Skip to content
Back to Blog
1 min read

Concept Drift Detection in Machine Learning

I wrote “Concept Drift Detection in Machine Learning” to share practical, production-minded guidance on this topic.

Understanding Concept Drift Types

  • Sudden Drift: Abrupt change in the concept
  • Gradual Drift: Slow transition between concepts
  • Incremental Drift: Continuous small changes
  • Recurring Drift: Concepts that reappear periodically

Detection Methods

import numpy as np
import pandas as pd
from scipy import stats
from sklearn.metrics import accuracy_score, log_loss
from collections import deque

class ConceptDriftDetector:
    """Detect concept drift using model performance monitoring"""

    def __init__(self, model, baseline_performance: float, window_size: int = 500):
        self.model = model
        self.baseline_performance = baseline_performance
        self.window_size = window_size
        self.performance_window = deque(maxlen=window_size)

    def update(self, X, y_true) -> dict:
        """Update with new data and check for drift"""
        y_pred = self.model.predict(X)

        # Calculate performance for this batch
        batch_performance = accuracy_score(y_true, y_pred)
        self.performance_window.append(batch_performance)

        # Check for drift when we have enough data
        if len(self.performance_window) >= self.window_size // 2:
            return self._check_drift()

        return {"drift_detected": False, "status": "collecting_data"}

    def _check_drift(self) -> dict:
        """Check for concept drift using performance degradation"""
        window_performance = np.mean(list(self.performance_window))
        degradation = (self.baseline_performance - window_performance) / self.baseline_performance

        # Statistical test for performance drop
        performances = list(self.performance_window)
        t_stat, p_value = stats.ttest_1samp(performances, self.baseline_performance)

        drift_detected = degradation > 0.05 and p_value < 0.05

        return {
            "drift_detected": drift_detected,
            "baseline_performance": self.baseline_performance,
            "current_performance": window_performance,
            "degradation_pct": degradation * 100,
            "p_value": p_value
        }


class PageHinkleyTest:
    """Page-Hinkley test for concept drift detection"""

    def __init__(self, min_instances: int = 30, delta: float = 0.005, threshold: float = 50):
        self.min_instances = min_instances
        self.delta = delta
        self.threshold = threshold
        self.reset()

    def reset(self):
        """Reset the detector"""
        self.n = 0
        self.sum = 0.0
        self.x_mean = 0.0
        self.p_h_sum = 0.0
        self.p_h_min = float('inf')

    def update(self, value: float) -> dict:
        """Update with new value and check for drift"""
        self.n += 1

        # Update mean
        self.x_mean = self.x_mean + (value - self.x_mean) / self.n

        # Update cumulative sum
        self.sum += value - self.x_mean - self.delta

        # Update Page-Hinkley values
        self.p_h_sum = self.sum
        self.p_h_min = min(self.p_h_min, self.p_h_sum)

        # Check for drift
        p_h_value = self.p_h_sum - self.p_h_min
        drift_detected = self.n >= self.min_instances and p_h_value > self.threshold

        return {
            "drift_detected": drift_detected,
            "p_h_value": p_h_value,
            "threshold": self.threshold,
            "n_samples": self.n
        }


class DDM:
    """Drift Detection Method (DDM)"""

    def __init__(self, min_instances: int = 30, warning_level: float = 2.0, drift_level: float = 3.0):
        self.min_instances = min_instances
        self.warning_level = warning_level
        self.drift_level = drift_level
        self.reset()

    def reset(self):
        """Reset the detector"""
        self.n = 0
        self.p = 0.0  # Error rate
        self.s = 0.0  # Standard deviation
        self.p_min = float('inf')
        self.s_min = float('inf')

    def update(self, prediction_correct: bool) -> dict:
        """Update with prediction result"""
        self.n += 1

        # Update error rate
        if prediction_correct:
            error = 0
        else:
            error = 1

        self.p = self.p + (error - self.p) / self.n
        self.s = np.sqrt(self.p * (1 - self.p) / self.n)

        # Update minimums
        if self.p + self.s < self.p_min + self.s_min:
            self.p_min = self.p
            self.s_min = self.s

        # Check thresholds
        if self.n < self.min_instances:
            return {"status": "collecting", "drift_detected": False, "warning": False}

        drift_threshold = self.p_min + self.drift_level * self.s_min
        warning_threshold = self.p_min + self.warning_level * self.s_min

        drift_detected = self.p + self.s > drift_threshold
        warning = self.p + self.s > warning_threshold and not drift_detected

        if drift_detected:
            self.reset()

        return {
            "drift_detected": drift_detected,
            "warning": warning,
            "error_rate": self.p,
            "threshold": drift_threshold,
            "n_samples": self.n
        }

ADWIN Algorithm

