Back to Blog
6 min read

Data Drift Detection for ML Models

Data drift occurs when the statistical properties of production data differ from training data. Detecting drift early prevents model performance degradation.

Types of Data Drift

  • Covariate Shift: Input feature distributions change
  • Prior Probability Shift: Target distribution changes
  • Concept Drift: Relationship between features and target changes

Statistical Methods for Drift Detection

import numpy as np
import pandas as pd
from scipy import stats
from typing import Dict, Tuple
import warnings

class StatisticalDriftDetector:
    """Detect data drift using statistical tests"""

    def __init__(self, reference_data: pd.DataFrame):
        self.reference_data = reference_data
        self.feature_stats = self._compute_stats(reference_data)

    def _compute_stats(self, df: pd.DataFrame) -> Dict:
        """Compute statistics for each feature"""
        stats_dict = {}
        for col in df.columns:
            if df[col].dtype in ['int64', 'float64']:
                stats_dict[col] = {
                    'type': 'numerical',
                    'mean': df[col].mean(),
                    'std': df[col].std(),
                    'min': df[col].min(),
                    'max': df[col].max(),
                    'percentiles': df[col].quantile([0.25, 0.5, 0.75]).tolist()
                }
            else:
                stats_dict[col] = {
                    'type': 'categorical',
                    'value_counts': df[col].value_counts(normalize=True).to_dict()
                }
        return stats_dict

    def kolmogorov_smirnov_test(self, feature: str, production_data: pd.Series) -> Tuple[float, float]:
        """KS test for numerical features"""
        ref_data = self.reference_data[feature].dropna()
        prod_data = production_data.dropna()
        statistic, p_value = stats.ks_2samp(ref_data, prod_data)
        return statistic, p_value

    def chi_squared_test(self, feature: str, production_data: pd.Series) -> Tuple[float, float]:
        """Chi-squared test for categorical features"""
        ref_counts = self.reference_data[feature].value_counts()
        prod_counts = production_data.value_counts()

        # Align categories
        all_cats = set(ref_counts.index) | set(prod_counts.index)
        ref_aligned = np.array([ref_counts.get(c, 0) for c in all_cats])
        prod_aligned = np.array([prod_counts.get(c, 0) for c in all_cats])

        # Normalize to get expected frequencies
        total_ref = ref_aligned.sum()
        total_prod = prod_aligned.sum()
        expected = ref_aligned / total_ref * total_prod

        # Avoid zero expected values
        expected = np.where(expected < 1, 1, expected)

        statistic, p_value = stats.chisquare(prod_aligned, expected)
        return statistic, p_value

    def population_stability_index(self, feature: str, production_data: pd.Series, bins: int = 10) -> float:
        """Calculate PSI for numerical features"""
        ref_data = self.reference_data[feature].dropna()
        prod_data = production_data.dropna()

        # Create bins based on reference data
        _, bin_edges = np.histogram(ref_data, bins=bins)

        # Calculate histograms
        ref_hist, _ = np.histogram(ref_data, bins=bin_edges)
        prod_hist, _ = np.histogram(prod_data, bins=bin_edges)

        # Convert to percentages
        ref_pct = ref_hist / len(ref_data)
        prod_pct = prod_hist / len(prod_data)

        # Avoid log(0)
        ref_pct = np.where(ref_pct == 0, 0.0001, ref_pct)
        prod_pct = np.where(prod_pct == 0, 0.0001, prod_pct)

        # Calculate PSI
        psi = np.sum((prod_pct - ref_pct) * np.log(prod_pct / ref_pct))
        return psi

    def jensen_shannon_divergence(self, feature: str, production_data: pd.Series, bins: int = 10) -> float:
        """Calculate Jensen-Shannon divergence"""
        ref_data = self.reference_data[feature].dropna()
        prod_data = production_data.dropna()

        # Create histograms
        min_val = min(ref_data.min(), prod_data.min())
        max_val = max(ref_data.max(), prod_data.max())
        bin_edges = np.linspace(min_val, max_val, bins + 1)

        ref_hist, _ = np.histogram(ref_data, bins=bin_edges, density=True)
        prod_hist, _ = np.histogram(prod_data, bins=bin_edges, density=True)

        # Normalize
        ref_hist = ref_hist / ref_hist.sum()
        prod_hist = prod_hist / prod_hist.sum()

        # JS divergence
        m = 0.5 * (ref_hist + prod_hist)
        js = 0.5 * stats.entropy(ref_hist, m) + 0.5 * stats.entropy(prod_hist, m)

        return js

    def detect_all_drift(self, production_data: pd.DataFrame, p_threshold: float = 0.05) -> pd.DataFrame:
        """Detect drift for all features"""
        results = []

        for col in self.reference_data.columns:
            if col not in production_data.columns:
                continue

            feature_type = self.feature_stats[col]['type']

            if feature_type == 'numerical':
                ks_stat, ks_p = self.kolmogorov_smirnov_test(col, production_data[col])
                psi = self.population_stability_index(col, production_data[col])
                js = self.jensen_shannon_divergence(col, production_data[col])

                results.append({
                    'feature': col,
                    'type': 'numerical',
                    'ks_statistic': ks_stat,
                    'ks_p_value': ks_p,
                    'psi': psi,
                    'js_divergence': js,
                    'drift_detected': ks_p < p_threshold or psi > 0.2
                })
            else:
                chi_stat, chi_p = self.chi_squared_test(col, production_data[col])

                results.append({
                    'feature': col,
                    'type': 'categorical',
                    'chi_statistic': chi_stat,
                    'chi_p_value': chi_p,
                    'drift_detected': chi_p < p_threshold
                })

        return pd.DataFrame(results)

