Back to Blog
6 min read

Concept Drift Detection in Machine Learning

Concept drift occurs when the relationship between input features and the target variable changes over time. Unlike data drift, concept drift can occur even when input distributions remain stable.

Understanding Concept Drift Types

  • Sudden Drift: Abrupt change in the concept
  • Gradual Drift: Slow transition between concepts
  • Incremental Drift: Continuous small changes
  • Recurring Drift: Concepts that reappear periodically

Detection Methods

import numpy as np
import pandas as pd
from scipy import stats
from sklearn.metrics import accuracy_score, log_loss
from collections import deque

class ConceptDriftDetector:
    """Detect concept drift using model performance monitoring"""

    def __init__(self, model, baseline_performance: float, window_size: int = 500):
        self.model = model
        self.baseline_performance = baseline_performance
        self.window_size = window_size
        self.performance_window = deque(maxlen=window_size)

    def update(self, X, y_true) -> dict:
        """Update with new data and check for drift"""
        y_pred = self.model.predict(X)

        # Calculate performance for this batch
        batch_performance = accuracy_score(y_true, y_pred)
        self.performance_window.append(batch_performance)

        # Check for drift when we have enough data
        if len(self.performance_window) >= self.window_size // 2:
            return self._check_drift()

        return {"drift_detected": False, "status": "collecting_data"}

    def _check_drift(self) -> dict:
        """Check for concept drift using performance degradation"""
        window_performance = np.mean(list(self.performance_window))
        degradation = (self.baseline_performance - window_performance) / self.baseline_performance

        # Statistical test for performance drop
        performances = list(self.performance_window)
        t_stat, p_value = stats.ttest_1samp(performances, self.baseline_performance)

        drift_detected = degradation > 0.05 and p_value < 0.05

        return {
            "drift_detected": drift_detected,
            "baseline_performance": self.baseline_performance,
            "current_performance": window_performance,
            "degradation_pct": degradation * 100,
            "p_value": p_value
        }


class PageHinkleyTest:
    """Page-Hinkley test for concept drift detection"""

    def __init__(self, min_instances: int = 30, delta: float = 0.005, threshold: float = 50):
        self.min_instances = min_instances
        self.delta = delta
        self.threshold = threshold
        self.reset()

    def reset(self):
        """Reset the detector"""
        self.n = 0
        self.sum = 0.0
        self.x_mean = 0.0
        self.p_h_sum = 0.0
        self.p_h_min = float('inf')

    def update(self, value: float) -> dict:
        """Update with new value and check for drift"""
        self.n += 1

        # Update mean
        self.x_mean = self.x_mean + (value - self.x_mean) / self.n

        # Update cumulative sum
        self.sum += value - self.x_mean - self.delta

        # Update Page-Hinkley values
        self.p_h_sum = self.sum
        self.p_h_min = min(self.p_h_min, self.p_h_sum)

        # Check for drift
        p_h_value = self.p_h_sum - self.p_h_min
        drift_detected = self.n >= self.min_instances and p_h_value > self.threshold

        return {
            "drift_detected": drift_detected,
            "p_h_value": p_h_value,
            "threshold": self.threshold,
            "n_samples": self.n
        }


class DDM:
    """Drift Detection Method (DDM)"""

    def __init__(self, min_instances: int = 30, warning_level: float = 2.0, drift_level: float = 3.0):
        self.min_instances = min_instances
        self.warning_level = warning_level
        self.drift_level = drift_level
        self.reset()

    def reset(self):
        """Reset the detector"""
        self.n = 0
        self.p = 0.0  # Error rate
        self.s = 0.0  # Standard deviation
        self.p_min = float('inf')
        self.s_min = float('inf')

    def update(self, prediction_correct: bool) -> dict:
        """Update with prediction result"""
        self.n += 1

        # Update error rate
        if prediction_correct:
            error = 0
        else:
            error = 1

        self.p = self.p + (error - self.p) / self.n
        self.s = np.sqrt(self.p * (1 - self.p) / self.n)

        # Update minimums
        if self.p + self.s < self.p_min + self.s_min:
            self.p_min = self.p
            self.s_min = self.s

        # Check thresholds
        if self.n < self.min_instances:
            return {"status": "collecting", "drift_detected": False, "warning": False}

        drift_threshold = self.p_min + self.drift_level * self.s_min
        warning_threshold = self.p_min + self.warning_level * self.s_min

        drift_detected = self.p + self.s > drift_threshold
        warning = self.p + self.s > warning_threshold and not drift_detected

        if drift_detected:
            self.reset()

        return {
            "drift_detected": drift_detected,
            "warning": warning,
            "error_rate": self.p,
            "threshold": drift_threshold,
            "n_samples": self.n
        }

ADWIN Algorithm

