Plot Decision Trees Using Python and Scikit-Learn
Decision trees are widely used in machine learning problems. We'll assume you are already familiar with the concept of decision trees and you've just trained your tree based algorithm!
Advice: If not, you can read our in-depth guide on "Decision Trees in Python with Scikit-Learn guide".
Now, it is time to try to explain how the tree has reached a decision. This means it is necessary to take a look under the hood, and see how the model has split the data internally to build the tree.
To be able to plot the resulting tree, let's create one. First, we'll load a toy wine dataset and divide it into train and test sets:
from sklearn import datasets
from sklearn.model_selection import train_test_split
SEED = 42
data = datasets.load_wine()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=SEED)
Now that the toy data has been divided, we can fit the Decision Tree model:
dt = DecisionTreeClassifier(max_depth=4,
random_state=SEED)
dt.fit(X_train, y_train)
Great! Notice that we have defined a maximum depth of 4, this means the generated tree will have 5 levels. This will help with the interpretability when plotting, since we'll only have 5 levels to read through.
Now, to plot the tree and get the underlying splits made by the model, we'll use Scikit-Learn's plot_tree()
method and matplotlib
to define a size for the plot.
You pass the fit model into the plot_tree()
method as the main argument. We will also pass the features and classes names, and customize the plot so that each tree node is displayed with rounded edges, filled with colors according to the classes, and display the proportion of each class in each node:
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
features = data.feature_names
classes = data.target_names
plt.figure(figsize=(10, 8))
plot_tree(dt,
feature_names=features,
class_names=classes,
rounded=True, # Rounded node edges
filled=True, # Adds color according to class
proportion=True); # Displays the proportions of class samples instead of the whole number of samples
That's it, this is the plotted underlying tree for this model!
Notice, that you can see that the first characteristic the tree considers to differentiate between wines is the intensity of color, followed by the amount of proline and flavonoids. Also, since we have filled the tree, class_0
nodes are orange, class_1
nodes are green, and class_2
nodes are purple.
See what else you can do with plot_tree()
in the documentation!
Data Scientist, Research Software Engineer, and teacher. Cassia is passionate about transformative processes in data, technology and life. She is graduated in Philosophy and Information Systems, with a Strictu Sensu Master's Degree in the field of Foundations Of Mathematics.