Introduction
There are many data visualization libraries in Python, yet Matplotlib is the most popular library out of all of them. Matplotlib’s popularity is due to its reliability and utility - it's able to create both simple and complex plots with little code. You can also customize the plots in a variety of ways.
In this tutorial, we'll cover how to plot a Joint Plot in Matplotlib which consists of a Scatter Plot and multiple Distribution Plots on the same Figure
.
Joint Plots are used to explore relationships between bivariate data, as well as their distributions at the same time.
Note: This sort of task is much more fit for libraries such as Seaborn, which has a built-in jointplot()
function. With Matplotlib, we'll construct a Joint Plot manually, using GridSpec
and multiple Axes
objects, instead of having Seaborn do it for us.
Importing Data
We'll use the famous Iris Dataset, since we can explore the relationship between features such as SepalWidthCm
and SepalLengthCm
through a Scatter Plot, but also explore the distributions between the Species
feature with their sepal length/width in mind, through Distribution Plots at the same time.
Let's import the dataset and take a peek:
import pandas as pd
df = pd.read_csv('iris.csv')
print(df.head())
This results in:
Id SepalLengthCm SepalWidthCm PetalLengthCm PetalWidthCm Species
0 1 5.1 3.5 1.4 0.2 Iris-setosa
1 2 4.9 3.0 1.4 0.2 Iris-setosa
2 3 4.7 3.2 1.3 0.2 Iris-setosa
3 4 4.6 3.1 1.5 0.2 Iris-setosa
4 5 5.0 3.6 1.4 0.2 Iris-setosa
We'll be exploring the bivariate relationship between the SepalLengthCm
and SepalWidthCm
features here, but also their distributions. We can approach this in two ways - with respect to their Species
or not.
We can totally disregard the Species
feature, and simply plot histograms of the distributions of each flower instance. On the other hand, we can color-code and plot distribution plots of each flower instance, highlighting the difference in their Species
as well.
We'll explore both options here, starting with the simpler one - disregarding the Species
altogether.
Plot a Joint Plot in Matplotlib with Single-Class Histograms
In the first approach, we'll just load in the flower instances and plot them as-is, with no regard to their Species
.
We'll be using a GridSpec
to customize our figure's layout, to make space for three different plots and Axes
instances.
To invoke the GridSpec
constructor, we'll want to import it alongside the PyPlot instance:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
Now, let's create our Figure
and create the Axes
objects:
df = pd.read_csv('iris.csv')
fig = plt.figure()
gs = GridSpec(4, 4)
ax_scatter = fig.add_subplot(gs[1:4, 0:3])
ax_hist_y = fig.add_subplot(gs[0,0:3])
ax_hist_x = fig.add_subplot(gs[1:4, 3])
plt.show()
We've created 3 Axes
instances, by adding subplots to the figure, using our GridSpec
instance to position them. This results in a Figure
with 3 empty Axes
instances:
Now that we've got the layout and positioning in place, all we have to do is plot the data on our Axes
. Let's update the script so that we plot the SepalLengthCm
and SepalWidthCm
features through a Scatter plot, on our ax_scatter
axes, and each of these features on the ax_hist_y
and ax_hist_x
axes:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
df = pd.read_csv('iris.csv')
fig = plt.figure()
gs = GridSpec(4, 4)
ax_scatter = fig.add_subplot(gs[1:4, 0:3])
ax_hist_x = fig.add_subplot(gs[0,0:3])
ax_hist_y = fig.add_subplot(gs[1:4, 3])
ax_scatter.scatter(df['SepalLengthCm'], df['SepalWidthCm'])
ax_hist_x.hist(df['SepalLengthCm'])
ax_hist_y.hist(df['SepalWidthCm'], orientation = 'horizontal')
plt.show()
We've set the orientation
of ax_hist_y
to horizontal
so that it's plotted horizontally, on the right-hand side of the Scatter Plot, in the same orientation we've set our axes to, using the GridSpec
:
This results in a Joint Plot of the relationship between the SepalLengthCm
and SepalWidthCm
features, as well as the distributions for the respective features.
Plot a Joint Plot in Matplotlib with Multiple-Class Histograms
Now, another case we might want to explore is the distribution of these features, with respect to the Species
of the flower, since it could very possibly affect the range of sepal lengths and widths.
For this, we won't be using just one histogram for each axis, where each contains all flower instances, but rather, we'll be overlaying a histogram for each Species
on both axes.
To do this, we'll first have to dissect the DataFrame
we've been using before, by the flower Species
:
Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
df = pd.read_csv('iris.csv')
setosa = df[df['Species']=='Iris-setosa']
virginica = df[df['Species']=='Iris-virginica']
versicolor = df[df['Species']=='Iris-versicolor']
species = df['Species']
colors = {
'Iris-setosa' : 'tab:blue',
'Iris-versicolor' : 'tab:red',
'Iris-virginica' : 'tab:green'
}
Here, we've just filtered out the DataFrame
, by the Species
feature into three separate datasets. The setosa
, virginica
and versicolor
datasets now contain only their respective instances.
We'll also want to color each of these instances with a different color, based on their Species
, both in the Scatter Plot and in the Histograms. For that, we've simply cut out a Series
of the Species
feature, and made a colors
dictionary, which we'll use to map()
the Species
of each flower to a color later on.
Now, let's make our Figure
, GridSpec
and Axes
instances:
fig = plt.figure()
gs = GridSpec(4, 4)
ax_scatter = fig.add_subplot(gs[1:4, 0:3])
ax_hist_y = fig.add_subplot(gs[0,0:3])
ax_hist_x = fig.add_subplot(gs[1:4, 3])
Finally, we can plot out the Scatter Plot and Histograms, setting their colors and orientations accordingly:
ax_scatter.scatter(df['SepalLengthCm'], df['SepalWidthCm'], c=species.map(colors))
ax_hist_y.hist(versicolor['SepalLengthCm'], color='tab:red', alpha=0.4)
ax_hist_y.hist(virginica['SepalLengthCm'], color='tab:green', alpha=0.4)
ax_hist_y.hist(setosa['SepalLengthCm'], color='tab:blue', alpha=0.4)
ax_hist_x.hist(versicolor['SepalWidthCm'], orientation = 'horizontal', color='tab:red', alpha=0.4)
ax_hist_x.hist(virginica['SepalWidthCm'], orientation = 'horizontal', color='tab:green', alpha=0.4)
ax_hist_x.hist(setosa['SepalWidthCm'], orientation = 'horizontal', color='tab:blue', alpha=0.4)
plt.show()
The map()
call results in a Series
of colors:
0 tab:blue
1 tab:blue
2 tab:blue
3 tab:blue
4 tab:blue
...
145 tab:green
146 tab:green
147 tab:green
148 tab:green
149 tab:green
When provided to the c
argument of the scatter()
function, it applies colors to instances in that order, effectively coloring each instance with a color corresponding to its species.
For the Histograms, we've simply plotted three plots, one for each Species
, with their respective colors. You can opt for a step
Histogram here, and tweak the alpha
value to create different-looking distributions.
Running this code results in:
Now, each Species
has its own color and distribution, plotted separately from other flowers. Furthermore, they're color-coded with the Scatter Plot so it's a really intuitive plot that can easily be read and interpreted.
Note: If you find the overlapping colors, such as the orange that comprises of the red and blue Histograms distracting, setting the histtype
to step
will remove the filled colors:
Conclusion
In this guide, we've taken a look at how to plot a Joint Plot in Matplotlib - a Scatter Plot with accompanying Distribution Plots (Histograms) on both axes of the plot, to explore the distribution of the variables that constitute the Scatter Plot itself.
Although this task is more suited for libraries like Seaborn, which have built-in support for Joint Plots, Matplotlib is the underlying engine that enables Seaborn to make these plots effortlessly.
If you're interested in Data Visualization and don't know where to start, make sure to check out our bundle of books on Data Visualization in Python:
Data Visualization in Python with Matplotlib and Pandas is a book designed to take absolute beginners to Pandas and Matplotlib, with basic Python knowledge, and allow them to build a strong foundation for advanced work with theses libraries - from simple plots to animated 3D plots with interactive buttons.
It serves as an in-depth, guide that'll teach you everything you need to know about Pandas and Matplotlib, including how to construct plot types that aren't built into the library itself.
Data Visualization in Python, a book for beginner to intermediate Python developers, guides you through simple data manipulation with Pandas, cover core plotting libraries like Matplotlib and Seaborn, and show you how to take advantage of declarative and experimental libraries like Altair. More specifically, over the span of 11 chapters this book covers 9 Python libraries: Pandas, Matplotlib, Seaborn, Bokeh, Altair, Plotly, GGPlot, GeoPandas, and VisPy.
It serves as a unique, practical guide to Data Visualization, in a plethora of tools you might use in your career.