plot_dt_regressor_sunburst#

er_evaluation.plots.plot_dt_regressor_sunburst(dt_regressor, X, y, feature_names, weights=None, label='Value', color_function=None)[source]#

Creates a sunburst plot of a decision tree regressor.

Parameters:
  • dt_regressor (DecisionTreeRegressor) – A fitted decision tree regressor model.

  • X (numpy array or pandas DataFrame) – The input features used to fit the model.

  • y (numpy array or pandas Series) – The target values used to fit the model.

  • feature_names (list of str) – The names of the input features.

  • weights (Series, optional) – Sampling weights for y. Default is None.

  • label (str, optional) – The label for the color scale. Default is β€œValue”.

  • color_function (function, optional) – A function applied to the subset of y values within each node to determine node color.If None, the predicted value for each node will be used as the color. Default is None.

Returns:

A sunburst plot of the decision tree regressor.

Return type:

plotly.graph_objs.Figure

Examples

>>> from sklearn.tree import DecisionTreeRegressor
>>> import numpy as np
>>> X = np.array([[1], [2], [3], [4], [5]])
>>> y = np.array([0, 1, 0, 1, 0])
>>> dt_regressor = DecisionTreeRegressor(max_depth=2)
>>> dt_regressor.fit(X, y)  
>>> feature_names = ['x']
>>> fig = plot_dt_regressor_sunburst(dt_regressor, X, y, feature_names)