Plotly Scatter Plot - Tutorial with Examples

Introduction

Plotly is a JavaScript-based, Python data visualization library, focused on interactive and web-based visualizations. It has the simplicity of Seaborn, with a high-level API, but also the interactivity of Bokeh.

In addition to the core library's functionality, using the built-in Plotly Express with Dash, makes it an amazing choice for web-based applications and interactive, data-driven dashboards, usually written in Flask.

In this guide, we'll take a look at how to plot a Scatter Plot with Plotly.

Scatter Plots explore the relationship between two numerical variables (features) of a dataset.

Import Data

We'll be working with the Heart Attack Dataset from Kaggle, which contains data on various bodily metrics that we could use as indicators of a heart attack possibility.

Let's import the dataset and print the head() to take a peek:

import pandas as pd

df = pd.read_csv('heart.csv')
print(df.head())

This results in:

   age cp  trtbps  chol  fbs  restecg  thalachh  exng  oldpeak  slp  caa output
0   63   3     145   233    1        0       150     0      2.3    0    0      1
1   37   2     130   250    0        1       187     0      3.5    0    0      1
2   41   1     130   204    0        0       172     0      1.4    2    0      1
3   56   1     120   236    0        1       178     0      0.8    2    0      1
4   57   0     120   354    0        1       163     1      0.6    2    0      1

Let's explore the relationships between features such as the thalachh (maximum recorded heart rate), trtbps (resting blood pressure), chol (amount of cholesterol) and output (0 or 1, representing lower or higher chances of experiencing a heart attack respectively).

First, let's go ahead and save our features separately for brevity's sake:

max_heartrate = df['thalachh']
resting_blood_pressure = df['trtbps']
cholesterol_level = df['chol']
output = df['output']

Plot a Scatter Plot with Plotly

Finally, we can go ahead and plot a Scatter Plot. Let's go ahead and first explore the relationship between max_heartrate and cholesterol_level. To plot a Scatter Plot with Plotly, we'll use the scatter() function of the Plotly Express (px) instance:

fig = px.scatter(x=cholesterol_level, y=max_heartrate)
fig.show()

The only required arguments are the x and y features, which will plot a Scatter Plot (without axis labels) in a spun-up server on your browser of choice:

Alternatively, if you don't want to define your variables beforehand, Plotly offers the exact same syntax as Seaborn - you specify the data source, and the names of the features you'd like to visualize. This will map the features to labels, and plot them directly without having to specify the features like we did before:

import pandas as pd
import plotly.express as px

df = pd.read_csv('heart.csv')

fig = px.scatter(df, x='chol', y='thalachh')
fig.show()

This results in:

Note: You can also do a mish-mash of these approaches, where you supply your DataFrame as the source, but also use predefined variables instead of referencing the feature column-names in the scatter() call:

fig = px.scatter(df, x=cholesterol_level, y=max_heartrate)
fig.show()

This results in a labeled Scatter Plot as well:

There doesn't seem to be much of a correlation between the cholesterol level and maximum heart rate of individuals in this dataset.

Customizing a Plotly Scatter Plot

Now, we rarely visualize plain plots. The point is to visualize certain characteristics of data, intuitively.

In our case, this might include coloring the markers depending on the output feature, or adding hover_data, which specifies what's shown on the markers when someone hovers over them.

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

Currently, the hover_data isn't very helpful, only showing us the x and y values, which can already be reasonably inferred from observing the resulting plot.

Let's go ahead and change a few of the parameters to make this plot a bit more intuitive:

import pandas as pd
import plotly.express as px

df = pd.read_csv('heart.csv')

fig = px.scatter(df, x='chol', y='thalachh', color='output', hover_data=['sex', 'age'])
fig.show()

We've set the color of each marker to be mapped to the output feature, coloring higher and lower chances of experiencing a heart attack in different colors. We've also included the sex and age of each individual on their markers.

This results in:

Finally, you can also change the size of the marker, either passing it a scalar value (such as 5) to the fig.update_traces() method, or by passing in a vector value (such as mapping the size to a feature) to the size argument.

Let's map the oldpeak feature with the size of each marker:

import pandas as pd
import plotly.express as px

df = pd.read_csv('heart.csv')

fig = px.scatter(df, x='chol', 
                y='thalachh', 
                color='output', 
                size='oldpeak', 
                hover_data=['sex', 'age'])
fig.show()

Now, each marker will have a variable size, depending on the values of the oldpeak feature:

Or, if you want to specifically make all markers of the same, fixed size, you can update the Figure's traces:

import pandas as pd
import plotly.express as px

df = pd.read_csv('heart.csv')

fig = px.scatter(df, x='chol',
                y='thalachh',
                color='output',
                hover_data=['sex', 'age'])

fig.update_traces(marker={'size': 10})    
fig.show()

This results in:

Conclusion

In this guide, we've taken a look at how to plot a Scatter Plot using Python and Plotly.

If you're interested in Data Visualization and don't know where to start, make sure to check out our bundle of books on Data Visualization in Python:

Data Visualization in Python with Matplotlib and Pandas is a book designed to take absolute beginners to Pandas and Matplotlib, with basic Python knowledge, and allow them to build a strong foundation for advanced work with these libraries - from simple plots to animated 3D plots with interactive buttons.

It serves as an in-depth guide that'll teach you everything you need to know about Pandas and Matplotlib, including how to construct plot types that aren't built into the library itself.

Data Visualization in Python, a book for beginner to intermediate Python developers, guides you through simple data manipulation with Pandas, covers core plotting libraries like Matplotlib and Seaborn, and shows you how to take advantage of declarative and experimental libraries like Altair. More specifically, over the span of 11 chapters this book covers 9 Python libraries: Pandas, Matplotlib, Seaborn, Bokeh, Altair, Plotly, GGPlot, GeoPandas, and VisPy.

It serves as a unique, practical guide to Data Visualization, in a plethora of tools you might use in your career.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms