Skip to ContentGo to accessibility pageKeyboard shortcuts menu
OpenStax Logo
Principles of Data Science

6.2 Classification Using Machine Learning

Principles of Data Science6.2 Classification Using Machine Learning

Learning Outcomes

By the end of this section, you should be able to:

  • 6.2.1 Perform logistic regression on datasets and interpret the results.
  • 6.2.2 Perform k-means clustering on datasets.
  • 6.2.3 Define the concept of density-based clustering and use DBScan on datasets.
  • 6.2.4 Interpret the confusion matrix in clustering or classifying data.

Classification problems come in many flavors. For example, suppose you are tasked with creating an algorithm that diagnoses heart disease. The input features may include biological sex, age, blood pressure data, and cholesterol levels. The output would be either: yes (diagnose heart disease) or no (do not diagnose heart disease). That is, your algorithm should classify patients as “yes” or “no” based on an array of features, or symptoms in medical terminology. Logistic regression is one tool for classification when there are only two possible outputs. This is often called a binary (binomial) classification problem. If the goal is to classify data into more than two classes or categories, then the problem is referred to as multiclass (multinomial) classification.

Other kinds of classification problems involve finding two or more clusters in the data. A cluster is a collection of data points that are closer to one another than to other data points, according to some definition of closeness. For example, certain characteristics of music, such as tempo, instrumentation, overall length, and types of chord patterns used, can be used as features to classify the music into various genres, like rock, pop, country, and hip-hop. A supervised machine learning algorithm such as a decision tree (see Decision Trees) or random forest (see Other Machine Learning Techniques) may be trained and used to classify music into the various predefined genres. Alternatively, an unsupervised model such as k-means clustering could be used to group similar-sounding songs together without necessarily adhering to a concept of genre.

This section will focus on logistic regression techniques and clustering algorithms such as k-means and DBScan. We’ll use measures of accuracy including the confusion matrix.

Logistic Regression

A logistic regression model L(X)L(X) takes input vectors XX (features) and produces an output of “yes” or “no.” Much like linear regression, the logistic regression model first fits a continuous function based on known, labeled data. However, instead of fitting a line to the data, a sigmoid function is used. A sigmoid function is a specific type of function that maps any real-valued number to a value between 0 and 1. The term sigmoid comes from the fact that the graph has a characteristic “S” shape (as shown in Figure 6.6), and sigma (σ)(σ) is the Greek letter that corresponds to our letter S. We often use the notation σ(x)σ(x) to denote a sigmoid function. The basic formula for the sigmoid is:

σ(x)=11+exσ(x)=11+ex
A line graph.” The X axis ranges from -1 to 1 and a y axis in the center of the graph ranges from 0 to 1. The graph displays a sigmoid function that forms an S-curve that starts at (-1, 0), crosses the center Y axis at (0, 0.5), and ends at (1, 1).
Figure 6.6 Graph of a Typical Sigmoid Function

One advantage of the sigmoid over a linear function is that the values of the sigmoid stay close to 0 when xx is negative and close to 1 when xx is positive. This makes it ideal for classifying arbitrary inputs xx into one of two classes, “yes” (1) or “no” (0). In fact, we can interpret the y-axis as probability that an event occurs. The sigmoid function can be shifted right or left and compressed or expanded horizontally as necessary to fit the data. In the following formula, the parameters aa and bb control the position and shape of the sigmoid function, respectively.

σ(a+bx)=11+e(a+bx)σ(a+bx)=11+e(a+bx)

As complicated as this formula may seem at first, it is built up from simple elements. The presence of a+bxa+bx suggests that there may be a linear regression y=a+bxy=a+bx lurking in the background. Indeed, if we solve to find the inverse function of σ(x)σ(x), we can use it to isolate a+bxa+bx, which in turns allows us to use standard linear regression techniques to build the logistic regression. Recall that the natural logarithm function, ln nn, is the inverse function for the exponential, exex. In what follows, let pp stand for σ(x)σ(x).

p=σ(x)=11+exp=σ(x)=11+ex
1p=1+ex1p1=exex=1p1=1ppx=ln(1pp)x=ln(1pp)=ln[(1pp)1]=ln(p1p)1p=1+ex1p1=exex=1p1=1ppx=ln(1pp)x=ln(1pp)=ln[(1pp)1]=ln(p1p)

