Skip to ContentGo to accessibility pageKeyboard shortcuts menu
OpenStax Logo
Principles of Data Science

9.5 Multivariate and Network Data Visualization Using Python

Principles of Data Science9.5 Multivariate and Network Data Visualization Using Python

Learning Outcomes

By the end of this section, you should be able to:

  • 9.5.1 Produce labeled scatterplots and scatterplots with variable density points, different colors, etc. to indicate additional information.
  • 9.5.2 Create and interpret correlation heatmaps from multidimensional data.
  • 9.5.3 Create and interpret graphs of three-dimensional data using a variety of methods.

A data scientist or researcher is often interested in generating more advanced visual representations such as scatterplots, correlation maps, and three-dimensional (3D) representations of data. Visualizations or graphs produced in a 3D format are especially important to convey a realistic view of the world around us. For example, on a manufacturing line, a 3D visualization of a mechanical component can be much easier to interpret than a two-dimensional blueprint drawing. 3D graphs can offer insights and perspectives that might be hidden or unavailable in a two-dimensional view.

3D visualizations are especially useful when mapping large datasets such as population density maps or global supply chain routes. In these situations, Python tools such as geopandas or 3D plotting routines available as part of matplotlib can be used to visualize relationships, patterns, and connections in large datasets.

As we saw in Decision-Making Using Machine Learning Basics, data collection and analysis of very large datasets, or big data, has become more common as businesses, government agencies, and other organizations collect huge volumes of data on a real-time basis. Some examples of big data include social media data, which generates very large datasets every second, including images, videos, posts, likes, shares, comments, etc. Data visualization for social media data might include dashboards to visualize trends and engagement on hot topics of the day (see more on these in Reporting Results), or geospatial maps to show the geographic distribution of users and associated demographics.

In this section, we explore more advanced visualization techniques such as scatterplots with variable density points, correlation heatmaps, and three-dimensional type analysis.

Scatterplots, Revisited

A scatterplot is a graphing method for bivariate data, which is paired data in which each value of one variable is paired with a value of a second variable, which plots (x,y)(x,y) data for two numeric quantities, and the goal is to determine if there is a correlation or dependency between the two quantities. Scatterplots were discussed extensively in Inferential Statistics and Regression Analysis as a key element in correlation and regression analysis. Recall that in a scatterplot, the independent, or explanatory, quantity is labeled as the x-variable, and the dependent, or response, quantity is labeled as the y-variable.

When generating scatterplots in Inferential Statistics and Regression Analysis, the plt.scatter() function was used, which is part of the Python matplotlib library. In Correlation and Linear Regression Analysis, an example scatterplot was generated to plot revenue vs. advertising spend for a company.

The Python code for that example is reproduced as follows:

Python Code

    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker
    
    # Define the x-data, which is the amount spent on advertising
    x = [49, 145, 57, 153, 92, 83, 117, 142, 69, 106, 109, 121]
    
    # Define the y-data, which is revenue
    y = [12210, 17590, 13215, 19200, 14600, 14100, 17100, 18400, 14100, 15500, 16300, 17020]
    
    # Use the scatter function to generate a time series graph
    plt.scatter(x, y)
    
    # Add a title using the title function
    plt.title("Revenue versus Advertising for a Company")
    
    # Add labels to the x and y axes by using xlabel and ylabel functions
    plt.xlabel("Advertising $000")
    plt.ylabel("Revenue $000")
    
    # Define a function to format the ticks with commas as thousands separators
    def format_ticks(value, tick_number):
        return f'{value:,.0f}'
    
    # Apply the custom formatter to the y-axis
    plt.gca().yaxis.set_major_formatter(ticker.FuncFormatter(format_ticks))
    
    # Show the plot
    plt.show()
    

The resulting output will look like this:

A scatterplot labeled revenue versus advertising for a company. The X axis is labeled Advertising $000 and ranges from 60 to 140. The Y axis is labeled Revenue $000 and ranges from 12,000 to 19,000. There are 12 data points showing an increase from left to right.

We can augment this scatterplot by adding color to indicate the revenue level (y-value).

To do this, use the c parameter in the plt.scatter function to specify that the color of the points will be based on the y-value of the data point. The parameter cmap allows the user to specify the color palette to be used such as “hot,” coolwarm,” etc.

In addition, a color bar can be added to the scatterplot so the user can interpret the color scale.

Here is a revised Python code and scatterplot that includes the color enhancement for the data points:

