Matplotlib Scatter Plot with Distribution Plots (Joint Plot) - Tutorial and Examples

Introduction

There are many data visualization libraries in Python, yet Matplotlib is the most popular library out of all of them. Matplotlib’s popularity is due to its reliability and utility - it's able to create both simple and complex plots with little code. You can also customize the plots in a variety of ways.

In this tutorial, we'll cover how to plot a Joint Plot in Matplotlib which consists of a Scatter Plot and multiple Distribution Plots on the same Figure.

Joint Plots are used to explore relationships between bivariate data, as well as their distributions at the same time.

Note: This sort of task is much more fit for libraries such as Seaborn, which has a built-in jointplot() function. With Matplotlib, we'll construct a Joint Plot manually, using GridSpec and multiple Axes objects, instead of having Seaborn do it for us.

Importing Data

We'll use the famous Iris Dataset, since we can explore the relationship between features such as SepalWidthCm and SepalLengthCm through a Scatter Plot, but also explore the distributions between the Species feature with their sepal length/width in mind, through Distribution Plots at the same time.

Let's import the dataset and take a peek:

import pandas as pd

df = pd.read_csv('iris.csv')
print(df.head())

This results in:

   Id  SepalLengthCm  SepalWidthCm  PetalLengthCm  PetalWidthCm      Species
0   1            5.1           3.5            1.4           0.2  Iris-setosa
1   2            4.9           3.0            1.4           0.2  Iris-setosa
2   3            4.7           3.2            1.3           0.2  Iris-setosa
3   4            4.6           3.1            1.5           0.2  Iris-setosa
4   5            5.0           3.6            1.4           0.2  Iris-setosa

We'll be exploring the bivariate relationship between the SepalLengthCm and SepalWidthCm features here, but also their distributions. We can approach this in two ways - with respect to their Species or not.

We can totally disregard the Species feature, and simply plot histograms of the distributions of each flower instance. On the other hand, we can color-code and plot distribution plots of each flower instance, highlighting the difference in their Species as well.

We'll explore both options here, starting with the simpler one - disregarding the Species altogether.

Plot a Joint Plot in Matplotlib with Single-Class Histograms

In the first approach, we'll just load in the flower instances and plot them as-is, with no regard to their Species.

We'll be using a GridSpec to customize our figure's layout, to make space for three different plots and Axes instances.

To invoke the GridSpec constructor, we'll want to import it alongside the PyPlot instance:

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

Now, let's create our Figure and create the Axes objects:

df = pd.read_csv('iris.csv')

fig = plt.figure()
gs = GridSpec(4, 4)

ax_scatter = fig.add_subplot(gs[1:4, 0:3])
ax_hist_y = fig.add_subplot(gs[0,0:3])
ax_hist_x = fig.add_subplot(gs[1:4, 3])

plt.show()

We've created 3 Axes instances, by adding subplots to the figure, using our GridSpec instance to position them. This results in a Figure with 3 empty Axes instances:

matplotlib gridspec for jointplot

Now that we've got the layout and positioning in place, all we have to do is plot the data on our Axes. Let's update the script so that we plot the SepalLengthCm and SepalWidthCm features through a Scatter plot, on our ax_scatter axes, and each of these features on the ax_hist_y and ax_hist_x axes:

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

df = pd.read_csv('iris.csv')

fig = plt.figure()
gs = GridSpec(4, 4)

ax_scatter = fig.add_subplot(gs[1:4, 0:3])
ax_hist_x = fig.add_subplot(gs[0,0:3])
ax_hist_y = fig.add_subplot(gs[1:4, 3])

ax_scatter.scatter(df['SepalLengthCm'], df['SepalWidthCm'])

ax_hist_x.hist(df['SepalLengthCm'])
ax_hist_y.hist(df['SepalWidthCm'], orientation = 'horizontal')

plt.show()

We've set the orientation of ax_hist_y to horizontal so that it's plotted horizontally, on the right-hand side of the Scatter Plot, in the same orientation we've set our axes to, using the GridSpec:

matplotlib simple joint plot with one class histogram

This results in a Joint Plot of the relationship between the SepalLengthCm and SepalWidthCm features, as well as the distributions for the respective features.

Plot a Joint Plot in Matplotlib with Multiple-Class Histograms

Now, another case we might want to explore is the distribution of these features, with respect to the Species of the flower, since it could very possibly affect the range of sepal lengths and widths.

For this, we won't be using just one histogram for each axis, where each contains all flower instances, but rather, we'll be overlaying a histogram for each Species on both axes.