In the final line, we used the power property of logarithms, ln An=n ln Aln An=n ln A, with power n=1n=1. The function we’ve just obtained is called the logit function.

logit(p)=ln(p1p)logit(p)=ln(p1p)

The key property of the logit function is that it is the inverse function of σ(x)σ(x). Recall from algebra, if g(x)g(x) is the inverse function of f(x)f(x), then we have g(f(x))=xg(f(x))=x for all xx. In our case, that means logit(σ(x))=xlogit(σ(x))=x, for all xx. Now applying the logit function to the general sigmoid function, we can write:

logit(σ(a+bx))=a+bxlogit(σ(a+bx))=a+bx

Thus, the logit function linearizes the sigmoid. Let’s dig a little deeper into the logit function and what it measures.

Log-Odds and Discrete Logistic Regression

Logistic regression is a prediction model based on the idea of odds. The odds that an event E occurs is the ratio of the probability of E occurring over the probability of E not occurring. Let p=P(E)p=P(E).

Odds(E)=P(E)P(E)=P(E)1P(E)=p1pOdds(E)=P(E)P(E)=P(E)1P(E)=p1p

Thus, the logit function is simply the logarithm of the odds of an event with a given probability pp. This is why logit is often called log-odds. While probabilities (p)(p) range from 0 to 1, log-odds will range from (minus infinity) to (infinity).

Suppose we would like to predict the likelihood of an event AA, given that another event BB either occurs or does not occur. We will create a logistic regression model, f(x)=11+e(a+bx)f(x)=11+e(a+bx), such that:

  • f(0)f(0) is the likelihood that AA occurs when BB does not occur, and
  • f(1)f(1) is the likelihood that AA occurs when BB does occur.

We obtain aa and bb using the logit function. If p1p1 is the probability that AA occurs given that BB does not occur (i.e., p1=P(A|B)p1=P(A|B)), using the notation of conditional probability (see Normal Continuous Probability Distributions), and if p2p2 is the probability that AA occurs given that BB does occur (i.e., p2=P(A|B)p2=P(A|B)), then

a=logit(p1)a=logit(p1), and b=logit(p2)logit(p1)b=logit(p2)logit(p1)

This model is called a discrete logistic regression because the feature variable is either 1 or 0 (whether or not event BB occurred).

Example 6.4

Problem

At a particular college, it has been found that in-state students are more likely to graduate than out-of-state students. Roughly 75% of in-state students end up graduating, while only about 30% of out-of-state students do so. Build a discrete logistic regression model based on this data.

The previous example used only a single feature variable, “in-state,” and so the usefulness of the model is very limited. We shall see how to expand the model to include multiple inputs or features in Multiple Regression Techniques.

Maximum Likelihood and Continuous Logistic Regression

The method of logistic regression can also be used to predict a yes/no response based on continuous input. Building the model requires finding values for the parameters of the sigmoid function that produce the most accurate results. To explain how this works, let’s first talk about likelihood. The likelihood for a model f(x)f(x) is computed as follows. For each data point xkxk, if the label of xkxk is 1 (or yes), then calculate f(xk)f(xk); if the label is 0 (or no), then calculate 1f(xk)1f(xk). The product of these individual results is the likelihood score for the model. We want to find the model that has the maximum likelihood.

Example 6.5

Problem

Consider the models

f(x)=11+e1.61.1x,g(x)=11+e2.10.8xf(x)=11+e1.61.1x,g(x)=11+e2.10.8x

Suppose there are four data points, A=1A=1, B=2B=2, C=3C=3, D=4D=4, and both AA and BB are known to have label 0, while CC and DD have label 1. Find the likelihood scores of each model and determine which model is a better fit for the given data.

How does one find the model that has maximum likelihood? The exact method for finding the coefficients aa and bb for the model σ(a+bx)σ(a+bx) falls outside the scope of this text. Fortunately, there are software packages available that can be used to do the work for you.

For example, the same university ran a study on the effect of first-year GPA on completion of a college degree. A small sample of the data is shown in Table 6.5. What is the model that best fits this data?

GPA Completed College?
1.5 N
2.4 Y
3.4 Y
2.1 N
2.5 Y
0.8 N
2.9 N
4.0 Y
2.3 Y
2.1 N
3.6 Y
0.5 N
Table 6.5 GPA vs. College Completion Data

