make_dt_regressor_plot#
- er_evaluation.plots.make_dt_regressor_plot(error_metrics, weights, features_df, numerical_features, categorical_features, type='sunburst', criterion='squared_error', **kwargs)[source]#
Fit a decision tree regressor to the data and create an interactive sunburst chart visualization of the resulting tree.
Parameters: y (Series): Cluster-wise error metrics. weights (Series): The sample weights to use during model fitting. features_df (DataFrame): The input features for each cluster as a Pandas DataFrame. numerical_features (list): A list of column names corresponding to the numerical features in the DataFrame. categorical_features (list): A list of column names corresponding to the categorical features in the DataFrame. **kwargs: Additional keyword arguments to pass to the fit_dt_regressor function.
Returns: plotly.graph_objs._sunburst.Sunburst: An interactive sunburst chart visualization of the fitted decision tree.
Examples
>>> import pandas as pd >>> import numpy as np >>> from er_evaluation.error_analysis import error_indicator >>> prediction = pd.Series([0, 1, 1]) >>> reference = pd.Series([0, 1, 0]) >>> y = error_indicator(prediction, reference) >>> weights = np.array([1, 1]) >>> features_df = pd.DataFrame({'feature1': [1, 2], 'feature2': [4, 5]}) >>> numerical_features = ['feature1'] >>> categorical_features = ['feature2'] >>> fig = make_dt_regressor_plot(y, weights, features_df, numerical_features, categorical_features)