Learning Outcomes
By the end of this section, you should be able to:
- 9.5.1 Produce labeled scatterplots and scatterplots with variable density points, different colors, etc. to indicate additional information.
- 9.5.2 Create and interpret correlation heatmaps from multidimensional data.
- 9.5.3 Create and interpret graphs of three-dimensional data using a variety of methods.
A data scientist or researcher is often interested in generating more advanced visual representations such as scatterplots, correlation maps, and three-dimensional (3D) representations of data. Visualizations or graphs produced in a 3D format are especially important to convey a realistic view of the world around us. For example, on a manufacturing line, a 3D visualization of a mechanical component can be much easier to interpret than a two-dimensional blueprint drawing. 3D graphs can offer insights and perspectives that might be hidden or unavailable in a two-dimensional view.
3D visualizations are especially useful when mapping large datasets such as population density maps or global supply chain routes. In these situations, Python tools such as geopandas
or 3D plotting routines available as part of matplotlib
can be used to visualize relationships, patterns, and connections in large datasets.
As we saw in Decision-Making Using Machine Learning Basics, data collection and analysis of very large datasets, or big data, has become more common as businesses, government agencies, and other organizations collect huge volumes of data on a real-time basis. Some examples of big data include social media data, which generates very large datasets every second, including images, videos, posts, likes, shares, comments, etc. Data visualization for social media data might include dashboards to visualize trends and engagement on hot topics of the day (see more on these in Reporting Results), or geospatial maps to show the geographic distribution of users and associated demographics.
In this section, we explore more advanced visualization techniques such as scatterplots with variable density points, correlation heatmaps, and three-dimensional type analysis.
Scatterplots, Revisited
A scatterplot is a graphing method for bivariate data, which is paired data in which each value of one variable is paired with a value of a second variable, which plots data for two numeric quantities, and the goal is to determine if there is a correlation or dependency between the two quantities. Scatterplots were discussed extensively in Inferential Statistics and Regression Analysis as a key element in correlation and regression analysis. Recall that in a scatterplot, the independent, or explanatory, quantity is labeled as the x-variable, and the dependent, or response, quantity is labeled as the y-variable.
When generating scatterplots in Inferential Statistics and Regression Analysis, the plt.scatter()
function was used, which is part of the Python matplotlib
library. In Correlation and Linear Regression Analysis, an example scatterplot was generated to plot revenue vs. advertising spend for a company.
The Python code for that example is reproduced as follows:
Python Code
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# Define the x-data, which is the amount spent on advertising
x = [49, 145, 57, 153, 92, 83, 117, 142, 69, 106, 109, 121]
# Define the y-data, which is revenue
y = [12210, 17590, 13215, 19200, 14600, 14100, 17100, 18400, 14100, 15500, 16300, 17020]
# Use the scatter function to generate a time series graph
plt.scatter(x, y)
# Add a title using the title function
plt.title("Revenue versus Advertising for a Company")
# Add labels to the x and y axes by using xlabel and ylabel functions
plt.xlabel("Advertising $000")
plt.ylabel("Revenue $000")
# Define a function to format the ticks with commas as thousands separators
def format_ticks(value, tick_number):
return f'{value:,.0f}'
# Apply the custom formatter to the y-axis
plt.gca().yaxis.set_major_formatter(ticker.FuncFormatter(format_ticks))
# Show the plot
plt.show()
The resulting output will look like this:
We can augment this scatterplot by adding color to indicate the revenue level (y-value).
To do this, use the c parameter in the plt.scatter
function to specify that the color of the points will be based on the y-value of the data point. The parameter cmap
allows the user to specify the color palette to be used such as “hot,” coolwarm,” etc.
In addition, a color bar can be added to the scatterplot so the user can interpret the color scale.
Here is a revised Python code and scatterplot that includes the color enhancement for the data points:
Python Code
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# Define the x-data, which is the amount spent on advertising
x = [49, 145, 57, 153, 92, 83, 117, 142, 69, 106, 109, 121]
# Define the y-data, which is revenue
y = [12210, 17590, 13215, 19200, 14600, 14100, 17100, 18400, 14100, 15500, 16300, 17020]
# Use the scatter function to generate a scatterplot
# Specify that the color of the data point will be based on revenue (y)
# Specify the 'coolwarm' color palette
plt.scatter(x, y, c=y, cmap='coolwarm')
# Plot a color bar so the user can interpret the color scale
cbar = plt.colorbar()
# Define a function to format the color bar ticks with commas as thousands separators
def format_colorbar_ticks(value, tick_number):
return f'{value:,.0f}'
# Apply the custom formatter to the color bar ticks
cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(format_colorbar_ticks))
# Add a title using the title function
plt.title("Revenue versus Advertising for a Company")
# Add labels to the x and y axes using xlabel and ylabel functions
plt.xlabel("Advertising $000")
plt.ylabel("Revenue $000")
# Define a function to format the y-axis ticks with commas as thousands separators
def format_ticks(value, tick_number):
return f'{value:,.0f}'
# Apply the custom formatter to the y-axis
plt.gca().yaxis.set_major_formatter(ticker.FuncFormatter(format_ticks))
# Show the plot
plt.show()
The resulting output will look like this:
A typical application for this capability is to plot the color of the datapoint on a scatterplot based on the value of a third variable. For example, a real estate professional may be interested in the correlation between home price, square footage of the home, and location in terms of distance in miles from the center of a city. The scatterplot can be created based on data, where x is the square footage of the home, y is the price of the home, and z is the distance from center of the city, represented by the color of the point.
Python Code
# import the matplotlib library
import matplotlib.pyplot as plt
# define the x-data, which is square footage of a home
x = [2100, 2378, 1983, 1422, 2764, 1901, 1198, 1785, 1556, 2931, 3071, 1688]
# define the y-data, which is the corresponding home price in 000
y = [390, 427, 350, 285, 479, 299, 250, 310, 290, 495, 515, 284]
# define the z-data, which is the corresponding distance from city center
z = [4.8, 0.5, 0.9, 1.5, 0.8, 4.2, 3.1, 0.1, 2.2, 1.5, 4.1, 0.7]
# use the scatter function to generate a time series graph
# specify that the color of the data point will be based on revenue (y)
# specify the hot color palette
plt.scatter(x, y, c = z, cmap='hot')
# also plot a color bar so user can interpret the color scale
plt.colorbar()
# Add a title using the title function
plt.title("Home Price versus Square Footage (Color = Miles from Center City)")
# Add labels to the x and y axes by using xlabel and ylabel functions
plt.xlabel("Square Footage")
plt.ylabel ("Home Price $000")
# Define a function to format the ticks with commas as thousands separators
def format_ticks(value, tick_number):
return f'{value:,.0f}'
# Apply the custom formatter to the x-axis
plt.gca().xaxis.set_major_formatter(ticker.FuncFormatter(format_ticks))
The resulting output will look like this:
Correlation Heatmaps
Recall from Correlation and Linear Regression Analysis that correlation analysis is used to determine if one quantity is related to or correlates with another quantity. For example, you know that the value of a car is correlated with the age of the car where, in general, the older the car, the less its value. A numeric correlation coefficient (r) was calculated that provided information on both the strength and direction of the correlation.
The value of r gives us this information:
- The value of r is always between and . The size of the correlation r indicates the strength of the linear relationship between the two variables. Values of r close to or to indicate a stronger linear relationship. Values of r close to zero indicate weak correlation.
The sign of r gives us this information:
- A positive value of r means that when x increases, y tends to increase, and when x decreases, y tends to decrease (positive correlation).
- A negative value of r means that when x increases, y tends to decrease, and when x decreases, y tends to increase (negative correlation).
Data scientists often collect data on more than two variables and then are interested in the various correlations between the variables. Of particular interest are those pairs of variables with stronger correlations. A convenient way to visualize this data is with a correlation heatmap, which assigns a color palette to the various values of the correlation coefficient for the two variables.
Python provides a mechanism to calculate a correlation matrix, which shows the value of r for each pair of variables in a dataset using the corr()
. This correlation matrix is typically used by researchers to identify those variables having stronger correlations. A correlation heatmap is a visual representation of the correlation matrix that implements color coding to visualize those variables with stronger correlations and those variables with weaker correlations.
Typically, darker colors are used to indicate variables with higher correlations and lighter colors are used to indicate variables with weaker correlations.
As an example of a correlation heatmap, we can use an available dataset from the Seaborn library related to car crash data found here at the car crashes dataset; the fields in the dataset are shown in Table 9.5.
Field Name | Field Description |
---|---|
abbrev | Abbreviation of the state name |
total | Total number of car crashes reported in the state |
speeding | Percentage of car crashes in which speeding was a contributing factor |
alcohol | Percentage of car crashes in which alcohol was involved |
not_distracted | Percentage of car crashes where the driver was not distracted |
no_previous | Percentage of car crashes where the driver had no previous accidents |
ins_premium | Insurance premium per insured vehicle |
ins_losses | Insurance losses incurred per insured vehicle |
There are 50 records in the dataset (one record per state). The first five lines of the dataset appear in Table 9.6.
total | speeding | alcohol | not_distracted | no_previous | ins_premium | ins_losses | abbrev |
---|---|---|---|---|---|---|---|
18.8 | 7.332 | 5.64 | 18.048 | 15.04 | 784.55 | 145.08 | AL |
18.1 | 7.421 | 4.525 | 16.29 | 17.014 | 1053.48 | 133.93 | AK |
18.6 | 6.51 | 5.208 | 15.624 | 17.856 | 899.47 | 110.35 | AZ |
22.4 | 4.032 | 5.824 | 21.056 | 21.28 | 827.34 | 142.39 | AR |
12 | 4.2 | 3.36 | 10.92 | 10.68 | 878.41 | 165.63 | CA |
The short Python program that follows can be used to generate a correlation heatmap for the variables in the dataset; it makes use of the heatmap
function available in the Seaborn
library (see also Decision-Making Using Machine Learning Basics).
Python Code
import seaborn as sns
import matplotlib.pyplot as plt
# Load the car_crashes dataset from seaborn library
dataset = sns.load_dataset("car_crashes")
# Select only numeric columns
numeric_columns = dataset.select_dtypes(include=['number'])
# Calculate the correlation matrix
correlation_matrix = numeric_columns.corr()
# Generate the correlation heatmap using the Reds color palette
sns.heatmap(correlation_matrix, cmap='Reds', annot=True, fmt=".2f")
# Add a Title to the plot
plt.title('Correlation Heatmap for Variables in Car Crashes Dataset')
# Show the plot
plt.show()
The resulting output will look like this:
Notice in the correlation heatmap that dark-red coloring is visible on the diagonal from the upper left to the lower right corner of the matrix. This dark shading represents a correlation coefficient of 1.0 since any variable is considered to be perfectly correlated with itself.
Visualizing Three-Dimensional Data
Let’s say you are taking an anatomy class, and an instructor attempts to verbally describe the human nervous system. Compare this to an interactive 3D model of the human nervous system, where a student can view the nervous system in three dimensions, rotate the 3D image, and view the nervous system from various angles and perspectives. Let’s say an auto mechanic needs to repair a transmission on a car and attempts to read the written technical service manual to understand how to start the repair. Compare this to an interactive 3D model of the car’s transmission where the mechanic can visualize the exact location for the repair and interact with the 3D model.
These examples show the advantages and power of three-dimensional perspectives. Visualizing three-dimensional data has many benefits such as spatial perception, depth perception, and an ability to better illustrate the clustering and patterns inherent in the dataset. 3D visualization is also useful in simulation and modeling applications.
Three-dimensional visualization is also used as part of multiple linear regression analysis to view a scatterplot in three dimensions and view the corresponding regression plane (see Machine Learning in Regression Analysis).
Python includes several tools and libraries to facilitate 3D visualization and analysis; some tools allow for interactivity such as the ability to zoom in/zoom out, pan, and rotate three-dimensinoal images. Here are some Python modules that support 3D visualization:
- Matplotlib includes facilities for 3D visualization such as
mpl.toolkits.mplot3d
. - Plotly is a library that includes 3D plotting capability, surface plots, and more.
- VTK is a library for 3D graphics and image processing (VTK stands for Visualization Toolkit).
To illustrate an example of a 3D scatterplot in Python using Matplotlib
, let’s generate Python code to plot the Seaborn “Iris” data, which can be downloaded from the Seaborn repository. (We used this dataset in What Are Data and Data Science? and Deep Learning and AI Basics as well.) The dataset contains 150 rows of data related to measurements for different varieties of iris flowers, but we will only be using a few rows (see Table 9.7).
Sepal Length | Sepal Width | Petal Length | Petal Width | Species of Iris |
---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | setosa |
4.9 | 3.0 | 1.4 | 0.2 | setosa |
6.7 | 3.0 | 5.2 | 2.3 | virginica |
5.9 | 3.0 | 5.1 | 1.8 | virginica |
We will use the mpl_toolkits.mplot3d()
function to generate the three-dimensional scatterplot. First, we will read in the “iris” dataset from the Seaborn library. Next, the 3D scatterplot is created using the fig.add.subplot()
function. The ax.scatter()
function is then used to plot the data points.
Here is the Python code:
Python Code
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle
# Load the iris dataset from Seaborn data repository
flower = sns.load_dataset('iris')
# Create a blank figure with a white background
fig = plt.figure(figsize=(9, 8), facecolor='white')
# Add a subplot to the figure
ax = fig.add_subplot(111, projection='3d')
# Set the background color of the axes to white
ax.set_facecolor('white')
# Plot data points
scatter = ax.scatter(flower['sepal_length'], flower['sepal_width'], flower['petal_length'],
c=flower['species'].astype('category').cat.codes, cmap='viridis', marker='o')
# Set labels and title
ax.set_xlabel('Sepal Length (cm)')
ax.set_ylabel('Sepal Width (cm)')
ax.set_zlabel('Petal Length (cm)')
ax.set_title('3D Scatterplot of Iris Dataset by Species')
# Map numerical codes to actual species names
species_names = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
# Create a ScalarMappable object
norm = plt.Normalize(flower['species'].astype('category').cat.codes.min(),
flower['species'].astype('category').cat.codes.max())
sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
sm.set_array([])
# Create the colorbar and adjust its spacing
cbar = plt.colorbar(sm, ticks=[0, 1, 2], label='Species', fraction=0.02, pad=0.1)
# Adjust colorbar position to increase space between colorbar and plot
cbar.ax.tick_params(labelsize=10) # Adjust the colorbar tick label size if needed
# Adjust figure layout to provide appropriate space around the plot
fig.subplots_adjust(left=0.1, right=0.85, top=0.9, bottom=0.1) # Adjust margins as needed
# Add legend
ax.legend(handles=scatter.legend_elements()[0],
labels=[species_names[name] for name in flower['species'].astype('category').cat.codes.unique()],
title='Species')
plt.show()
The resulting output will look like this:
Datasets
Note: The primary datasets referenced in the chapter code may also be downloaded here.