Plot Decision Trees Using Python and Scikit-Learn

Plot Decision Trees Using Python and Scikit-Learn

Decision trees are widely used in machine learning problems. We'll assume you are already familiar with the concept of decision trees and you've just trained your tree based algorithm!

Advice: If not, you can read our in-depth guide on "Decision Trees in Python with Scikit-Learn guide".

Now, it is time to try to explain how the tree has reached a decision. This means it is necessary to take a look under the hood, and see how the model has split the data internally to build the tree.

To be able to plot the resulting tree, let's create one. First, we'll load a toy wine dataset and divide it into train and test sets:

from sklearn import datasets
from sklearn.model_selection import train_test_split

SEED = 42

data = datasets.load_wine()
X = data.data
y = data.target

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=SEED)

Now that the toy data has been divided, we can fit the Decision Tree model:

dt = DecisionTreeClassifier(max_depth=4,
                            random_state=SEED)
dt.fit(X_train, y_train)

Great! Notice that we have defined a maximum depth of 4, this means the generated tree will have 5 levels. This will help with the interpretability when plotting, since we'll only have 5 levels to read through.

Get free courses, guided projects, and more

No spam ever. Unsubscribe anytime. Read our Privacy Policy.

Now, to plot the tree and get the underlying splits made by the model, we'll use Scikit-Learn's plot_tree() method and matplotlib to define a size for the plot.

You pass the fit model into the plot_tree() method as the main argument. We will also pass the features and classes names, and customize the plot so that each tree node is displayed with rounded edges, filled with colors according to the classes, and display the proportion of each class in each node:

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

features = data.feature_names
classes = data.target_names

plt.figure(figsize=(10, 8))
plot_tree(dt,
          feature_names=features,
          class_names=classes,
          rounded=True, # Rounded node edges
          filled=True, # Adds color accoding to class
          proportion=True); # Displays the proportions of class samples instead of the whole number of samples

That's it, this is the plotted underlying tree for this model!

Notice, that you can see that the first characteristic the tree considers to differentiate between wines is the intensity of color, folowed by the amount of proline and flavonoids. Also, since we have filled the tree, class_0 nodes are orange, class_1 nodes are green, and class_2 nodes are purple.

See what else you can do with plot_tree() in the documentation!

Was this helpful?
Cássia SampaioAuthor

Data Scientist, Research Software Engineer, and teacher. Cassia is passionate about transformative processes in data, technology and life. She is graduated in Philosophy and Information Systems, with a Strictu Sensu Master's Degree in the field of Foundations Of Mathematics.

Project

Hands-On House Price Prediction - Machine Learning in Python

# deep learning# tensorflow# machine learning# python

If you've gone through the experience of moving to a new house or apartment - you probably remember the stressful experience of choosing a property,...

David Landup
Ammar Alyousfi
Jovana Ninkovic
Details
Project

Building Your First Convolutional Neural Network With Keras

# artificial intelligence# machine learning# keras# deep learning

Most resources start with pristine datasets, start at importing and finish at validation. There's much more to know. Why was a class predicted? Where was...

David Landup
David Landup
Details

© 2013-2022 Stack Abuse. All rights reserved.

DisclosurePrivacyTerms