Correlation analysis can help us understand whether, and how strongly, a pair of variables are related.
In data science and machine learning, this can help us understand relationships between features/predictor variables and outcomes. It can also help us understand dependencies between different feature variables.
- How strong is the correlation between mental stress and cardiac issues?
- Is there a correlation between literacy rate and frequency of criminal activities?
This tutorial will help you learn the different techniques and approaches used to understand correlations that exist between features in any dataset.
In correlation analysis, we also estimate a sample correlation coefficient between a pair of variables. Correlation coefficient of a pair of variables helps us understand how strongly one variable changes with respect to another.
Correlation coefficient ranges between a value of -1 to 1.
The graphs below show different pairs of variables with different correlation coefficient $\rho$:
Graphs of different variables with different Correlation Coefficients. [Source: Wikipedia]
The magnitude signifies the strength of relationship between the variables.
- So, a value of of -1 or +1 indicates that the two variables are perfectly related.
- A value of 0 indicates that the variables have no relation.
The sign (positive / negative) signifies the direction of relationship:
- A positive value indicates that higher values of one variable are accompanied with higher values of the other variable
- A negative value indicates that higher values of one variable are accompanied with lower values of the other variable
The two most common types of correlation coefficient are Pearson’s and Spearman’s.
Pearson’s correlation coefficient is a measure of linear correlation between two variables. The graphs you saw above show Pearson’s correlation coefficient.
Spearman’s rank correlation coefficient assesses monotonic relationships, irrespective of whether it is linear or not. That is to say, it is only interested in knowing that an increase in one variable causes an increase in another variable.
Let us take an example to understand the difference:
Example of a non-linear monotonically related pair of variables. [Source: Wikipedia]
Here we can see that when there is an increase in X, there is always an increase in Y. However, the increase is not linear.
Thus Pearson’s correlation is 0.88 while Spearman’s correlation is a perfect 1.
Also note that correlation is a symmetric relationship. That is to say, if A correlates with B, then B correlates with A.
Before we proceed further, we should clarify that although correlation shows relationships between variables and it helps in predictive analysis, it does not imply causation.
For example: The sale of Sunglasses and the sale of ice-cream are highly correlated. But an increase in ice-cream sales does not cause an increase in the sales of sunglasses, or vice-versa. In this case, they are correlated because they both depend on a third independent variable, which is how hot / sunny it is.
Although correlation coefficients give us an idea about the strength of the relationship between two variables, it is not possible for a single number to give us the full picture.
Since correlation coefficients can only give us limited information, visualisation can be of great aid in understanding the relationship between variables.
For example, the following graphs show the famous example Anscombe’s Quartet. It consists of four data sets, where the correlation coefficient is exactly the same, even though the data sets are very different from one another.
Anscombe’s Quartet [Source: Wikipedia]
For the rest of this tutorial (and the next tutorial as well), we will be see how to use visualization to understand the relation between two variables. To do this, we will we using the
seaborn Python library and a couple of datasets.
We will be using the standard data science libraries — NumPy, Pandas, Matplotlib and Seaborn. So let’s start by importing them.
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns # change pandas display options pd.set_option('display.width', 200) pd.set_option('display.max_columns', 20)
We’ve also asked pandas to increase the display width to 200 characters, and the maximum number of columns it should display to 20. This will make sure we can view our data properly.
We will be using the
wine dataset and the
tips dataset available on CommonLounge. We’ll load these datasets using the
load_commonlounge_dataset() function, which returns a DataFrame.
Let’s first load the
wine dataset (we’ll load a modified version for this tutorial):
wine_data = load_commonlounge_dataset('wine_v2')
Let’s also load the
tips_data = load_commonlounge_dataset('tips')
Below, we’ve included a brief description of these two datasets.
This dataset is the result of a chemical analysis of wines grown (in the same region) in Italy.
The following are the variables in the dataset:
Wine: This is the target variable to be predicted. It is a categorical variable divided into a set of three classes denoting three different types of wines. The classes are labelled as 1, 2 and 3.
All other attributes are continuous numerical variables:
Alcohol: alcohol content
Malic.acid: one of the principal organic acids
Ash: inorganic matter left
Acl: the alkalinity of ash
Flavanoids: particular type of phenol
Nonflavanoid.phenols: particular type of phenol
Proanth: particular type of phenol
Color.int: color intensity
Hue: hue of a wine
OD: protein content measurements
Proline: an amino acid
Let’s take a glimpse of the first few instances in the dataset using the
The data is taken from UCI Machine Learning Repository.
Citation: Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science.
tips dataset contains information about the tips collected at a restaurant.
Here’s a description of the attributes:
tip: (continuous) Tips paid. The target variable to be predicted.
total_bill: (continuous) Total bill
sex: (categorical) Male / Female
smoker: (categorical) smoker
day: (categorical) day of the week
time: (categorical) Dinner / Lunch
size: (integer) size of the party
This data is taken from Seaborn’s data repository.
Again, let’s take a glimpse of the first few instances in the dataset:
We will also use a third dataset, which we will only use for the exercises in this article. For now, we will load it and keep it aside.
This is a synthetic dataset composed of 5 feature variables called
feature5. It has two target variables called
qns = load_commonlounge_dataset('corr_qns')
To explore correlation between two numerical variables, we will use two different kinds of plots: Scatter plots and Hexbin plots.
tips dataset has a output variable which is continuous and some input variables which are also continuous. Let us plot them against each other and see if there is a trend.
We’ll start by plotting the
tips variable against the `totalbill`_ variable.
To do so, we will use the
jointplot() function from the
Here’s the syntax for the function:
Let’s give it a try:
sns.jointplot(tips_data["total_bill"], tips_data["tip"]) plt.show()
As you can see, apart from plotting a scatter plot, the
jointplot() function also plots the univariate distribution of x and y on the x and y axis.
- We can infer from this graph that there is indeed a positive correlation between
total_bill. However, it looks like the correlation is stronger for lower values of
total_billthan the higher values.
- For lower values (between about 0 to 20), the values of
tiprises steadily with an increase in value of
- But as the value of
total_billincreases (between 30 to 50), we see the variance in
tipincreasing. Some values of
tipare high for high
total_billwhile some values of
tipare very low. So the correlation at this stage is not very strong.
- We can also see that
total_billis more skewed than
tip, which is relatively more symmetric, centered roughly around 3 dollars.
These insights would have been hard to extract without visualisation.
As we saw in the graph above, the data-points in a scatter plot may overlap each other. This makes it harder to gauge the density of points in certain regions of the plot.
The hexbin plot helps overcome this problem.
Let us use the same variables to plot a hexbin plot. We will use the same
jointplot() function, but this time we will pass the
kind parameter the value
"hex". Here’s the syntax:
sns.jointplot(xdata, ydata, kind="hex")
Note : The default argument for
"scatter". That is why it plots a scatter plot when no argument is passed to this parameter.
sns.jointplot(tips_data["total_bill"], tips_data["tip"], kind="hex") plt.show()
Hexbin plots are equivalent of bar graphs, but for a pair of variables. It creates hexagonal bins, dividing the data in both x and y axis into intervals. The darker the colour of the hexagon, the higher is the frequency in that interval.
- As we can again see, there is a strong cluster of points at the lower values of
total_bill, while the data is much more spread out at the higher values.
- This means that the
tipis very likely to be $2-$4 if the
total_billis between $10 and$20. But if the
total_billis between $40 and$50, the tip could range anywhere between $3 and$10, all values being more or less equally likely.
So far for plotting, we have been using the following syntax:
Since we have a dataframe with all the data, this usually translates to:
sns.jointplot(df["x column name"], df["y column name"])
When the x and y data are stored in the same DataFrame, seaborn also supports another syntax, which is:
sns.jointplot(x="x column name", y="y column name", data=df)
These two syntax are equivalent, and all the seaborn functions discussed in this tutorial support both.
To explore correlation between a categorical and a numerical variable, we will use Strip plots, Swarm plots, Boxplots and Violin Plots. All the these plots will be drawn using functions from the
For this section, we will use both the datasets.
wine dataset, the output variable
Wine is categorical, while all the input features are numerical.
tips dataset, the output variable
tips is numerical, while the input variables are categorical.
Let us explore these cases visually and see if we can draw some inferences.
Strip plots are similar to scatter plots, but for the situation when one of the variables is categorical.
We will be using the
stripplot() function. The syntax is:
sns.stripplot(x="column name", y="column name", data=DataFrame)
Let us plot the output variable
tip against the input variable
sns.stripplot(x="time", y="tip", data=tips_data) plt.show()
Although we can see the data points, they are not very clear because so many of them overlap. This makes it difficult to gauge the density of the data points.
Solution: We can fix this by using the parameter called
jitter can take in float value for the amount of jitter, or take a boolean value where
1 has a default value of jitter associated with it, and
0 means no jitter.
The default value of
Let’s re-draw the plot with the added parameter:
sns.stripplot(x="time", y="tip", data=tips_data, jitter=0.25) plt.show()
The plot looks much better now.
- Although the difference isn’t drastic, it looks like dinner tips on average are higher than lunch tips.
- We can also see that the number of data points is more for dinner.
Let us now plot the output variable
Wine against the input variable
sns.stripplot(x="Alcohol", y="Wine", data=wine_data, jitter=0.25) plt.show()
We can see a big problem in the graph.
Since we plotted the categorical variable on the y-axis and our categorical labels are integers, the
stripplot() function is not able to infer which variable is actually categorical, and what is the orientation.
Solution: To fix this, we can use the
orient parameter to explicitly tell the function to plot the graph horizontally.
Valid values for orient are
"v" (orient vertically) and
"h" (orient horizontally).
Let us re-draw the plot with the added parameter:
sns.stripplot(x="Alcohol", y="Wine", data=wine_data, orient="h", jitter=0.25) plt.show()
That’s much better!
We can very clearly see that there is a relationship between
Alcohol levels and category of
- Wines of category 2 have much lower
Alcohollevels (average around 12%).
- Wines of category 1 have the highest
Alcohollevels (average around 13.5% - 14%).
- And wines of category 3 are somewhere in the middle. (average
Alcohollevel around 13% - 13.5%).
One of the major problems with strip plots was the overlap of data points, which made it difficult to understand the density of points in some places.
Swarm plots are similar to strip plots, but the points are adjusted so that they don’t overlap.
We will be using the
swarmplot() function. The syntax is:
sns.swarmplot(x="column name", y="column name", data=DataFrame)
Let us plot the variables
sns.swarmplot(x="day", y="tip", data=tips_data) plt.show()
We can see, the plots are much cleaner than the strip plots. The breadth of the plots gives us a good sense of where most of the data points lie.
We can draw the following inferences:
- Majority of the data seems to be from the weekend.
- On average, tips on weekends seem to be higher than tips on weekdays.
- We can see a lot of points on the same horizontal line. That’s probably because people like paying tips in round values such as $2,$3 , $4 or$2.50, $3.50 and so forth.
You should have already come across boxplots before in a separate tutorial. But in case you skipped it, we’ve provided a short recap.
Here’s a sample boxplot for a toy data set having just 10 observations – 144, 147, 153, 154, 156, 157, 161, 164, 170, 181.
Box plot for the toy dataset
A box plot has the following components:
- 1st Quartile (25th percentile)
- Median (2nd Quartile or 50th percentile)
- 3rd Quartile (75th percentile)
- Interquartile Range (IQR – difference between 3rd and 1st Quartile)
- Whiskers — marks the lowest data point which lies within 1.5 IQR of the 1st quartile, and the highest data point which lies within 1.5 IQR of the 3rd quartile
- Outliers — any data points beyond 1.5 IQR of the 1st or 3rd quartile, i.e. values which are greater than 3rd Quartile + 1.5 * IQR or less than 1st Quartile – 1.5 * IQR.
In other words, whiskers extend from the quartiles to the rest of the distribution, except for points that are determined to be “outliers”.
The syntax for plotting a boxplot is:
sns.boxplot(x="column name", y="column name", data=DataFrame, orient="orientation")
orient parameter is used exactly for the same reason as we saw in the Swarm plot section — that is, to explicitly provide information about which variable is categorical.
Let us plot a graph between
sns.boxplot(x="Proline", y="Wine", data=wine_data, orient="h") plt.show()
As we can see, the plot gives us a good view of the relation between proline and wine category, without actually plotting the individual data points.
- For category 1
Wine, the median
Prolinelevel is higher than any value of
Prolinefrom category 2 and 3
Wine. Thus high values of
Prolinedirectly imply that the wine is from category 1.
- For low levels of
Prolinemeans, the overlap between category 2 and 3 is quite high, so we can’t say much conclusively about them from this plot.
Violinplots are similar to boxplots, but they have the added advantage that it also shows the actual underlying distribution of the data.
Here’s what a sample violin plot looks like:
Note: Violin plot displays an estimate of the distribution of data using a method called kernel density estimation. This estimation procedure is influenced by the sample size, and violins for relatively small samples might look misleadingly smooth.
The syntax is as follows:
sns.violinplot(x="column name", y="column name", data=DataFrame, orient="orientation")
Let us plot again plot the
Proline variable with
Wine, but this time it’s a violinplot:
sns.violinplot(x="Proline", y="Wine", data=wine_data, orient="h") plt.show()
- From this plot, we could make all the conclusions we made earlier from the boxplot.
- In addition, we can see the
Prolinelevels for category 2 vs category 3
Winesmuch more clearly, and although there is an overlap, the modes are slightly separated.
orient parameter can be used for all the plots mentioned in this section - Strip plots, Swarm plots, Box plots and Violin plots.
- Correlation coefficient (ranges from -1 to +1) tells us about the relationship between two variables.
- Its magnitude signifies how strongly related the variables are.
- Its sign signifies the direction of relationship (positive vs negative).
- Scatterplots and hexbin plots visualize two numerical variables.
- Stripplots and swarmplots visualize a numerical variable with a categorical variable.
- Boxplots and violinplots also visualize a numerical variable with a categorical variable, but they don’t plot every data point, and instead display quartiles, density, etc.
orientparameter is used to explicitly specify which variable is categorical.
Two numerical variables:
sns.jointplot(xdata, ydata, kind="hex")
One numerical variable and one categorical variable:
sns.stripplot(xdata, ydata, jitter=float, orient="h" or "v")
sns.swarmplot(xdata, ydata, orient="h")
sns.boxplot(xdata, ydata, orient="h")
sns.violinplot(xdata, ydata, orient="h")