To do this, we'll first have to dissect the DataFrame we've been using before, by the flower Species:

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

df = pd.read_csv('iris.csv')

setosa = df[df['Species']=='Iris-setosa']
virginica = df[df['Species']=='Iris-virginica']
versicolor = df[df['Species']=='Iris-versicolor']
species = df['Species']
colors = {
    'Iris-setosa' : 'tab:blue', 
    'Iris-versicolor' : 'tab:red',
    'Iris-virginica' : 'tab:green'
    }

Here, we've just filtered out the DataFrame, by the Species feature into three separate datasets. The setosa, virginica and versicolor datasets now contain only their respective instances.

We'll also want to color each of these instances with a different color, based on their Species, both in the Scatter Plot and in the Histograms. For that, we've simply cut out a Series of the Species feature, and made a colors dictionary, which we'll use to map() the Species of each flower to a color later on.

Now, let's make our Figure, GridSpec and Axes instances:

fig = plt.figure()
gs = GridSpec(4, 4)

ax_scatter = fig.add_subplot(gs[1:4, 0:3])
ax_hist_y = fig.add_subplot(gs[0,0:3])
ax_hist_x = fig.add_subplot(gs[1:4, 3])

Finally, we can plot out the Scatter Plot and Histograms, setting their colors and orientations accordingly:

ax_scatter.scatter(df['SepalLengthCm'], df['SepalWidthCm'], c=species.map(colors))

ax_hist_y.hist(versicolor['SepalLengthCm'], color='tab:red', alpha=0.4)
ax_hist_y.hist(virginica['SepalLengthCm'], color='tab:green', alpha=0.4)
ax_hist_y.hist(setosa['SepalLengthCm'], color='tab:blue', alpha=0.4)

ax_hist_x.hist(versicolor['SepalWidthCm'], orientation = 'horizontal', color='tab:red', alpha=0.4)
ax_hist_x.hist(virginica['SepalWidthCm'], orientation = 'horizontal', color='tab:green', alpha=0.4)
ax_hist_x.hist(setosa['SepalWidthCm'], orientation = 'horizontal', color='tab:blue', alpha=0.4)

plt.show()

The map() call results in a Series of colors:

0       tab:blue
1       tab:blue
2       tab:blue
3       tab:blue
4       tab:blue
         ...
145    tab:green
146    tab:green
147    tab:green
148    tab:green
149    tab:green

When provided to the c argument of the scatter() function, it applies colors to instances in that order, effectively coloring each instance with a color corresponding to its species.

For the Histograms, we've simply plotted three plots, one for each Species, with their respective colors. You can opt for a step Histogram here, and tweak the alpha value to create different-looking distributions.

Running this code results in:

matplotlib joint plot with multi-class histogram

Now, each Species has its own color and distribution, plotted separately from other flowers. Furthermore, they're color-coded with the Scatter Plot so it's a really intuitive plot that can easily be read and interpreted.

Note: If you find the overlapping colors, such as the orange that comprises of the red and blue Histograms distracting, setting the histtype to step will remove the filled colors:

matplotlib joint plot, step histogram

Conclusion

In this guide, we've taken a look at how to plot a Joint Plot in Matplotlib - a Scatter Plot with accompanying Distribution Plots (Histograms) on both axes of the plot, to explore the distribution of the variables that constitute the Scatter Plot itself.

Although this task is more suited for libraries like Seaborn, which have built-in support for Joint Plots, Matplotlib is the underlying engine that enables Seaborn to make these plots effortlessly.

If you're interested in Data Visualization and don't know where to start, make sure to check out our bundle of books on Data Visualization in Python:

Data Visualization in Python with Matplotlib and Pandas is a book designed to take absolute beginners to Pandas and Matplotlib, with basic Python knowledge, and allow them to build a strong foundation for advanced work with theses libraries - from simple plots to animated 3D plots with interactive buttons.

It serves as an in-depth, guide that'll teach you everything you need to know about Pandas and Matplotlib, including how to construct plot types that aren't built into the library itself.

Data Visualization in Python, a book for beginner to intermediate Python developers, guides you through simple data manipulation with Pandas, cover core plotting libraries like Matplotlib and Seaborn, and show you how to take advantage of declarative and experimental libraries like Altair. More specifically, over the span of 11 chapters this book covers 9 Python libraries: Pandas, Matplotlib, Seaborn, Bokeh, Altair, Plotly, GGPlot, GeoPandas, and VisPy.

It serves as a unique, practical guide to Data Visualization, in a plethora of tools you might use in your career.