Window-Based Drift Detection

class WindowDriftDetector:
    """Detect drift using sliding windows"""

    def __init__(self, window_size: int = 1000, step_size: int = 100):
        self.window_size = window_size
        self.step_size = step_size
        self.baseline_window = None

    def set_baseline(self, data: np.ndarray):
        """Set baseline window"""
        self.baseline_window = data[-self.window_size:]

    def detect_drift_adwin(self, data: np.ndarray) -> Dict:
        """ADWIN-style drift detection"""
        results = []

        for i in range(0, len(data) - self.window_size, self.step_size):
            current_window = data[i:i + self.window_size]

            # Compare with baseline
            stat, p_value = stats.ks_2samp(self.baseline_window, current_window)

            results.append({
                'window_start': i,
                'window_end': i + self.window_size,
                'ks_statistic': stat,
                'p_value': p_value,
                'drift_detected': p_value < 0.05
            })

            # Update baseline if no drift (adaptive)
            if p_value >= 0.05:
                self.baseline_window = current_window

        return pd.DataFrame(results)

    def detect_drift_cusum(self, data: np.ndarray, threshold: float = 5.0) -> Dict:
        """CUSUM (Cumulative Sum) drift detection"""
        if self.baseline_window is None:
            raise ValueError("Baseline not set")

        baseline_mean = np.mean(self.baseline_window)
        baseline_std = np.std(self.baseline_window)

        # Normalize data
        normalized = (data - baseline_mean) / baseline_std

        # Calculate CUSUM
        cusum_pos = np.zeros(len(normalized))
        cusum_neg = np.zeros(len(normalized))

        for i in range(1, len(normalized)):
            cusum_pos[i] = max(0, cusum_pos[i-1] + normalized[i] - 0.5)
            cusum_neg[i] = max(0, cusum_neg[i-1] - normalized[i] - 0.5)

        drift_points = np.where((cusum_pos > threshold) | (cusum_neg > threshold))[0]

        return {
            'cusum_positive': cusum_pos,
            'cusum_negative': cusum_neg,
            'drift_points': drift_points.tolist(),
            'drift_detected': len(drift_points) > 0
        }

Visualization

import matplotlib.pyplot as plt

