Back to Blog
5 min read

Model Risk Management for Large Language Models

Model Risk Management for Large Language Models

LLMs present unique model risk challenges. They’re not just ML models - they’re complex systems that require specialized risk management approaches.

LLM-Specific Risk Categories

from dataclasses import dataclass, field
from typing import List, Dict, Optional
from enum import Enum
from datetime import datetime

class LLMRiskType(Enum):
    HALLUCINATION = "Hallucination"
    JAILBREAK = "Jailbreak/Misuse"
    BIAS = "Bias & Fairness"
    PRIVACY = "Privacy Leakage"
    TOXICITY = "Toxic Content"
    INCONSISTENCY = "Output Inconsistency"
    OUTDATED = "Knowledge Cutoff"
    OVERRELIANCE = "User Overreliance"

@dataclass
class LLMRiskProfile:
    model_name: str
    deployment_context: str
    risk_scores: Dict[LLMRiskType, int]
    mitigations_implemented: List[str]
    last_assessed: datetime

    def get_overall_risk(self) -> str:
        avg_score = sum(self.risk_scores.values()) / len(self.risk_scores)
        if avg_score >= 4:
            return "Critical"
        elif avg_score >= 3:
            return "High"
        elif avg_score >= 2:
            return "Medium"
        return "Low"

# Example risk profile
customer_service_llm = LLMRiskProfile(
    model_name="GPT-4 Turbo",
    deployment_context="Customer Service Chatbot",
    risk_scores={
        LLMRiskType.HALLUCINATION: 3,
        LLMRiskType.JAILBREAK: 2,
        LLMRiskType.BIAS: 2,
        LLMRiskType.PRIVACY: 4,
        LLMRiskType.TOXICITY: 2,
        LLMRiskType.INCONSISTENCY: 3,
        LLMRiskType.OUTDATED: 2,
        LLMRiskType.OVERRELIANCE: 3
    },
    mitigations_implemented=[
        "RAG grounding",
        "Content filtering",
        "Human escalation path"
    ],
    last_assessed=datetime.now()
)

Hallucination Risk Management

class HallucinationManager:
    """Manage hallucination risk in LLM systems."""

    def __init__(self):
        self.detection_results: List[Dict] = []

    def detect_hallucination(
        self,
        response: str,
        context: str,
        method: str = "nli"
    ) -> Dict:
        """Detect potential hallucinations."""
        result = {
            "response": response[:100] + "...",
            "detection_method": method,
            "timestamp": datetime.now()
        }

        if method == "nli":
            # Natural Language Inference approach
            score = self._nli_check(response, context)
            result["nli_score"] = score
            result["hallucination_risk"] = "High" if score < 0.5 else "Low"

        elif method == "entity_verification":
            # Check named entities against knowledge base
            entities = self._extract_entities(response)
            verified = self._verify_entities(entities)
            result["entities_found"] = len(entities)
            result["entities_verified"] = verified
            result["hallucination_risk"] = "High" if verified < 0.7 else "Low"

        self.detection_results.append(result)
        return result

    def _nli_check(self, response: str, context: str) -> float:
        """Check if response is entailed by context."""
        # Placeholder for NLI model inference
        return 0.75

    def _extract_entities(self, text: str) -> List[str]:
        """Extract named entities from text."""
        # Placeholder for NER
        return ["entity1", "entity2"]

    def _verify_entities(self, entities: List[str]) -> float:
        """Verify entities against knowledge base."""
        # Placeholder for verification
        return 0.8

    def implement_rag_grounding(self, config: Dict) -> Dict:
        """Configure RAG to reduce hallucinations."""
        rag_config = {
            "retrieval_settings": {
                "top_k": config.get("top_k", 5),
                "similarity_threshold": config.get("threshold", 0.7),
                "reranking_enabled": True
            },
            "grounding_instructions": """
                Answer ONLY based on the provided context.
                If the context doesn't contain the answer, say:
                "I don't have information about that in my knowledge base."
                Always cite your sources.
            """,
            "citation_format": "[Source: {document_name}]",
            "confidence_scoring": True
        }
        return rag_config

    def get_hallucination_metrics(self) -> Dict:
        """Calculate hallucination metrics over time."""
        if not self.detection_results:
            return {}

        high_risk_count = sum(
            1 for r in self.detection_results
            if r.get("hallucination_risk") == "High"
        )

        return {
            "total_checked": len(self.detection_results),
            "high_risk_count": high_risk_count,
            "hallucination_rate": high_risk_count / len(self.detection_results)
        }

Output Validation Framework