Figure 6.7 shows the data in graphical form.

A line graph with an X axis that ranges from 0 to 4 and a Y axis that ranges from 0 to 1. There are five data points at 0 on the Y axis (0, 0.5,) (0, 0.8), (0, 1.5), (0, 2.1), (0, 2.9) and six data points at 1 on the Y axis at (2.3, 1), (2.4, 1), ( 2.5, 1), (3.4, 1), (3.6, 1), ( 4.0, 1).
Figure 6.7 Graph of GPA (x-axis) and College Completion (y-axis). This uses values 0 for non-completion and 1 for completion. The graph suggests that higher GPAs tend to predict completion of a degree.

The logistics model (found using computer software) that maximizes the likelihood is:

11+e2.8261.197x11+e2.8261.197x

The logistics regression model that best fits the data is shown in Figure 6.8 along with the data points colored green for predicted completion and red for predicted non-completion.

A line graph with an X axis that ranges from 0 to 4 and a Y axis that ranges from 0 to 1. There are five red data points (0, 0.5,) (0, 0.8), (0, 1.5), (0, 2.1), and (2.3, 1) and six green data points (0, 2.9), (2.4, 1), ( 2.5, 1), (3.4, 1), (3.6, 1), ( 4.0, 1). Data points are colored green for predicted completion and red for predicted non-completion.
Figure 6.8 Logistic Regression Model Fitting the Data in Table 6.5

Logistic Regression in Python

A sample of Python code for logistic regression, using the data from Table 6.5, appears next. We use the LogisticRegression function found in the Python library sklearn.linear_model.

Python Code

      # Import libraries
      from sklearn.linear_model import LogisticRegression
      
      # Define input and output data
      data = [[1.5, 0], [2.4, 1], [3.4, 1 ], [2.1, 0 ], [2.5, 1 ], 
       [0.8, 0 ], [2.9, 0 ], [4.0, 1], [2.3, 1], [2.1, 0], [3.6, 1 ], [0.5, 0]]
      
      # Separate features (X) and labels (y)
      X = [[row[0]] for row in data]  # Feature
      y = [row[1] for row in data]    # Label
      
      # Build the logistic model
      Lmodel = LogisticRegression()
      Lmodel.fit(X,y)
      
      # Display the coefficients of the logistic regression
      (a,b) = (Lmodel.intercept_[0], Lmodel.coef_[0,0])
      print("Coefficients: ", (round(a,3),round(b,3)))
      
      # Display the accuracy of the model
      s = Lmodel.score(X,y)
      print("Accuracy: ", round(s, 3))
    

The resulting output will look like this:

Coefficients: (-2.826, 1.197)
Accuracy: 0.833

The code produces the coefficients that can be plugged in for aa and bb in the model:

σ(2.826+1.197x)=11+e(2.826+1.197x)=11+e2.8261.197xσ(2.826+1.197x)=11+e(2.826+1.197x)=11+e2.8261.197x

k-Means Clustering

Clustering algorithms perform like digital detectives, uncovering patterns and structure in data. The goal is to find and label groups of data points that are close to one another. This section develops methods for grouping data (clustering) that incorporate machine learning.

The k-means clustering algorithm can classify or group similar data points into clusters or categories without prior knowledge of what those categories might be (i.e., unsupervised learning). However, the user must provide a guess as to the value of kk, the number of clusters to expect. (Note: You may get around this issue by iterating the algorithm through many values of kk and evaluating which k-value produced the best results. The trade-off is the time it takes to run the algorithm using multiple values of kk and interpreting the results using. The so-called elbow method is often used to find the best value of kk, but it is beyond the scope of this text to explain this method.) The user must also provide initial guesses as to where the centroids would be. We shall see by example how these choices may affect the algorithm.

k-Means Clustering Fundamentals

The basic idea behind k-means is quite intuitive: Each cluster must be centered around its centroid in such a way that minimizes distance from the data in the cluster to its centroid. The trick is in finding those centroids! Briefly, a centroid of a set of data points is defined as the arithmetic mean of the data, regarded as vectors.

Centroid=1Ni=1NXiCentroid=1Ni=1NXi

