Back to Blog
5 min read

Feature Drift Monitoring for ML Systems

Feature drift occurs when individual feature distributions change over time. Monitoring feature-level drift helps identify the root cause of model degradation.

Feature Drift vs Data Drift

  • Data Drift: Overall input distribution changes
  • Feature Drift: Individual feature distributions change
  • Root Cause: Feature drift helps identify which specific features are problematic

Feature Drift Detector

import numpy as np
import pandas as pd
from scipy import stats
from typing import Dict, List, Tuple
from dataclasses import dataclass

@dataclass
class FeatureDriftResult:
    feature_name: str
    drift_score: float
    p_value: float
    drift_detected: bool
    reference_stats: Dict
    current_stats: Dict
    method: str

class FeatureDriftMonitor:
    """Monitor drift for individual features"""

    def __init__(self, reference_data: pd.DataFrame, drift_threshold: float = 0.1):
        self.reference_data = reference_data
        self.drift_threshold = drift_threshold
        self.feature_stats = self._compute_feature_stats(reference_data)

    def _compute_feature_stats(self, df: pd.DataFrame) -> Dict:
        """Compute statistics for each feature"""
        stats_dict = {}
        for col in df.columns:
            if df[col].dtype in ['int64', 'float64']:
                stats_dict[col] = {
                    'type': 'numerical',
                    'mean': df[col].mean(),
                    'std': df[col].std(),
                    'median': df[col].median(),
                    'q1': df[col].quantile(0.25),
                    'q3': df[col].quantile(0.75),
                    'min': df[col].min(),
                    'max': df[col].max(),
                    'null_rate': df[col].isnull().mean()
                }
            else:
                value_counts = df[col].value_counts(normalize=True)
                stats_dict[col] = {
                    'type': 'categorical',
                    'mode': df[col].mode().iloc[0] if len(df[col].mode()) > 0 else None,
                    'unique_count': df[col].nunique(),
                    'value_distribution': value_counts.to_dict(),
                    'null_rate': df[col].isnull().mean()
                }
        return stats_dict

    def check_feature_drift(self, current_data: pd.DataFrame) -> List[FeatureDriftResult]:
        """Check drift for all features"""
        results = []

        for feature in self.reference_data.columns:
            if feature not in current_data.columns:
                continue

            feature_type = self.feature_stats[feature]['type']

            if feature_type == 'numerical':
                result = self._check_numerical_drift(feature, current_data[feature])
            else:
                result = self._check_categorical_drift(feature, current_data[feature])

            results.append(result)

        return results

    def _check_numerical_drift(self, feature: str, current_data: pd.Series) -> FeatureDriftResult:
        """Check drift for numerical feature"""
        ref_data = self.reference_data[feature].dropna()
        cur_data = current_data.dropna()

        # KS test
        ks_stat, p_value = stats.ks_2samp(ref_data, cur_data)

        # Compute current stats
        current_stats = {
            'mean': cur_data.mean(),
            'std': cur_data.std(),
            'median': cur_data.median(),
            'q1': cur_data.quantile(0.25),
            'q3': cur_data.quantile(0.75),
            'null_rate': current_data.isnull().mean()
        }

        return FeatureDriftResult(
            feature_name=feature,
            drift_score=ks_stat,
            p_value=p_value,
            drift_detected=ks_stat > self.drift_threshold or p_value < 0.05,
            reference_stats=self.feature_stats[feature],
            current_stats=current_stats,
            method='kolmogorov_smirnov'
        )

    def _check_categorical_drift(self, feature: str, current_data: pd.Series) -> FeatureDriftResult:
        """Check drift for categorical feature"""
        ref_dist = self.feature_stats[feature]['value_distribution']
        cur_dist = current_data.value_counts(normalize=True).to_dict()

        # Calculate JS divergence
        all_categories = set(ref_dist.keys()) | set(cur_dist.keys())
        ref_probs = np.array([ref_dist.get(c, 1e-10) for c in all_categories])
        cur_probs = np.array([cur_dist.get(c, 1e-10) for c in all_categories])

        # Normalize
        ref_probs = ref_probs / ref_probs.sum()
        cur_probs = cur_probs / cur_probs.sum()

        # JS divergence
        m = 0.5 * (ref_probs + cur_probs)
        js_div = 0.5 * stats.entropy(ref_probs, m) + 0.5 * stats.entropy(cur_probs, m)

        # Chi-squared test for p-value
        ref_counts = self.reference_data[feature].value_counts()
        cur_counts = current_data.value_counts()

        # Align
        all_cats = set(ref_counts.index) | set(cur_counts.index)
        ref_aligned = np.array([ref_counts.get(c, 0) for c in all_cats])
        cur_aligned = np.array([cur_counts.get(c, 0) for c in all_cats])

        expected = ref_aligned / ref_aligned.sum() * cur_aligned.sum()
        expected = np.where(expected < 1, 1, expected)

        try:
            chi_stat, p_value = stats.chisquare(cur_aligned, expected)
        except:
            p_value = 1.0

        current_stats = {
            'mode': current_data.mode().iloc[0] if len(current_data.mode()) > 0 else None,
            'unique_count': current_data.nunique(),
            'value_distribution': cur_dist,
            'null_rate': current_data.isnull().mean()
        }

        return FeatureDriftResult(
            feature_name=feature,
            drift_score=js_div,
            p_value=p_value,
            drift_detected=js_div > self.drift_threshold or p_value < 0.05,
            reference_stats=self.feature_stats[feature],
            current_stats=current_stats,
            method='jensen_shannon'
        )

    def get_drift_report(self, current_data: pd.DataFrame) -> pd.DataFrame:
        """Generate comprehensive drift report"""
        results = self.check_feature_drift(current_data)

        report_data = []
        for result in results:
            report_data.append({
                'feature': result.feature_name,
                'drift_score': result.drift_score,
                'p_value': result.p_value,
                'drift_detected': result.drift_detected,
                'method': result.method
            })

        return pd.DataFrame(report_data).sort_values('drift_score', ascending=False)