class OutputValidator:
    """Validate LLM outputs for various risk factors."""

    def __init__(self):
        self.validators: List[callable] = []

    def add_validator(self, name: str, validator_func: callable, priority: int = 5):
        """Add a validation function."""
        self.validators.append({
            "name": name,
            "function": validator_func,
            "priority": priority
        })
        self.validators.sort(key=lambda x: x["priority"])

    def validate(self, output: str, context: Dict = None) -> Dict:
        """Run all validators on output."""
        results = {
            "output": output[:200],
            "passed": True,
            "validations": []
        }

        for validator in self.validators:
            try:
                is_valid, details = validator["function"](output, context)
                results["validations"].append({
                    "name": validator["name"],
                    "passed": is_valid,
                    "details": details
                })
                if not is_valid:
                    results["passed"] = False
            except Exception as e:
                results["validations"].append({
                    "name": validator["name"],
                    "passed": False,
                    "error": str(e)
                })
                results["passed"] = False

        return results

# Create validators
validator = OutputValidator()

# PII detection validator
def pii_validator(output: str, context: Dict) -> tuple:
    """Check for PII in output."""
    pii_patterns = [
        r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
        r'\b\d{16}\b',  # Credit card
        r'\b[\w.-]+@[\w.-]+\.\w+\b'  # Email
    ]
    import re
    found = []
    for pattern in pii_patterns:
        if re.search(pattern, output):
            found.append(pattern)
    return len(found) == 0, {"patterns_found": found}

validator.add_validator("PII Detection", pii_validator, priority=1)

# Length validator
def length_validator(output: str, context: Dict) -> tuple:
    """Ensure output is within acceptable length."""
    max_length = context.get("max_length", 2000) if context else 2000
    return len(output) <= max_length, {"length": len(output), "max": max_length}

validator.add_validator("Length Check", length_validator, priority=5)

# Toxicity validator
def toxicity_validator(output: str, context: Dict) -> tuple:
    """Check for toxic content."""
    # Placeholder - would use actual toxicity model
    toxic_words = ["harmful", "dangerous"]  # Simplified
    found = [w for w in toxic_words if w.lower() in output.lower()]
    return len(found) == 0, {"toxic_content": found}

validator.add_validator("Toxicity Check", toxicity_validator, priority=2)

Model Monitoring and Observability

class LLMObservability:
    """Comprehensive LLM monitoring and observability."""

    def __init__(self, model_id: str):
        self.model_id = model_id
        self.logs: List[Dict] = []
        self.metrics: Dict[str, List] = {
            "latency": [],
            "token_usage": [],
            "error_rate": [],
            "validation_pass_rate": []
        }

    def log_interaction(
        self,
        request_id: str,
        prompt: str,
        response: str,
        latency_ms: float,
        tokens_used: int,
        validation_result: Dict
    ):
        """Log a single interaction."""
        log_entry = {
            "request_id": request_id,
            "timestamp": datetime.now(),
            "prompt_preview": prompt[:100],
            "response_preview": response[:100],
            "latency_ms": latency_ms,
            "tokens_used": tokens_used,
            "validation_passed": validation_result.get("passed", False)
        }

        self.logs.append(log_entry)
        self.metrics["latency"].append(latency_ms)
        self.metrics["token_usage"].append(tokens_used)
        self.metrics["validation_pass_rate"].append(
            1 if validation_result.get("passed") else 0
        )

    def get_health_dashboard(self) -> Dict:
        """Generate health dashboard data."""
        def avg(lst):
            return sum(lst) / len(lst) if lst else 0

        recent_logs = self.logs[-100:]  # Last 100 interactions

        return {
            "model_id": self.model_id,
            "total_interactions": len(self.logs),
            "metrics": {
                "avg_latency_ms": avg(self.metrics["latency"][-100:]),
                "p95_latency_ms": sorted(self.metrics["latency"][-100:])[95] if len(self.metrics["latency"]) >= 100 else None,
                "avg_tokens": avg(self.metrics["token_usage"][-100:]),
                "validation_pass_rate": avg(self.metrics["validation_pass_rate"][-100:])
            },
            "alerts": self._check_alerts()
        }

    def _check_alerts(self) -> List[str]:
        """Check for alert conditions."""
        alerts = []

        # High latency alert
        recent_latency = self.metrics["latency"][-10:]
        if recent_latency and sum(recent_latency) / len(recent_latency) > 5000:
            alerts.append("High average latency detected")

        # Low validation pass rate
        recent_validation = self.metrics["validation_pass_rate"][-50:]
        if recent_validation and sum(recent_validation) / len(recent_validation) < 0.9:
            alerts.append("Validation pass rate below threshold")

        return alerts

Tomorrow, we’ll explore operational risk in AI systems!

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.