Python Code

    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker
    
    # Define the x-data, which is the amount spent on advertising
    x = [49, 145, 57, 153, 92, 83, 117, 142, 69, 106, 109, 121]
    
    # Define the y-data, which is revenue
    y = [12210, 17590, 13215, 19200, 14600, 14100, 17100, 18400, 14100, 15500, 16300, 17020]
    
    # Use the scatter function to generate a scatterplot
    # Specify that the color of the data point will be based on revenue (y)
    # Specify the 'coolwarm' color palette
    plt.scatter(x, y, c=y, cmap='coolwarm')
    
    # Plot a color bar so the user can interpret the color scale
    cbar = plt.colorbar()
    
    # Define a function to format the color bar ticks with commas as thousands separators
    def format_colorbar_ticks(value, tick_number):
        return f'{value:,.0f}'
    
    # Apply the custom formatter to the color bar ticks
    cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(format_colorbar_ticks))
    
    # Add a title using the title function
    plt.title("Revenue versus Advertising for a Company")
    
    # Add labels to the x and y axes using xlabel and ylabel functions
    plt.xlabel("Advertising $000")
    plt.ylabel("Revenue $000")
    
    # Define a function to format the y-axis ticks with commas as thousands separators
    def format_ticks(value, tick_number):
        return f'{value:,.0f}'
    
    # Apply the custom formatter to the y-axis
    plt.gca().yaxis.set_major_formatter(ticker.FuncFormatter(format_ticks))
    
    # Show the plot
    plt.show()
    

The resulting output will look like this:

A scatterplot labeled revenue versus advertising for a company. The X axis is labeled Advertising $000 and ranges from 60 to 140. The Y axis is labeled Revenue $000 and ranges from 12,000 to 19,000. A color key with blue (13,000) as the lowest value and red (19,000) as the highest value runs along the right side of the graph. There are 12 data points showing an increase from left to right with the colors aligning with the key.

A typical application for this capability is to plot the color of the datapoint on a scatterplot based on the value of a third variable. For example, a real estate professional may be interested in the correlation between home price, square footage of the home, and location in terms of distance in miles from the center of a city. The scatterplot can be created based on (x,y)(x,y) data, where x is the square footage of the home, y is the price of the home, and z is the distance from center of the city, represented by the color of the point.

Python Code

    # import the matplotlib library
    import matplotlib.pyplot as plt
    
    # define the x-data, which is square footage of a home
    x = [2100, 2378, 1983, 1422, 2764, 1901, 1198, 1785, 1556, 2931, 3071, 1688]
    
    # define the y-data, which is the corresponding home price in 000
    y = [390, 427, 350, 285, 479, 299, 250, 310, 290, 495, 515, 284]
    
    # define the z-data, which is the corresponding distance from city center
    z = [4.8, 0.5, 0.9, 1.5, 0.8, 4.2, 3.1, 0.1, 2.2, 1.5, 4.1, 0.7]
    
    # use the scatter function to generate a time series graph
    # specify that the color of the data point will be based on revenue (y)
    # specify the hot color palette
    plt.scatter(x, y, c = z, cmap='hot')
    
    # also plot a color bar so user can interpret the color scale
    plt.colorbar()
    
    # Add a title using the title function
    plt.title("Home Price versus Square Footage (Color = Miles from Center City)")
    
    # Add labels to the x and y axes by using xlabel and ylabel functions
    plt.xlabel("Square Footage")
    plt.ylabel ("Home Price $000")
    
    # Define a function to format the ticks with commas as thousands separators
    def format_ticks(value, tick_number):
        return f'{value:,.0f}'
    
    # Apply the custom formatter to the x-axis
    plt.gca().xaxis.set_major_formatter(ticker.FuncFormatter(format_ticks))
    

The resulting output will look like this:

A scatterplot labeled home prices versus square footage (color = miles from city center). The X axis is labeled square footage and ranges from 1,250 to 3,000. The Y axis is labeled home price $000 and ranges from 250 to 500. A color key with red (1) as the lowest value and yellow (4) as the highest value runs along the right side of the graph representing miles from city center. There are 12 data points showing an increase from left to right with the colors aligning with the key.

Correlation Heatmaps

Recall from Correlation and Linear Regression Analysis that correlation analysis is used to determine if one quantity is related to or correlates with another quantity. For example, you know that the value of a car is correlated with the age of the car where, in general, the older the car, the less its value. A numeric correlation coefficient (r) was calculated that provided information on both the strength and direction of the correlation.

The value of r gives us this information:

  • The value of r is always between 11 and +1+1. The size of the correlation r indicates the strength of the linear relationship between the two variables. Values of r close to 11 or to +1+1 indicate a stronger linear relationship. Values of r close to zero indicate weak correlation.

The sign of r gives us this information:

  • A positive value of r means that when x increases, y tends to increase, and when x decreases, y tends to decrease (positive correlation).
  • A negative value of r means that when x increases, y tends to decrease, and when x decreases, y tends to increase (negative correlation).