The basic steps of the k-means clustering algorithm are as follows:

  1. Guess the locations of kk cluster centers.
  2. Compute distances from each data point to each cluster center. Data points are assigned to the cluster whose center is closest to them.
  3. Find the centroid of each cluster.
  4. Repeat steps 2 and 3 until the centroids stabilize—that is, the new centroids are the same (or approximately the same, within some given tolerance) as the centroids from the previous iteration.

If the clustering was done well, then clusters should be clearly separated. The silhouette score is a measure of how well-separated the clusters are and is defined as follows:

Let a(i)a(i) be the mean distance between point ii and all other data points in the same cluster as point ii.

For each cluster JJ not containing point ii, let bJ(i)bJ(i) be the mean distance between point ii and all data points in cluster JJ. Then, let b(i)b(i) be the minimum of all the values of bJ(i)bJ(i).

Then the silhouette score is found by the formula

S=1Ni=1Nb(i)a(i)max{a(i),b(i)}S=1Ni=1Nb(i)a(i)max{a(i),b(i)}

Values of SS lie between 1 and 1, with values close to 1 indicating well-separated clusters, values close to 0 indicating clusters are ambiguous, and values close to 1 indicating poor clustering with points assigned to clusters arbitrarily. Fortunately, statistical software packages that can do k-means clustering will be able to compute the silhouette score for you.

Example 6.6

Problem

An outbreak of fungus has affected a small garden area. The fungus originated underground, but it produced mushrooms that can easily be spotted. From a photo of the area, it is surmised that there are three clusters of mushrooms (see Figure 6.9). Use k-means to classify the mushrooms into three groups.

A scatterplot with an X axis labeled “East” that ranges from 20 to 160 and a Y axis labeled “North” that ranges from 20 to 90. There are 19 data points that represent locations of mushrooms in a garden and represent three apparent clusters with a couple of outliers.
Figure 6.9 The Locations of Mushrooms in a Garden. There are three apparent clusters.

Among clustering algorithms, k-means is rather simple, easy to implement, and very fast. However, it suffers from many drawbacks, including sensitivity to initial choice of centroids. Furthermore, k-means does detect clusters that are not simply “blobs.” In the next section, we will discuss another clustering algorithm that generally performs better on a wider range of data.

k-Means Clustering in Python

Our previous example used a sample of only 19 data points selected from the 115 data points found in the file FungusLocations.csv. Here is the Python code to produce a k-means clustering from a dataset:

Python Code

      # Import libraries
      import pandas as pd  ## for dataset management
      import matplotlib.pyplot as plt ## for data visualization
      from sklearn.cluster import KMeans
      from sklearn.metrics import silhouette_score
      
      # Read data
      data = pd.read_csv('FungusLocations.csv').dropna()
      
      # Build K-means model with 3 clusters
      km = KMeans(n_clusters=3, n_init='auto')
      km.fit(data)
      
      # Calculate silhouette score
      silhouette_avg = silhouette_score(data, km.labels_)
      print("Silhouette Score:", round(silhouette_avg,2))
      
      # Visualize the result of k-Means clustering using matplotlib
      plt.scatter(data['East'], data['North'], c=km.labels_, cmap='viridis')
      plt.xlabel('East')
      plt.ylabel('North')
      plt.title('k-Means Clustering Result')
      plt.show()
      

The resulting output will look like this:

Silhouette Score: 0.77
A scatterplot labeled “k-Means Clustering Result” with an X axis labeled “East” that ranges from 20 to 160 and a Y axis labeled “North” that ranges from 20 to 90. There are 19 data points that represent locations of mushrooms in a garden. The points are shown in three clusters colored yellow, purple, and green.

With a silhouette score of about 0.77, the separation of data points into three clusters seems appropriate.

Exploring Further

Centroids versus Other Means

Instead of using centroids (which are computed using arithmetic means, or averages), other means may be employed, including medians, trimmed averages, geometric or harmonic means, and many others. These algorithms may have different names such as k-medians, k-medoids, etc. More information can be found at Neptune and Tidymodels articles..

Density-Based Clustering (DBScan)

Suppose you need to classify the data in Figure 6.13 into two clusters.

A density-based clustering scan with an X axis that ranges from -15 to 15 and a Y axis that ranges from -15 to 15. Blue dots form a larger outer circle cluster and a smaller inner cluster in the center of the circle.
Figure 6.13 Dataset with an Inner and Outer Ring