Feature Importance-Weighted Drift

class ImportanceWeightedDriftMonitor(FeatureDriftMonitor):
    """Weight drift detection by feature importance"""

    def __init__(self, reference_data: pd.DataFrame, feature_importance: Dict[str, float],
                 drift_threshold: float = 0.1):
        super().__init__(reference_data, drift_threshold)
        self.feature_importance = feature_importance

    def get_weighted_drift_score(self, current_data: pd.DataFrame) -> float:
        """Calculate importance-weighted drift score"""
        results = self.check_feature_drift(current_data)

        weighted_sum = 0.0
        total_importance = 0.0

        for result in results:
            importance = self.feature_importance.get(result.feature_name, 0.0)
            weighted_sum += result.drift_score * importance
            total_importance += importance

        if total_importance == 0:
            return 0.0

        return weighted_sum / total_importance

    def get_critical_drifts(self, current_data: pd.DataFrame,
                            importance_threshold: float = 0.05) -> List[FeatureDriftResult]:
        """Get drift results for important features"""
        results = self.check_feature_drift(current_data)

        critical_results = []
        for result in results:
            importance = self.feature_importance.get(result.feature_name, 0.0)
            if result.drift_detected and importance >= importance_threshold:
                critical_results.append(result)

        return sorted(critical_results,
                      key=lambda x: self.feature_importance.get(x.feature_name, 0),
                      reverse=True)

# Usage
from sklearn.ensemble import RandomForestClassifier

# Get feature importance from model
model = RandomForestClassifier()
model.fit(X_train, y_train)
importance_dict = dict(zip(X_train.columns, model.feature_importances_))

# Create weighted monitor
weighted_monitor = ImportanceWeightedDriftMonitor(
    reference_data=X_train,
    feature_importance=importance_dict
)

# Check drift
weighted_score = weighted_monitor.get_weighted_drift_score(X_production)
critical_drifts = weighted_monitor.get_critical_drifts(X_production)

print(f"Weighted drift score: {weighted_score:.4f}")
print(f"Critical drifts: {len(critical_drifts)}")