Data scientists often collect data on more than two variables and then are interested in the various correlations between the variables. Of particular interest are those pairs of variables with stronger correlations. A convenient way to visualize this data is with a correlation heatmap, which assigns a color palette to the various values of the correlation coefficient for the two variables.

Python provides a mechanism to calculate a correlation matrix, which shows the value of r for each pair of variables in a dataset using the corr(). This correlation matrix is typically used by researchers to identify those variables having stronger correlations. A correlation heatmap is a visual representation of the correlation matrix that implements color coding to visualize those variables with stronger correlations and those variables with weaker correlations.

Typically, darker colors are used to indicate variables with higher correlations and lighter colors are used to indicate variables with weaker correlations.

As an example of a correlation heatmap, we can use an available dataset from the Seaborn library related to car crash data found here at the car crashes dataset; the fields in the dataset are shown in Table 9.5.

Field Name Field Description
abbrev Abbreviation of the state name
total Total number of car crashes reported in the state
speeding Percentage of car crashes in which speeding was a contributing factor
alcohol Percentage of car crashes in which alcohol was involved
not_distracted Percentage of car crashes where the driver was not distracted
no_previous Percentage of car crashes where the driver had no previous accidents
ins_premium Insurance premium per insured vehicle
ins_losses Insurance losses incurred per insured vehicle
Table 9.5 Fields and Descriptions for the “Car_Crashes” Dataset from the Seaborn Library

There are 50 records in the dataset (one record per state). The first five lines of the dataset appear in Table 9.6.

total speeding alcohol not_distracted no_previous ins_premium ins_losses abbrev
18.8 7.332 5.64 18.048 15.04 784.55 145.08 AL
18.1 7.421 4.525 16.29 17.014 1053.48 133.93 AK
18.6 6.51 5.208 15.624 17.856 899.47 110.35 AZ
22.4 4.032 5.824 21.056 21.28 827.34 142.39 AR
12 4.2 3.36 10.92 10.68 878.41 165.63 CA
Table 9.6 “Car_Crashes” Dataset Excerpt

The short Python program that follows can be used to generate a correlation heatmap for the variables in the dataset; it makes use of the heatmap function available in the Seaborn library (see also Decision-Making Using Machine Learning Basics).

Python Code

    import seaborn as sns
    import matplotlib.pyplot as plt
    
    # Load the car_crashes dataset from seaborn library
    dataset = sns.load_dataset("car_crashes")
    
    # Select only numeric columns
    numeric_columns = dataset.select_dtypes(include=['number'])
    
    # Calculate the correlation matrix
    correlation_matrix = numeric_columns.corr()
    
    # Generate the correlation heatmap using the Reds color palette
    sns.heatmap(correlation_matrix, cmap='Reds', annot=True, fmt=".2f")
    
    # Add a Title to the plot
    plt.title('Correlation Heatmap for Variables in Car Crashes Dataset')
    
    # Show the plot
    plt.show()
    

The resulting output will look like this:

Heatmap showing correlation between variables in a car crash dataset. The variables are total, speeding, alcohol, not distracted, no previous, insurance minimum, and insurance losses. A color key with pink (0.0) as the lowest value and dark red (1.0) as the highest value runs along the right side of the graph. The strongest positive correlations are observed between total crashes and speeding, alcohol use, and distraction.

Notice in the correlation heatmap that dark-red coloring is visible on the diagonal from the upper left to the lower right corner of the matrix. This dark shading represents a correlation coefficient of 1.0 since any variable is considered to be perfectly correlated with itself.

Visualizing Three-Dimensional Data

Let’s say you are taking an anatomy class, and an instructor attempts to verbally describe the human nervous system. Compare this to an interactive 3D model of the human nervous system, where a student can view the nervous system in three dimensions, rotate the 3D image, and view the nervous system from various angles and perspectives. Let’s say an auto mechanic needs to repair a transmission on a car and attempts to read the written technical service manual to understand how to start the repair. Compare this to an interactive 3D model of the car’s transmission where the mechanic can visualize the exact location for the repair and interact with the 3D model.

These examples show the advantages and power of three-dimensional perspectives. Visualizing three-dimensional data has many benefits such as spatial perception, depth perception, and an ability to better illustrate the clustering and patterns inherent in the dataset. 3D visualization is also useful in simulation and modeling applications.

Three-dimensional visualization is also used as part of multiple linear regression analysis to view a scatterplot in three dimensions and view the corresponding regression plane (see Machine Learning in Regression Analysis).