class ADWIN:
    """Adaptive Windowing for drift detection"""

    def __init__(self, delta: float = 0.002, max_buckets: int = 5):
        self.delta = delta
        self.max_buckets = max_buckets
        self.reset()

    def reset(self):
        """Reset the detector"""
        self.bucket_sizes = []
        self.bucket_totals = []
        self.total = 0.0
        self.variance = 0.0
        self.width = 0

    def update(self, value: float) -> dict:
        """Add new value and check for drift"""
        self._add_element(value)
        drift_detected = self._detect_change()

        return {
            "drift_detected": drift_detected,
            "window_size": self.width,
            "mean": self.total / self.width if self.width > 0 else 0
        }

    def _add_element(self, value: float):
        """Add element to window"""
        self.width += 1
        self.total += value

        # Add to buckets (simplified ADWIN)
        self.bucket_sizes.append(1)
        self.bucket_totals.append(value)

        # Merge buckets if needed
        self._compress_buckets()

    def _compress_buckets(self):
        """Compress buckets to maintain logarithmic space"""
        i = 0
        while i < len(self.bucket_sizes) - 1:
            if self.bucket_sizes[i] == self.bucket_sizes[i + 1]:
                # Merge buckets
                self.bucket_sizes[i] *= 2
                self.bucket_totals[i] += self.bucket_totals[i + 1]
                del self.bucket_sizes[i + 1]
                del self.bucket_totals[i + 1]
            else:
                i += 1

    def _detect_change(self) -> bool:
        """Detect if there's a significant change in the window"""
        if self.width < 10:
            return False

        # Compare first and second half means
        half = self.width // 2
        first_half_total = sum(self.bucket_totals[:half])
        second_half_total = sum(self.bucket_totals[half:])

        mean_diff = abs(first_half_total / half - second_half_total / (self.width - half))

        # Simplified threshold
        threshold = np.sqrt(np.log(2 / self.delta) / (2 * half))

        if mean_diff > threshold:
            # Remove old elements
            self.bucket_sizes = self.bucket_sizes[half:]
            self.bucket_totals = self.bucket_totals[half:]
            self.width = self.width - half
            self.total = sum(self.bucket_totals)
            return True

        return False

Combining Multiple Detectors

class EnsembleDriftDetector:
    """Ensemble of multiple drift detection methods"""

    def __init__(self, model, baseline_performance: float):
        self.performance_detector = ConceptDriftDetector(model, baseline_performance)
        self.ph_detector = PageHinkleyTest()
        self.ddm_detector = DDM()
        self.adwin_detector = ADWIN()

    def update(self, X, y_true, y_pred) -> dict:
        """Update all detectors and return consensus"""
        # Performance-based detection
        perf_result = self.performance_detector.update(X, y_true)

        # Error-based detection
        errors = (y_pred != y_true).astype(float)
        error_rate = np.mean(errors)

        ph_result = self.ph_detector.update(error_rate)
        adwin_result = self.adwin_detector.update(error_rate)

        # Per-prediction detection
        ddm_results = [self.ddm_detector.update(pred == true)
                       for pred, true in zip(y_pred, y_true)]
        ddm_drift = any(r['drift_detected'] for r in ddm_results)

        # Voting
        votes = [
            perf_result['drift_detected'],
            ph_result['drift_detected'],
            adwin_result['drift_detected'],
            ddm_drift
        ]

        drift_detected = sum(votes) >= 2  # Majority voting

        return {
            "drift_detected": drift_detected,
            "votes": sum(votes),
            "details": {
                "performance": perf_result,
                "page_hinkley": ph_result,
                "adwin": adwin_result,
                "ddm": ddm_drift
            }
        }

Visualization

import matplotlib.pyplot as plt

def visualize_concept_drift(timestamps, error_rates, drift_points, title="Concept Drift Detection"):
    """Visualize error rate and detected drift points"""
    fig, ax = plt.subplots(figsize=(12, 5))

    ax.plot(timestamps, error_rates, label='Error Rate', color='blue')

    # Mark drift points
    for drift_time in drift_points:
        ax.axvline(x=drift_time, color='red', linestyle='--', alpha=0.7, label='Drift Detected')

    ax.set_xlabel('Time')
    ax.set_ylabel('Error Rate')
    ax.set_title(title)
    ax.legend()

    plt.tight_layout()
    return fig

# Simulated usage
timestamps = list(range(1000))
error_rates = np.concatenate([
    np.random.normal(0.1, 0.02, 400),  # Normal period
    np.random.normal(0.25, 0.03, 300),  # Drift period
    np.random.normal(0.12, 0.02, 300)   # Recovery
])
drift_points = [400, 700]

fig = visualize_concept_drift(timestamps, error_rates, drift_points)
plt.savefig("concept_drift.png")

Detecting concept drift enables timely model retraining to maintain prediction accuracy over time.\n\n## Takeaways\n\nAdd a concise, personal takeaway and recommended next steps here.\n

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.