Handling Imbalanced Data in Machine Learning
Introduction
Imbalanced data occurs when the distribution of classes in a dataset is highly skewed, meaning one class has significantly more samples than the other(s). This issue is particularly common in real-world scenarios such as fraud detection, medical diagnosis, and anomaly detection. If not handled properly, machine learning models tend to be biased towards the majority class, leading to poor performance in predicting the minority class.
In this guide, we will explore:
- The effects of imbalanced data
- Different techniques to handle imbalanced data
- Step-by-step implementation in Python
Step 1: Understanding Imbalanced Data
What is Imbalanced Data?
A dataset is considered imbalanced when one class significantly outweighs the other(s).
Example:
Consider a fraud detection dataset where:
- 95% of transactions are legitimate (majority class)
- 5% of transactions are fraudulent (minority class)
If we train a model without handling this imbalance, it may predict all transactions as legitimate and still achieve 95% accuracy, but fail to detect fraudulent cases effectively.
Effects of Imbalanced Data
- Biased Model: The model favors the majority class and ignores the minority class.
- Poor Recall for Minority Class: The model fails to detect minority class instances.
- Misleading Accuracy: High accuracy may be achieved simply by predicting the majority class.
Step 2: Techniques to Handle Imbalanced Data
There are several techniques to address the problem of imbalanced data.
A. Resampling Techniques
Resampling methods aim to balance the dataset by modifying the sample distribution.
1. Oversampling (Increasing Minority Class Samples)
Oversampling involves increasing the number of samples in the minority class by either duplicating existing samples or generating synthetic samples.
a) Random Oversampling
Randomly duplicate instances from the minority class.
✅ Simple and effective for small datasets.
❌ May lead to overfitting.
Implementation in Python:
from imblearn.over_sampling import RandomOverSampler
from collections import Counter
# Sample imbalanced dataset
X = [[i] for i in range(10)]
y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] # Imbalanced (5:5 for simplicity)
# Apply random oversampling
ros = RandomOverSampler(random_state=42)
X_resampled, y_resampled = ros.fit_resample(X, y)
print("Class distribution after oversampling:", Counter(y_resampled))
b) SMOTE (Synthetic Minority Over-sampling Technique)
SMOTE generates synthetic samples for the minority class using K-Nearest Neighbors.
✅ Reduces overfitting risk compared to random oversampling.
❌ Can create noise if the minority class is too sparse.
Implementation in Python:
from imblearn.over_sampling import SMOTE
# Apply SMOTE
smote = SMOTE(random_state=42)
X_smote, y_smote = smote.fit_resample(X, y)
print("Class distribution after SMOTE:", Counter(y_smote))
2. Undersampling (Reducing Majority Class Samples)
Undersampling reduces the number of samples in the majority class to balance the dataset.
a) Random Undersampling
Randomly removes instances from the majority class.
✅ Helps reduce dataset size, making training faster.
❌ Risk of losing important information.
Implementation in Python:
from imblearn.under_sampling import RandomUnderSampler
# Apply random undersampling
rus = RandomUnderSampler(random_state=42)
X_under, y_under = rus.fit_resample(X, y)
print("Class distribution after undersampling:", Counter(y_under))
b) NearMiss Algorithm
NearMiss selects majority class samples that are closest to minority class samples.
✅ Retains informative samples instead of random removal.
❌ Computationally expensive.
from imblearn.under_sampling import NearMiss
# Apply NearMiss
nm = NearMiss(version=1)
X_nm, y_nm = nm.fit_resample(X, y)
print("Class distribution after NearMiss:", Counter(y_nm))
B. Algorithmic Approaches
Instead of modifying the data, we can modify algorithms to handle class imbalance.
1. Changing Class Weights (Cost-Sensitive Learning)
Assign higher penalties to the misclassification of the minority class by adjusting class weights in models like Logistic Regression, SVM, or Random Forest.
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
# Sample dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train model with class weights
model = RandomForestClassifier(class_weight='balanced', random_state=42)
model.fit(X_train, y_train)
# Predictions
y_pred = model.predict(X_test)
# Evaluation
print(classification_report(y_test, y_pred))
✅ Prevents bias towards the majority class.
❌ May not work well if class imbalance is extreme.
2. Anomaly Detection Models
For extreme imbalances (e.g., fraud detection), treat the minority class as an “anomaly” using models like Isolation Forest or One-Class SVM.
C. Evaluation Metrics for Imbalanced Data
Accuracy is misleading in imbalanced datasets. Instead, use:
- Precision – How many predicted positive cases were actually positive? Precision=TPTP+FP\text{Precision} = \frac{TP}{TP + FP}
- Recall (Sensitivity) – How many actual positive cases were correctly predicted? Recall=TPTP+FN\text{Recall} = \frac{TP}{TP + FN}
- F1-Score – Harmonic mean of Precision & Recall. F1=2×Precision×RecallPrecision+RecallF1 = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}
- ROC-AUC Score – Measures overall model performance.
- Confusion Matrix – Helps visualize true positives, false positives, false negatives, and true negatives.
Implementation in Python:
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
# Evaluate Model
print(classification_report(y_test, y_pred))
print("ROC-AUC Score:", roc_auc_score(y_test, y_pred))
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
Step 3: Choosing the Right Method
Scenario | Recommended Technique |
---|---|
Small dataset | Random Oversampling, SMOTE |
Large dataset | Random Undersampling |
Extreme imbalance (e.g., 99:1) | Cost-sensitive learning, Anomaly detection |
Interpretability is important | Class weighting in models |
High computational power available | SMOTE or NearMiss |