Decision trees are widely used in machine learning problems. We'll assume you are already familiar with the concept of decision trees and you've just trained your tree based algorithm!
Advice: If not, you can read our in-depth guide on "Decision Trees in Python with Scikit-Learn guide".
Now, it is time to try to explain how the tree has reached a decision. This means it is necessary to take a look under the hood, and see how the model has split the data internally to build the tree.
To be able to plot the resulting tree, let's create one. First, we'll load a toy wine dataset and divide it into train and test sets:
from sklearn import datasets from sklearn.model_selection import train_test_split SEED = 42 data = datasets.load_wine() X = data.data y = data.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=SEED)
Now that the toy data has been divided, we can fit the Decision Tree model:
dt = DecisionTreeClassifier(max_depth=4, random_state=SEED) dt.fit(X_train, y_train)
Great! Notice that we have defined a maximum depth of 4, this means the generated tree will have 5 levels. This will help with the interpretability when plotting, since we'll only have 5 levels to read through.
Now, to plot the tree and get the underlying splits made by the model, we'll use Scikit-Learn's
plot_tree() method and
matplotlib to define a size for the plot.
You pass the fit model into the
plot_tree() method as the main argument. We will also pass the features and classes names, and customize the plot so that each tree node is displayed with rounded edges, filled with colors according to the classes, and display the proportion of each class in each node:
from sklearn.tree import plot_tree import matplotlib.pyplot as plt features = data.feature_names classes = data.target_names plt.figure(figsize=(10, 8)) plot_tree(dt, feature_names=features, class_names=classes, rounded=True, # Rounded node edges filled=True, # Adds color according to class proportion=True); # Displays the proportions of class samples instead of the whole number of samples
That's it, this is the plotted underlying tree for this model!
Notice, that you can see that the first characteristic the tree considers to differentiate between wines is the intensity of color, followed by the amount of proline and flavonoids. Also, since we have filled the tree,
class_0 nodes are orange,
class_1 nodes are green, and
class_2 nodes are purple.
Data Scientist, Research Software Engineer, and teacher. Cassia is passionate about transformative processes in data, technology and life. She is graduated in Philosophy and Information Systems, with a Strictu Sensu Master's Degree in the field of Foundations Of Mathematics.