How might you do this? There seems to be a clear distinction between points on the inner ring and points on the outer ring. While k-means clustering would not do well on this dataset, a density-based clustering algorithm such as DBScan might perform better.

DBScan Fundamentals

The DBScan algorithm (DBScan is an acronym that stands for density-based spatial clustering of applications with noise) works by locating core data points that are part of dense regions of the dataset and expanding from those cores by adding neighboring points according to certain closeness criteria.

First, we define two parameters: KK, and rr. A point in the dataset is called a core point if it has at least KK neighbors, counting itself, within a distance of rr from itself.

The steps of DBScan are as follows:

  1. Choose a starting point that is a core point and assign it to a cluster. (Note: If there are no core points in the dataset at all, then the parameters KK and rr will need to be adjusted.)
  2. Add all core points that are close (less than a distance of rr) to the first point into the first cluster. Keep adding to the cluster in this way until there are no more core points close enough to the first cluster.
  3. If there are any core points not assigned to a cluster, choose one to start a new cluster and repeat step 2 until all core points are assigned to clusters.
  4. For each non-core point, check to see if it is close to any core points. If so, then add it to the cluster of the closest core point. If a non-core point is not close enough to any core points, it will not be added to any cluster. Such a point is regarded as an outlier or noise.

Since the values of KK and rr are so important to the performance of DBScan, the algorithm is typically run multiple times on the same dataset using a range of KK values and rr values. If there are a large number of outliers and/or very few clusters, then the model may be underfitting. Conversely, a very small number of outliers (or zero outliers) and/or many clusters present in the model may be a sign of overfitting.

Let’s work through a small example using DBScan with K=3K=3 and r=2.1r=2.1 to classify the following dataset into clusters. (Points are labeled by letters for convenience.)

A: (3.5, 5.8) B: (4.5, 3.5) C: (5.0, 5.0) D: (6.0, 6.0) E: (6.0, 3.0) F: (6.9, 4.7)

G: (9.0, 4.0) H: (11.0, 5.0) J: (12.0, 4.0) K: (12.4, 5.4) L: (13.2, 2.5)

Note: In practice, DBScan would never be done by hand, but the steps shown as follows illustrate how the algorithm works. Figure 6.14 displays the data.

A scatterplot with an X axis that ranges from 0 to 16 and a Y axis that ranges from 0 to 8. There are 11 data points labeled A through L randomly scattered.
Figure 6.14 Original Dataset with 11 Unlabeled Points.

The core points are those that have at least KK neighbors (or K1K1, not counting the point itself) within a distance of rr. The easiest way to determine core points would be to draw circles of radius 2.1 centered at each point, as shown in Figure 6.15, and then count the number of points within each circle.

A scatterplot with an X axis that ranges from 0 to 16 and a Y axis that ranges from 0 to 8. There are 11 data points labeled A through L randomly scattered. There are 11 blue circles, each centered around a data point. The circles overlap and intersect illustrating which points are core or not core.
Figure 6.15 Circles of Radius Centered at Each Point. Drawing circles of radius rr around each data point, it is easier to see which points are core and which are not core.

Based on Figure 6.15, point A has only two neighbors, A and C, so A is not a core point. On the other hand, point B has three neighbors, B, C, and E, so B is core. The full list of core points is B, C, D, E, F, H, J, and K. So B may be chosen to start the first cluster. Then, building from B, points C and E are added next. Next, D, E, and F are added to the growing first cluster. Neither A nor G are added at this stage since they are not core points. Having found all core points that are in the first cluster, let H start a second cluster. The only core points left are close enough to H that the second cluster is H, J, and K.

Finally, we try to place the non-core points into clusters. Point A is close enough to C to make it into the first cluster. Point L is close enough to J to become a part of the second cluster. Point G does not have any core points near enough, so G will be un-clustered and considered noise. Thus, cluster 1 consists of {A, B, C, D, E, F}, and cluster 2 consists of {H, J, K, L}. The results of DBScan clustering are shown in Figure 6.16.

A scatterplot with an X axis that ranges from 0 to 16 and a Y axis that ranges from 0 to 8. There are 11 data points labeled A through L. Data points A-F are blue, G is black, and H-L are red, representing a DBScan of two clusters. Cluster 1 is in blue and Cluster 2 is in red.
Figure 6.16 Graph Showing Results of DBScan Clustering. Cluster 1 is in blue and Cluster 2 is in red.