def visualize_drift(reference_data, production_data, feature_name, drift_result):
    """Visualize drift for a feature"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Distribution comparison
    axes[0].hist(reference_data, bins=30, alpha=0.5, label='Reference', density=True)
    axes[0].hist(production_data, bins=30, alpha=0.5, label='Production', density=True)
    axes[0].legend()
    axes[0].set_title(f'{feature_name} Distribution')

    # Q-Q plot
    ref_sorted = np.sort(reference_data)
    prod_sorted = np.sort(production_data)
    # Interpolate to same length
    interp_prod = np.interp(
        np.linspace(0, 1, len(ref_sorted)),
        np.linspace(0, 1, len(prod_sorted)),
        prod_sorted
    )
    axes[1].scatter(ref_sorted, interp_prod, alpha=0.5)
    axes[1].plot([ref_sorted.min(), ref_sorted.max()],
                 [ref_sorted.min(), ref_sorted.max()], 'r--')
    axes[1].set_xlabel('Reference Quantiles')
    axes[1].set_ylabel('Production Quantiles')
    axes[1].set_title('Q-Q Plot')

    # Statistics
    stats_text = f"KS Statistic: {drift_result.get('ks_statistic', 'N/A'):.4f}\n"
    stats_text += f"P-Value: {drift_result.get('ks_p_value', 'N/A'):.4f}\n"
    stats_text += f"PSI: {drift_result.get('psi', 'N/A'):.4f}\n"
    stats_text += f"Drift: {'Yes' if drift_result.get('drift_detected') else 'No'}"
    axes[2].text(0.5, 0.5, stats_text, ha='center', va='center',
                 fontsize=12, transform=axes[2].transAxes)
    axes[2].axis('off')
    axes[2].set_title('Drift Statistics')

    plt.tight_layout()
    return fig

# Usage
detector = StatisticalDriftDetector(reference_df)
drift_results = detector.detect_all_drift(production_df)

for _, row in drift_results.iterrows():
    if row['drift_detected']:
        fig = visualize_drift(
            reference_df[row['feature']].dropna(),
            production_df[row['feature']].dropna(),
            row['feature'],
            row.to_dict()
        )
        fig.savefig(f"drift_{row['feature']}.png")

Automated Drift Monitoring Pipeline

from datetime import datetime, timedelta
import schedule
import time

class DriftMonitoringPipeline:
    def __init__(self, reference_data, data_source, alert_callback):
        self.detector = StatisticalDriftDetector(reference_data)
        self.data_source = data_source
        self.alert_callback = alert_callback
        self.drift_history = []

    def run_detection(self):
        """Run drift detection on recent data"""
        # Get recent production data
        production_data = self.data_source.get_recent_data(hours=24)

        if len(production_data) < 100:
            print("Insufficient data for drift detection")
            return

        # Detect drift
        results = self.detector.detect_all_drift(production_data)
        drifted_features = results[results['drift_detected']]

        # Log results
        self.drift_history.append({
            'timestamp': datetime.utcnow().isoformat(),
            'total_features': len(results),
            'drifted_features': len(drifted_features),
            'details': results.to_dict('records')
        })

        # Alert if drift detected
        if len(drifted_features) > 0:
            self.alert_callback(drifted_features)

        return results

    def start(self, interval_hours: int = 6):
        """Start scheduled monitoring"""
        schedule.every(interval_hours).hours.do(self.run_detection)

        while True:
            schedule.run_pending()
            time.sleep(60)

# Usage
def send_alert(drifted_features):
    print(f"ALERT: Drift detected in {len(drifted_features)} features!")
    print(drifted_features[['feature', 'drift_detected']].to_string())

pipeline = DriftMonitoringPipeline(
    reference_data=training_df,
    data_source=production_data_source,
    alert_callback=send_alert
)

# Run once
results = pipeline.run_detection()
print(f"Drift detected in {len(results[results['drift_detected']])} features")

Early drift detection enables proactive model maintenance and prevents silent failures in production.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.