5 min read
Model Risk Management for Large Language Models
Model Risk Management for Large Language Models
LLMs present unique model risk challenges. They’re not just ML models - they’re complex systems that require specialized risk management approaches.
LLM-Specific Risk Categories
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from enum import Enum
from datetime import datetime
class LLMRiskType(Enum):
HALLUCINATION = "Hallucination"
JAILBREAK = "Jailbreak/Misuse"
BIAS = "Bias & Fairness"
PRIVACY = "Privacy Leakage"
TOXICITY = "Toxic Content"
INCONSISTENCY = "Output Inconsistency"
OUTDATED = "Knowledge Cutoff"
OVERRELIANCE = "User Overreliance"
@dataclass
class LLMRiskProfile:
model_name: str
deployment_context: str
risk_scores: Dict[LLMRiskType, int]
mitigations_implemented: List[str]
last_assessed: datetime
def get_overall_risk(self) -> str:
avg_score = sum(self.risk_scores.values()) / len(self.risk_scores)
if avg_score >= 4:
return "Critical"
elif avg_score >= 3:
return "High"
elif avg_score >= 2:
return "Medium"
return "Low"
# Example risk profile
customer_service_llm = LLMRiskProfile(
model_name="GPT-4 Turbo",
deployment_context="Customer Service Chatbot",
risk_scores={
LLMRiskType.HALLUCINATION: 3,
LLMRiskType.JAILBREAK: 2,
LLMRiskType.BIAS: 2,
LLMRiskType.PRIVACY: 4,
LLMRiskType.TOXICITY: 2,
LLMRiskType.INCONSISTENCY: 3,
LLMRiskType.OUTDATED: 2,
LLMRiskType.OVERRELIANCE: 3
},
mitigations_implemented=[
"RAG grounding",
"Content filtering",
"Human escalation path"
],
last_assessed=datetime.now()
)
Hallucination Risk Management
class HallucinationManager:
"""Manage hallucination risk in LLM systems."""
def __init__(self):
self.detection_results: List[Dict] = []
def detect_hallucination(
self,
response: str,
context: str,
method: str = "nli"
) -> Dict:
"""Detect potential hallucinations."""
result = {
"response": response[:100] + "...",
"detection_method": method,
"timestamp": datetime.now()
}
if method == "nli":
# Natural Language Inference approach
score = self._nli_check(response, context)
result["nli_score"] = score
result["hallucination_risk"] = "High" if score < 0.5 else "Low"
elif method == "entity_verification":
# Check named entities against knowledge base
entities = self._extract_entities(response)
verified = self._verify_entities(entities)
result["entities_found"] = len(entities)
result["entities_verified"] = verified
result["hallucination_risk"] = "High" if verified < 0.7 else "Low"
self.detection_results.append(result)
return result
def _nli_check(self, response: str, context: str) -> float:
"""Check if response is entailed by context."""
# Placeholder for NLI model inference
return 0.75
def _extract_entities(self, text: str) -> List[str]:
"""Extract named entities from text."""
# Placeholder for NER
return ["entity1", "entity2"]
def _verify_entities(self, entities: List[str]) -> float:
"""Verify entities against knowledge base."""
# Placeholder for verification
return 0.8
def implement_rag_grounding(self, config: Dict) -> Dict:
"""Configure RAG to reduce hallucinations."""
rag_config = {
"retrieval_settings": {
"top_k": config.get("top_k", 5),
"similarity_threshold": config.get("threshold", 0.7),
"reranking_enabled": True
},
"grounding_instructions": """
Answer ONLY based on the provided context.
If the context doesn't contain the answer, say:
"I don't have information about that in my knowledge base."
Always cite your sources.
""",
"citation_format": "[Source: {document_name}]",
"confidence_scoring": True
}
return rag_config
def get_hallucination_metrics(self) -> Dict:
"""Calculate hallucination metrics over time."""
if not self.detection_results:
return {}
high_risk_count = sum(
1 for r in self.detection_results
if r.get("hallucination_risk") == "High"
)
return {
"total_checked": len(self.detection_results),
"high_risk_count": high_risk_count,
"hallucination_rate": high_risk_count / len(self.detection_results)
}
Output Validation Framework
class OutputValidator:
"""Validate LLM outputs for various risk factors."""
def __init__(self):
self.validators: List[callable] = []
def add_validator(self, name: str, validator_func: callable, priority: int = 5):
"""Add a validation function."""
self.validators.append({
"name": name,
"function": validator_func,
"priority": priority
})
self.validators.sort(key=lambda x: x["priority"])
def validate(self, output: str, context: Dict = None) -> Dict:
"""Run all validators on output."""
results = {
"output": output[:200],
"passed": True,
"validations": []
}
for validator in self.validators:
try:
is_valid, details = validator["function"](output, context)
results["validations"].append({
"name": validator["name"],
"passed": is_valid,
"details": details
})
if not is_valid:
results["passed"] = False
except Exception as e:
results["validations"].append({
"name": validator["name"],
"passed": False,
"error": str(e)
})
results["passed"] = False
return results
# Create validators
validator = OutputValidator()
# PII detection validator
def pii_validator(output: str, context: Dict) -> tuple:
"""Check for PII in output."""
pii_patterns = [
r'\b\d{3}-\d{2}-\d{4}\b', # SSN
r'\b\d{16}\b', # Credit card
r'\b[\w.-]+@[\w.-]+\.\w+\b' # Email
]
import re
found = []
for pattern in pii_patterns:
if re.search(pattern, output):
found.append(pattern)
return len(found) == 0, {"patterns_found": found}
validator.add_validator("PII Detection", pii_validator, priority=1)
# Length validator
def length_validator(output: str, context: Dict) -> tuple:
"""Ensure output is within acceptable length."""
max_length = context.get("max_length", 2000) if context else 2000
return len(output) <= max_length, {"length": len(output), "max": max_length}
validator.add_validator("Length Check", length_validator, priority=5)
# Toxicity validator
def toxicity_validator(output: str, context: Dict) -> tuple:
"""Check for toxic content."""
# Placeholder - would use actual toxicity model
toxic_words = ["harmful", "dangerous"] # Simplified
found = [w for w in toxic_words if w.lower() in output.lower()]
return len(found) == 0, {"toxic_content": found}
validator.add_validator("Toxicity Check", toxicity_validator, priority=2)
Model Monitoring and Observability
class LLMObservability:
"""Comprehensive LLM monitoring and observability."""
def __init__(self, model_id: str):
self.model_id = model_id
self.logs: List[Dict] = []
self.metrics: Dict[str, List] = {
"latency": [],
"token_usage": [],
"error_rate": [],
"validation_pass_rate": []
}
def log_interaction(
self,
request_id: str,
prompt: str,
response: str,
latency_ms: float,
tokens_used: int,
validation_result: Dict
):
"""Log a single interaction."""
log_entry = {
"request_id": request_id,
"timestamp": datetime.now(),
"prompt_preview": prompt[:100],
"response_preview": response[:100],
"latency_ms": latency_ms,
"tokens_used": tokens_used,
"validation_passed": validation_result.get("passed", False)
}
self.logs.append(log_entry)
self.metrics["latency"].append(latency_ms)
self.metrics["token_usage"].append(tokens_used)
self.metrics["validation_pass_rate"].append(
1 if validation_result.get("passed") else 0
)
def get_health_dashboard(self) -> Dict:
"""Generate health dashboard data."""
def avg(lst):
return sum(lst) / len(lst) if lst else 0
recent_logs = self.logs[-100:] # Last 100 interactions
return {
"model_id": self.model_id,
"total_interactions": len(self.logs),
"metrics": {
"avg_latency_ms": avg(self.metrics["latency"][-100:]),
"p95_latency_ms": sorted(self.metrics["latency"][-100:])[95] if len(self.metrics["latency"]) >= 100 else None,
"avg_tokens": avg(self.metrics["token_usage"][-100:]),
"validation_pass_rate": avg(self.metrics["validation_pass_rate"][-100:])
},
"alerts": self._check_alerts()
}
def _check_alerts(self) -> List[str]:
"""Check for alert conditions."""
alerts = []
# High latency alert
recent_latency = self.metrics["latency"][-10:]
if recent_latency and sum(recent_latency) / len(recent_latency) > 5000:
alerts.append("High average latency detected")
# Low validation pass rate
recent_validation = self.metrics["validation_pass_rate"][-50:]
if recent_validation and sum(recent_validation) / len(recent_validation) < 0.9:
alerts.append("Validation pass rate below threshold")
return alerts
Tomorrow, we’ll explore operational risk in AI systems!