DBScan in Python

The Python library sklearn.cluster has a module named DBSCAN. Here is how it works on the dataset DBScanExample.csv.

Python Code

      # Import libraries
      import pandas as pd ## for dataset management
      from sklearn.cluster import DBSCAN
      
      # Read data
      data = pd.read_csv('DBScanExample.csv').dropna()
      
       # Run DBScan
      db = DBSCAN(eps=1.3, min_samples=5).fit(data)
      labels = db.labels_
      
      # Find the number of clusters and points of noise
      n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
      n_noise = list(labels).count(-1)
      print("Number of clusters: ", n_clusters)
      print("Points of noise:", n_noise)
    

The resulting output will look like this:

Number of clusters: 2
Points of noise: 8

Python Code

      # Visualize clusters
      import matplotlib.pyplot as plt ## for data visualization
      
      plt.scatter(data['X'], data['Y'], c=db.labels_)
    

The resulting output will look like this:

A density-based clustering scan with an X axis that ranges from -10 to 10 and a Y axis that ranges from -10 to 10. Yellow dots form a larger outer circle cluster and green dots form a smaller inner cluster in the center of the circle. Eight purple dots appear as outliers inside and outside the circle.

The Confusion Matrix

When training and testing any algorithm that performs classification, such as logistic regression models, k-means clustering, or DBScan, it is important to measure accuracy and identify error. The confusion matrix is one way of quantifying and visualizing the effectiveness of a classification model. This concept will be illustrated by the following example.

Suppose you have trained a model that classifies images of plants into one of three categories: flowers, trees, and grasses. When you test the model on 100 additional images of plants, you discover that it correctly identified flowers, trees, and grasses the majority of the time, but there were also quite a few mistakes. Table 6.6 displays the results in a confusion matrix:

Identified as a Flower Identified as a Tree Identified as a Grass
Is a flower 23 3 9
Is a tree 2 32 0
Is a grass 12 1 18
Table 6.6 A Confusion Matrix for Flowers, Trees, and Grasses

Thus, we can see that 23 flowers, 32 trees, and 18 grasses were identified correctly, giving the model an accuracy of 23+32+18100=0.7323+32+18100=0.73, or 73%. Note: The terms TP, TN, FP, and FN introduced in What Is Machine Learning? and defined again in the next example do not apply when classifying data into three or more categories; however, the confusion matrix gives more detailed information about the mistakes. Where there are relatively low numbers off the main diagonal, the model performed well. Conversely, higher numbers off the main diagonal indicated greater rates of misidentification, or confusion. In this example, when presented with a picture of a flower, the model did a pretty good job of identifying it as a flower but would mistake flowers for grasses at a higher rate than mistaking them for trees. The model did very well identifying trees as trees, never mistaking them for grasses and only twice mistaking them for flowers. On the other hand, the model did not do so well identifying grasses. Almost 40% of the images of grasses were mistaken for flowers! With this information in hand, you could go back to your model and adjust parameters in a way that may address these specific issues. There may be a way to tweak the model that helps with identifying grasses in particular. Or you may decide to take extra steps whenever an image is identified as a flower, perhaps running the image through additional analysis to be sure that it truly is a flower and not just a mislabeled grass.

When there are only two classes (binary classification), the confusion matrix of course would have only four entries. There are special terms that apply to this case when the two classes are “Positive” and “Negative.” Think of diagnosing a disease. Either the patient has the disease or they do not. A doctor can perform a test, and the hope is that the test will determine for sure whether the patient has that disease. Unfortunately, no test is 100% accurate. The two cases in which the test fails are called false positive and false negative. A false positive (also called type I error) occurs when the true state is negative, but the model or test predicts positive. A false negative (also called type II error) is just the opposite: the state is positive, but the model or test predicts a negative. Both errors can be dangerous in fields such as medicine. Some terminology for the four possibilities are shown in Table 6.7.

Predicted Positive Predicted Negative
Actual Positive True Positive (TP) – hit False Negative (FN) – miss
Actual Negative False Positive (FP) – false alarm True Negative (TN) – correctly rejected
Table 6.7 Terminology Summarizing the Four Possibilities from a Confusion Matrix