Python includes several tools and libraries to facilitate 3D visualization and analysis; some tools allow for interactivity such as the ability to zoom in/zoom out, pan, and rotate three-dimensinoal images. Here are some Python modules that support 3D visualization:

  • Matplotlib includes facilities for 3D visualization such as mpl.toolkits.mplot3d.
  • Plotly is a library that includes 3D plotting capability, surface plots, and more.
  • VTK is a library for 3D graphics and image processing (VTK stands for Visualization Toolkit).

To illustrate an example of a 3D scatterplot in Python using Matplotlib, let’s generate Python code to plot the Seaborn “Iris” data, which can be downloaded from the Seaborn repository. (We used this dataset in What Are Data and Data Science? and Deep Learning and AI Basics as well.) The dataset contains 150 rows of data related to measurements for different varieties of iris flowers, but we will only be using a few rows (see Table 9.7).

Sepal Length Sepal Width Petal Length Petal Width Species of Iris
5.1 3.5 1.4 0.2 setosa
4.9 3.0 1.4 0.2 setosa
6.7 3.0 5.2 2.3 virginica
5.9 3.0 5.1 1.8 virginica
Table 9.7 “Iris” Dataset Excerpt

We will use the mpl_toolkits.mplot3d() function to generate the three-dimensional scatterplot. First, we will read in the “iris” dataset from the Seaborn library. Next, the 3D scatterplot is created using the fig.add.subplot() function. The ax.scatter() function is then used to plot the data points.

Here is the Python code:

Python Code

    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import seaborn as sns
    import matplotlib.cm as cm
    import matplotlib.colors as mcolors
    from matplotlib.patches import Rectangle
    
    # Load the iris dataset from Seaborn data repository
    flower = sns.load_dataset('iris')
    
    # Create a blank figure with a white background
    fig = plt.figure(figsize=(9, 8), facecolor='white')
    
    # Add a subplot to the figure
    ax = fig.add_subplot(111, projection='3d')
    
    # Set the background color of the axes to white
    ax.set_facecolor('white')
    
    # Plot data points
    scatter = ax.scatter(flower['sepal_length'], flower['sepal_width'], flower['petal_length'], 
                         c=flower['species'].astype('category').cat.codes, cmap='viridis', marker='o')
    
    # Set labels and title
    ax.set_xlabel('Sepal Length (cm)')
    ax.set_ylabel('Sepal Width (cm)')
    ax.set_zlabel('Petal Length (cm)')
    ax.set_title('3D Scatterplot of Iris Dataset by Species')
    
    # Map numerical codes to actual species names
    species_names = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
    
    # Create a ScalarMappable object
    norm = plt.Normalize(flower['species'].astype('category').cat.codes.min(), 
                         flower['species'].astype('category').cat.codes.max())
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
    sm.set_array([])
    
    # Create the colorbar and adjust its spacing
    cbar = plt.colorbar(sm, ticks=[0, 1, 2], label='Species', fraction=0.02, pad=0.1)
    
    # Adjust colorbar position to increase space between colorbar and plot
    cbar.ax.tick_params(labelsize=10)  # Adjust the colorbar tick label size if needed
    
    # Adjust figure layout to provide appropriate space around the plot
    fig.subplots_adjust(left=0.1, right=0.85, top=0.9, bottom=0.1)  # Adjust margins as needed
    
    # Add legend
    ax.legend(handles=scatter.legend_elements()[0], 
              labels=[species_names[name] for name in flower['species'].astype('category').cat.codes.unique()], 
              title='Species')
    
    plt.show()
    

The resulting output will look like this:

A 3D scatter plot of the iris dataset by species showing the relationship between sepal length (cm), sepal width (cm), and petal length (cm) with a color gradient representing species setosa, versicolor, and virginica.

Datasets

Note: The primary datasets referenced in the chapter code may also be downloaded here.

Citation/Attribution

This book may not be used in the training of large language models or otherwise be ingested into large language models or generative AI offerings without OpenStax's permission.

Want to cite, share, or modify this book? This book uses the Creative Commons Attribution-NonCommercial-ShareAlike License and you must attribute OpenStax.

Attribution information
  • If you are redistributing all or part of this book in a print format, then you must include on every physical page the following attribution:
    Access for free at https://openstax.org/books/principles-data-science/pages/1-introduction
  • If you are redistributing all or part of this book in a digital format, then you must include on every digital page view the following attribution:
    Access for free at https://openstax.org/books/principles-data-science/pages/1-introduction
Citation information

© Dec 19, 2024 OpenStax. Textbook content produced by OpenStax is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike License . The OpenStax name, OpenStax logo, OpenStax book covers, OpenStax CNX name, and OpenStax CNX logo are not subject to the Creative Commons license and may not be reproduced without the prior and express written consent of Rice University.