Visualization

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_feature_drift(monitor: FeatureDriftMonitor,
                            current_data: pd.DataFrame,
                            top_n: int = 10):
    """Visualize feature drift"""
    report = monitor.get_drift_report(current_data)
    top_features = report.head(top_n)

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # Drift scores bar chart
    colors = ['red' if d else 'green' for d in top_features['drift_detected']]
    axes[0, 0].barh(top_features['feature'], top_features['drift_score'], color=colors)
    axes[0, 0].set_xlabel('Drift Score')
    axes[0, 0].set_title('Feature Drift Scores')
    axes[0, 0].axvline(x=monitor.drift_threshold, color='black', linestyle='--', label='Threshold')
    axes[0, 0].legend()

    # P-values
    axes[0, 1].barh(top_features['feature'], -np.log10(top_features['p_value']))
    axes[0, 1].set_xlabel('-log10(p-value)')
    axes[0, 1].set_title('Statistical Significance')
    axes[0, 1].axvline(x=-np.log10(0.05), color='red', linestyle='--', label='p=0.05')
    axes[0, 1].legend()

    # Distribution comparison for top drifted feature
    top_feature = top_features.iloc[0]['feature']
    ref_data = monitor.reference_data[top_feature].dropna()
    cur_data = current_data[top_feature].dropna()

    axes[1, 0].hist(ref_data, bins=30, alpha=0.5, label='Reference', density=True)
    axes[1, 0].hist(cur_data, bins=30, alpha=0.5, label='Current', density=True)
    axes[1, 0].set_title(f'Distribution: {top_feature}')
    axes[1, 0].legend()

    # Summary statistics
    drifted_count = top_features['drift_detected'].sum()
    summary_text = f"Total features: {len(report)}\n"
    summary_text += f"Features with drift: {drifted_count}\n"
    summary_text += f"Drift rate: {drifted_count/len(report)*100:.1f}%\n\n"
    summary_text += "Top drifted features:\n"
    for _, row in top_features[top_features['drift_detected']].head(5).iterrows():
        summary_text += f"  - {row['feature']}: {row['drift_score']:.4f}\n"

    axes[1, 1].text(0.1, 0.5, summary_text, fontsize=10, verticalalignment='center',
                    family='monospace', transform=axes[1, 1].transAxes)
    axes[1, 1].axis('off')
    axes[1, 1].set_title('Summary')

    plt.tight_layout()
    return fig

# Usage
fig = visualize_feature_drift(monitor, X_production)
fig.savefig("feature_drift_report.png")

Alerting System

class FeatureDriftAlertManager:
    """Manage alerts for feature drift"""

    def __init__(self, monitor: FeatureDriftMonitor, alert_config: Dict):
        self.monitor = monitor
        self.config = alert_config
        self.alert_history = []

    def check_and_alert(self, current_data: pd.DataFrame) -> List[Dict]:
        """Check for drift and generate alerts"""
        results = self.monitor.check_feature_drift(current_data)
        alerts = []

        for result in results:
            if result.drift_detected:
                severity = self._determine_severity(result)

                if severity >= self.config.get('min_severity', 'warning'):
                    alert = {
                        'timestamp': datetime.utcnow().isoformat(),
                        'feature': result.feature_name,
                        'drift_score': result.drift_score,
                        'severity': severity,
                        'message': self._generate_message(result)
                    }
                    alerts.append(alert)
                    self.alert_history.append(alert)

        return alerts

    def _determine_severity(self, result: FeatureDriftResult) -> str:
        """Determine alert severity"""
        if result.drift_score > 0.3:
            return 'critical'
        elif result.drift_score > 0.2:
            return 'high'
        elif result.drift_score > 0.1:
            return 'warning'
        return 'info'

    def _generate_message(self, result: FeatureDriftResult) -> str:
        """Generate alert message"""
        return (f"Feature drift detected in '{result.feature_name}'. "
                f"Drift score: {result.drift_score:.4f}, "
                f"Method: {result.method}")

Feature-level drift monitoring enables targeted investigation and remediation of model issues.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.