Adjustments to a model that reduce one type of error generally increases the rate of the other type of error. That is, if you want your model to have a very low rate of false negatives, then it may become positively biased (more sensitive), predicting more positives whether they are true or false positives. On the other hand, if the rate of false positives needs to be reduced, then the model may become negative biased (less sensitive) and yield more negative predictions overall, both true and false negatives.

Visualizing Confusion Matrices

Confusion matrices are often shaded or colored in a way to show contrast of high and low values, which is helpful in locating any abnormal behavior in the model, often called a heatmap. Figure 6.17 provides a heatmap for the flower/tree/grass example.

A heatmap with 9 cells. From left to right, columns are labeled “flower,” “tree,” “grass.” From top to bottom, rows are labeled “flower,” “tree,” “grass.” The top row reads 23, 3, 9. The middle row reads 2, 32, 0. The bottom row reads 12, 1, 18. Cells are varying shades of gray with 0 being the lightest and 32 being the darkest.
Figure 6.17 Heatmap for Flower/Tree/Grass Example

The darker shades indicate higher values. The main diagonal stands out with darker cells. This is to be expected if our classifier is doing its job. However, darker shades in cells that are off the main diagonal indicate higher rates of misclassification. The numbers 12 in the lower left and 9 in the upper right are prominent. The model clearly has more trouble with flowers and grasses than it does with trees.

Generating a Confusion Matrix in Python

The dataset CollegeCompletionData.csv contains 62 data points (from which only 12 points of data were used in Table 6.6). The following code produces a logistic regression based on all 62 points. Then, the confusion matrix is generated. Finally, a visual display of the confusion matrix as a heatmap is generated using another visualization library called seaborn. If you do not need the heatmap, then just type print(cf) directly after finding the confusion matrix to print a text version of the matrix.

Python Code

      # Import libraries
      import pandas as pd  ## for dataset management
      import matplotlib.pyplot as plt ## for data visualization
      import seaborn as sns ## for heatmap visualization
      from sklearn.linear_model import LogisticRegression
      from sklearn.metrics import confusion_matrix
      
      # Read data
      data = pd.read_csv('CollegeCompletionData.csv').dropna()
      x = data[['GPA']]
      y = data['Completion']
      
      # Build the logistic model
      model = LogisticRegression()
      model.fit(x,y)
      
      # Generate model predictions
      y_pred = model.predict(x)
      
      # Generate the confusion matrix
      cf = confusion_matrix(y, y_pred)
      
      # Plot the heatmap using seaborn and matplotlib
      sns.heatmap(cf, annot=True, fmt='d', cmap='Blues', cbar=True)
      plt.xlabel('Predicted')
      plt.ylabel('True')
      plt.yticks(rotation=0)
      plt.title('Confusion Matrix')
      plt.show()
    

The resulting output will look like this:

A confusion matrix visualizing the performance of a binary classification model. The matrix has two rows and two columns, representing the true and predicted classes, respectively. The diagonal cells (top-left and bottom-right) show the number of correctly classified instances (true positives and true negatives), while the off-diagonal cells show the number of misclassified instances (false positives and false negatives). Clockwise from top left the boxes read: 20, 10, 24, 8. A color scale on the right indicates the frequency of each value. It runs up the right side of the matrix starting as light blue at 5.0 and getting darker at the top at 25. The color of the boxes aligns with this scale.

From the preceding analysis, we find 20 true positives and 24 true negatives, representing an accuracy of 44/62=71%44/62=71%. There were 10 false negatives and 8 false positives.

Citation/Attribution

This book may not be used in the training of large language models or otherwise be ingested into large language models or generative AI offerings without OpenStax's permission.

Want to cite, share, or modify this book? This book uses the Creative Commons Attribution-NonCommercial-ShareAlike License and you must attribute OpenStax.

Attribution information
  • If you are redistributing all or part of this book in a print format, then you must include on every physical page the following attribution:
    Access for free at https://openstax.org/books/principles-data-science/pages/1-introduction
  • If you are redistributing all or part of this book in a digital format, then you must include on every digital page view the following attribution:
    Access for free at https://openstax.org/books/principles-data-science/pages/1-introduction
Citation information

© Dec 19, 2024 OpenStax. Textbook content produced by OpenStax is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike License . The OpenStax name, OpenStax logo, OpenStax book covers, OpenStax CNX name, and OpenStax CNX logo are not subject to the Creative Commons license and may not be reproduced without the prior and express written consent of Rice University.