Correlation Heatmaps: A Comprehensive Guide
Introduction
A correlation heatmap is a powerful visualization tool used to analyze and understand relationships between numerical variables in a dataset. It represents correlation coefficients using color gradients, making it easy to detect strong or weak relationships.
Why Use a Correlation Heatmap?
✔ Visualizes relationships between numerical features.
✔ Identifies multicollinearity (highly correlated features).
✔ Helps in feature selection for machine learning models.
✔ Detects hidden patterns in datasets.
I. Understanding Correlation
What is Correlation?
Correlation measures the strength and direction of the relationship between two variables. It is represented by the correlation coefficient (r), which ranges from -1 to +1:
Correlation Value (r) | Relationship Type | Interpretation |
---|---|---|
+1.0 | Perfect Positive Correlation | As one variable increases, the other increases proportionally. |
0.7 to 0.9 | Strong Positive Correlation | Strong upward trend. |
0.4 to 0.6 | Moderate Positive Correlation | Moderate upward trend. |
0.1 to 0.3 | Weak Positive Correlation | Slight upward trend. |
0.0 | No Correlation | No relationship between variables. |
-0.1 to -0.3 | Weak Negative Correlation | Slight downward trend. |
-0.4 to -0.6 | Moderate Negative Correlation | Moderate downward trend. |
-0.7 to -0.9 | Strong Negative Correlation | Strong downward trend. |
-1.0 | Perfect Negative Correlation | As one variable increases, the other decreases proportionally. |
📌 Important: Correlation does not imply causation!
Step 1: Import Required Libraries
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
Step 2: Load the Dataset
We’ll use the Titanic dataset for demonstration.
df = pd.read_csv("https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv")
df.head()
Step 3: Selecting Numerical Columns
Since correlation applies only to numerical data, we select relevant columns.
df_numeric = df.select_dtypes(include=[np.number]) # Selecting only numerical columns
df_numeric.head()
Step 4: Compute the Correlation Matrix
correlation_matrix = df_numeric.corr()
print(correlation_matrix)
✅ What does this do?
- Computes Pearson correlation coefficients (default).
- Displays correlation values between numerical features.
Step 5: Creating a Basic Correlation Heatmap
plt.figure(figsize=(10,6))
sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5)
plt.title("Correlation Heatmap")
plt.show()
✅ Interpretation:
- Red regions → Strong positive correlation.
- Blue regions → Strong negative correlation.
- White/Gray regions → Weak or no correlation.
- Diagonal values = 1.0, since each variable is perfectly correlated with itself.
Step 6: Customizing the Heatmap for Better Readability
plt.figure(figsize=(12,8))
sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f", linewidths=1, vmin=-1, vmax=1, square=True)
plt.title("Enhanced Correlation Heatmap")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.show()
✅ Enhancements:
- fmt=”.2f” → Displays values with 2 decimal places.
- linewidths=1 → Increases grid line visibility.
- vmin=-1, vmax=1 → Keeps color scale balanced.
- square=True → Ensures a neat square shape.
Step 7: Handling High Correlation (Feature Selection)
If two features are highly correlated (above 0.9 or below -0.9), one of them might be redundant.
high_corr_features = np.where(abs(correlation_matrix) > 0.9)
high_corr_features = [(correlation_matrix.index[x], correlation_matrix.columns[y])
for x, y in zip(*high_corr_features) if x != y]
print("Highly Correlated Feature Pairs:", high_corr_features)
✅ What does this do?
- Detects feature pairs with strong correlation (>0.9 or <-0.9).
- Helps decide which feature to remove in machine learning.
Step 8: Filtering Out Redundant Features
We can remove one of the correlated features to avoid multicollinearity.
correlation_threshold = 0.9 # Set a threshold
correlated_features = set()
for i in range(len(correlation_matrix.columns)):
for j in range(i):
if abs(correlation_matrix.iloc[i, j]) > correlation_threshold:
colname = correlation_matrix.columns[i]
correlated_features.add(colname)
df_filtered = df_numeric.drop(columns=correlated_features)
print("Reduced Dataset Features:", df_filtered.columns)
✅ Benefits:
- Prevents overfitting in machine learning models.
- Reduces data dimensionality for efficiency.
Key Takeaways
✔ Correlation heatmaps help visualize relationships between numerical variables.
✔ Red & Blue regions highlight strong correlations (positive & negative).
✔ Feature selection is essential to remove redundant variables.
✔ High correlation can lead to multicollinearity issues in ML models.