class ADWIN:
    """Adaptive Windowing for drift detection"""

    def __init__(self, delta: float = 0.002, max_buckets: int = 5):
        self.delta = delta
        self.max_buckets = max_buckets
        self.reset()

    def reset(self):
        """Reset the detector"""
        self.bucket_sizes = []
        self.bucket_totals = []
        self.total = 0.0
        self.variance = 0.0
        self.width = 0

    def update(self, value: float) -> dict:
        """Add new value and check for drift"""
        self._add_element(value)
        drift_detected = self._detect_change()

        return {
            "drift_detected": drift_detected,
            "window_size": self.width,
            "mean": self.total / self.width if self.width > 0 else 0
        }

    def _add_element(self, value: float):
        """Add element to window"""
        self.width += 1
        self.total += value

        # Add to buckets (simplified ADWIN)
        self.bucket_sizes.append(1)
        self.bucket_totals.append(value)

        # Merge buckets if needed
        self._compress_buckets()

    def _compress_buckets(self):
        """Compress buckets to maintain logarithmic space"""
        i = 0
        while i < len(self.bucket_sizes) - 1:
            if self.bucket_sizes[i] == self.bucket_sizes[i + 1]:
                # Merge buckets
                self.bucket_sizes[i] *= 2
                self.bucket_totals[i] += self.bucket_totals[i + 1]
                del self.bucket_sizes[i + 1]
                del self.bucket_totals[i + 1]
            else:
                i += 1

    def _detect_change(self) -> bool:
        """Detect if there's a significant change in the window"""
        if self.width < 10:
            return False

        # Compare first and second half means
        half = self.width // 2
        first_half_total = sum(self.bucket_totals[:half])
        second_half_total = sum(self.bucket_totals[half:])

        mean_diff = abs(first_half_total / half - second_half_total / (self.width - half))

        # Simplified threshold
        threshold = np.sqrt(np.log(2 / self.delta) / (2 * half))

        if mean_diff > threshold:
            # Remove old elements
            self.bucket_sizes = self.bucket_sizes[half:]
            self.bucket_totals = self.bucket_totals[half:]
            self.width = self.width - half
            self.total = sum(self.bucket_totals)
            return True

        return False

Combining Multiple Detectors

class EnsembleDriftDetector:
    """Ensemble of multiple drift detection methods"""

    def __init__(self, model, baseline_performance: float):
        self.performance_detector = ConceptDriftDetector(model, baseline_performance)
        self.ph_detector = PageHinkleyTest()
        self.ddm_detector = DDM()
        self.adwin_detector = ADWIN()

    def update(self, X, y_true, y_pred) -> dict:
        """Update all detectors and return consensus"""
        # Performance-based detection
        perf_result = self.performance_detector.update(X, y_true)

        # Error-based detection
        errors = (y_pred != y_true).astype(float)
        error_rate = np.mean(errors)

        ph_result = self.ph_detector.update(error_rate)
        adwin_result = self.adwin_detector.update(error_rate)

        # Per-prediction detection
        ddm_results = [self.ddm_detector.update(pred == true)
                       for pred, true in zip(y_pred, y_true)]
        ddm_drift = any(r['drift_detected'] for r in ddm_results)

        # Voting
        votes = [
            perf_result['drift_detected'],
            ph_result['drift_detected'],
            adwin_result['drift_detected'],
            ddm_drift
        ]

        drift_detected = sum(votes) >= 2  # Majority voting

        return {
            "drift_detected": drift_detected,
            "votes": sum(votes),
            "details": {
                "performance": perf_result,
                "page_hinkley": ph_result,
                "adwin": adwin_result,
                "ddm": ddm_drift
            }
        }

Visualization

import matplotlib.pyplot as plt

def visualize_concept_drift(timestamps, error_rates, drift_points, title="Concept Drift Detection"):
    """Visualize error rate and detected drift points"""
    fig, ax = plt.subplots(figsize=(12, 5))

    ax.plot(timestamps, error_rates, label='Error Rate', color='blue')

    # Mark drift points
    for drift_time in drift_points:
        ax.axvline(x=drift_time, color='red', linestyle='--', alpha=0.7, label='Drift Detected')

    ax.set_xlabel('Time')
    ax.set_ylabel('Error Rate')
    ax.set_title(title)
    ax.legend()

    plt.tight_layout()
    return fig

# Simulated usage
timestamps = list(range(1000))
error_rates = np.concatenate([
    np.random.normal(0.1, 0.02, 400),  # Normal period
    np.random.normal(0.25, 0.03, 300),  # Drift period
    np.random.normal(0.12, 0.02, 300)   # Recovery
])
drift_points = [400, 700]

fig = visualize_concept_drift(timestamps, error_rates, drift_points)
plt.savefig("concept_drift.png")

Detecting concept drift enables timely model retraining to maintain prediction accuracy over time.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.