Back to Blog
8 min read

Input Validation for LLM Applications

Introduction

Input validation is the first line of defense for LLM applications. Proper validation prevents injection attacks, ensures data quality, and protects against abuse. This post covers comprehensive validation strategies for production systems.

Input Validation Architecture

from dataclasses import dataclass
from typing import List, Dict, Optional, Callable
from enum import Enum
import re

class ValidationLevel(Enum):
    STRICT = "strict"
    MODERATE = "moderate"
    PERMISSIVE = "permissive"

@dataclass
class ValidationResult:
    valid: bool
    sanitized_input: Optional[str]
    issues: List[Dict]
    risk_score: float

class InputValidator:
    """Comprehensive input validation for LLM applications"""

    def __init__(self, level: ValidationLevel = ValidationLevel.MODERATE):
        self.level = level
        self.validators: List[Callable] = []
        self._setup_default_validators()

    def _setup_default_validators(self):
        """Setup default validation rules"""
        self.validators = [
            self._validate_length,
            self._validate_encoding,
            self._validate_characters,
            self._validate_structure,
            self._validate_content
        ]

    def validate(self, user_input: str) -> ValidationResult:
        """Run all validators on input"""
        issues = []
        risk_score = 0.0

        for validator in self.validators:
            result = validator(user_input)
            if result["issues"]:
                issues.extend(result["issues"])
                risk_score += result["risk_contribution"]

        # Sanitize if needed
        sanitized = self._sanitize(user_input, issues) if issues else user_input

        # Determine validity
        valid = risk_score < self._get_threshold()

        return ValidationResult(
            valid=valid,
            sanitized_input=sanitized,
            issues=issues,
            risk_score=min(1.0, risk_score)
        )

    def _get_threshold(self) -> float:
        """Get risk threshold based on level"""
        thresholds = {
            ValidationLevel.STRICT: 0.3,
            ValidationLevel.MODERATE: 0.5,
            ValidationLevel.PERMISSIVE: 0.7
        }
        return thresholds[self.level]

    def _validate_length(self, text: str) -> Dict:
        """Validate input length"""
        issues = []
        risk = 0.0

        min_length = 1
        max_length = 10000

        if len(text) < min_length:
            issues.append({
                "type": "too_short",
                "message": "Input is empty or too short",
                "severity": "medium"
            })
            risk = 0.3

        if len(text) > max_length:
            issues.append({
                "type": "too_long",
                "message": f"Input exceeds {max_length} characters",
                "severity": "medium"
            })
            risk = 0.2

        return {"issues": issues, "risk_contribution": risk}

    def _validate_encoding(self, text: str) -> Dict:
        """Validate text encoding"""
        issues = []
        risk = 0.0

        try:
            text.encode('utf-8').decode('utf-8')
        except UnicodeError:
            issues.append({
                "type": "encoding_error",
                "message": "Invalid UTF-8 encoding",
                "severity": "high"
            })
            risk = 0.4

        # Check for null bytes
        if '\x00' in text:
            issues.append({
                "type": "null_bytes",
                "message": "Input contains null bytes",
                "severity": "high"
            })
            risk += 0.3

        return {"issues": issues, "risk_contribution": risk}

    def _validate_characters(self, text: str) -> Dict:
        """Validate character composition"""
        issues = []
        risk = 0.0

        # Check for control characters
        control_chars = re.findall(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', text)
        if control_chars:
            issues.append({
                "type": "control_characters",
                "message": f"Found {len(control_chars)} control characters",
                "severity": "medium"
            })
            risk = 0.2

        # Check for excessive special characters
        special_ratio = len(re.findall(r'[^\w\s]', text)) / max(len(text), 1)
        if special_ratio > 0.5:
            issues.append({
                "type": "excessive_special_chars",
                "message": f"High ratio of special characters ({special_ratio:.1%})",
                "severity": "low"
            })
            risk = 0.1

        return {"issues": issues, "risk_contribution": risk}

    def _validate_structure(self, text: str) -> Dict:
        """Validate input structure"""
        issues = []
        risk = 0.0

        # Check for suspicious delimiters
        suspicious_delimiters = [
            r'<\|.*?\|>',
            r'\[INST\]',
            r'\[\/INST\]',
            r'<<SYS>>',
            r'<\/s>',
            r'###'
        ]

        for pattern in suspicious_delimiters:
            if re.search(pattern, text):
                issues.append({
                    "type": "suspicious_delimiter",
                    "message": f"Found suspicious delimiter pattern",
                    "pattern": pattern,
                    "severity": "high"
                })
                risk += 0.3

        return {"issues": issues, "risk_contribution": min(1.0, risk)}

    def _validate_content(self, text: str) -> Dict:
        """Validate content for suspicious patterns"""
        issues = []
        risk = 0.0

        # Injection patterns
        injection_patterns = [
            (r'ignore (all )?(previous|prior|above)', "instruction_override", 0.4),
            (r'you are now', "role_change", 0.3),
            (r'pretend (to be|you)', "roleplay_attempt", 0.3),
            (r'system prompt', "system_probe", 0.2),
        ]

        for pattern, issue_type, risk_contrib in injection_patterns:
            if re.search(pattern, text.lower()):
                issues.append({
                    "type": issue_type,
                    "message": f"Detected potential {issue_type.replace('_', ' ')}",
                    "severity": "high" if risk_contrib > 0.3 else "medium"
                })
                risk += risk_contrib

        return {"issues": issues, "risk_contribution": min(1.0, risk)}

    def _sanitize(self, text: str, issues: List[Dict]) -> str:
        """Sanitize input based on detected issues"""
        sanitized = text

        for issue in issues:
            if issue["type"] == "control_characters":
                sanitized = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', sanitized)
            elif issue["type"] == "null_bytes":
                sanitized = sanitized.replace('\x00', '')
            elif issue["type"] == "suspicious_delimiter":
                # Remove suspicious delimiters
                sanitized = re.sub(issue.get("pattern", ""), "[filtered]", sanitized)

        return sanitized

Type-Specific Validators

from typing import Any
import json

class TypeValidator:
    """Validate specific input types"""

    @staticmethod
    def validate_json(text: str) -> ValidationResult:
        """Validate JSON input"""
        issues = []

        try:
            data = json.loads(text)

            # Check depth
            depth = TypeValidator._get_json_depth(data)
            if depth > 10:
                issues.append({
                    "type": "excessive_depth",
                    "message": f"JSON depth ({depth}) exceeds limit",
                    "severity": "medium"
                })

            # Check size
            if len(text) > 100000:
                issues.append({
                    "type": "excessive_size",
                    "message": "JSON too large",
                    "severity": "medium"
                })

        except json.JSONDecodeError as e:
            issues.append({
                "type": "invalid_json",
                "message": f"Invalid JSON: {str(e)}",
                "severity": "high"
            })

        return ValidationResult(
            valid=len(issues) == 0,
            sanitized_input=text,
            issues=issues,
            risk_score=0.5 if issues else 0.0
        )

    @staticmethod
    def _get_json_depth(obj: Any, current_depth: int = 0) -> int:
        """Calculate JSON nesting depth"""
        if isinstance(obj, dict):
            if not obj:
                return current_depth
            return max(TypeValidator._get_json_depth(v, current_depth + 1) for v in obj.values())
        elif isinstance(obj, list):
            if not obj:
                return current_depth
            return max(TypeValidator._get_json_depth(item, current_depth + 1) for item in obj)
        return current_depth

    @staticmethod
    def validate_url(text: str) -> ValidationResult:
        """Validate URL input"""
        import urllib.parse
        issues = []

        try:
            parsed = urllib.parse.urlparse(text)

            # Check scheme
            if parsed.scheme not in ['http', 'https']:
                issues.append({
                    "type": "invalid_scheme",
                    "message": f"Invalid URL scheme: {parsed.scheme}",
                    "severity": "medium"
                })

            # Check for suspicious patterns
            if re.search(r'(localhost|127\.0\.0\.1|0\.0\.0\.0)', text):
                issues.append({
                    "type": "localhost_url",
                    "message": "URL points to localhost",
                    "severity": "high"
                })

            # Check for common SSRF patterns
            if re.search(r'(169\.254\.|10\.|172\.(1[6-9]|2[0-9]|3[01])\.)', text):
                issues.append({
                    "type": "internal_ip",
                    "message": "URL points to internal IP",
                    "severity": "high"
                })

        except Exception as e:
            issues.append({
                "type": "invalid_url",
                "message": str(e),
                "severity": "high"
            })

        return ValidationResult(
            valid=len([i for i in issues if i["severity"] == "high"]) == 0,
            sanitized_input=text,
            issues=issues,
            risk_score=0.7 if issues else 0.0
        )

    @staticmethod
    def validate_code(text: str, language: str = "python") -> ValidationResult:
        """Validate code input"""
        issues = []

        dangerous_patterns = {
            "python": [
                (r'import\s+os', "os_import", "medium"),
                (r'import\s+subprocess', "subprocess_import", "high"),
                (r'exec\s*\(', "exec_call", "high"),
                (r'eval\s*\(', "eval_call", "high"),
                (r'__import__', "dynamic_import", "high"),
                (r'open\s*\(', "file_open", "medium"),
            ],
            "javascript": [
                (r'eval\s*\(', "eval_call", "high"),
                (r'Function\s*\(', "function_constructor", "high"),
                (r'require\s*\([\'"]child_process', "child_process", "high"),
            ]
        }

        patterns = dangerous_patterns.get(language, [])

        for pattern, issue_type, severity in patterns:
            if re.search(pattern, text):
                issues.append({
                    "type": issue_type,
                    "message": f"Detected potentially dangerous pattern",
                    "severity": severity
                })

        high_severity = sum(1 for i in issues if i["severity"] == "high")

        return ValidationResult(
            valid=high_severity == 0,
            sanitized_input=text,
            issues=issues,
            risk_score=min(1.0, high_severity * 0.3 + len(issues) * 0.1)
        )

Context-Aware Validation

class ContextAwareValidator:
    """Validate input based on application context"""

    def __init__(self, context: Dict):
        self.context = context
        self.base_validator = InputValidator()

    def validate(self, user_input: str) -> ValidationResult:
        """Validate with context awareness"""
        # Base validation
        base_result = self.base_validator.validate(user_input)

        # Context-specific validation
        context_issues = self._validate_for_context(user_input)

        # Combine results
        all_issues = base_result.issues + context_issues
        risk_score = min(1.0, base_result.risk_score + len(context_issues) * 0.1)

        return ValidationResult(
            valid=base_result.valid and len(context_issues) == 0,
            sanitized_input=base_result.sanitized_input,
            issues=all_issues,
            risk_score=risk_score
        )

    def _validate_for_context(self, text: str) -> List[Dict]:
        """Apply context-specific validations"""
        issues = []

        app_type = self.context.get("app_type", "general")

        if app_type == "customer_service":
            issues.extend(self._validate_customer_service(text))
        elif app_type == "code_assistant":
            issues.extend(self._validate_code_assistant(text))
        elif app_type == "medical":
            issues.extend(self._validate_medical(text))

        return issues

    def _validate_customer_service(self, text: str) -> List[Dict]:
        """Validation for customer service context"""
        issues = []

        # Check for competitor mentions (might be relevant)
        # Check for profanity
        profanity_pattern = r'\b(damn|hell|crap)\b'  # Simplified example
        if re.search(profanity_pattern, text.lower()):
            issues.append({
                "type": "profanity",
                "message": "Input contains profanity",
                "severity": "low"
            })

        return issues

    def _validate_code_assistant(self, text: str) -> List[Dict]:
        """Validation for code assistant context"""
        issues = []

        # Check for attempts to access file system
        if re.search(r'(\/etc\/passwd|\.env|credentials)', text):
            issues.append({
                "type": "sensitive_path",
                "message": "Reference to sensitive file paths",
                "severity": "high"
            })

        return issues

    def _validate_medical(self, text: str) -> List[Dict]:
        """Validation for medical context"""
        issues = []

        # Flag potential emergencies
        emergency_patterns = [
            r'(heart attack|stroke|can\'t breathe|overdose)',
            r'(bleeding heavily|severe pain|unconscious)'
        ]

        for pattern in emergency_patterns:
            if re.search(pattern, text.lower()):
                issues.append({
                    "type": "potential_emergency",
                    "message": "Input may indicate medical emergency",
                    "severity": "high",
                    "action": "escalate"
                })

        return issues

# Usage
context = {"app_type": "customer_service", "user_tier": "premium"}
validator = ContextAwareValidator(context)

result = validator.validate("I need help with my order, this is frustrating!")
print(f"Valid: {result.valid}")
print(f"Issues: {result.issues}")

Rate Limiting and Abuse Prevention

from datetime import datetime, timedelta
from collections import defaultdict

class RateLimiter:
    """Rate limiting for input validation"""

    def __init__(self, max_requests: int = 100, window_seconds: int = 60):
        self.max_requests = max_requests
        self.window = timedelta(seconds=window_seconds)
        self.requests: Dict[str, List[datetime]] = defaultdict(list)

    def check_rate_limit(self, user_id: str) -> Dict:
        """Check if user is within rate limit"""
        now = datetime.now()
        cutoff = now - self.window

        # Clean old requests
        self.requests[user_id] = [
            t for t in self.requests[user_id] if t > cutoff
        ]

        # Check limit
        current_count = len(self.requests[user_id])

        if current_count >= self.max_requests:
            return {
                "allowed": False,
                "current_count": current_count,
                "reset_time": min(self.requests[user_id]) + self.window
            }

        # Record this request
        self.requests[user_id].append(now)

        return {
            "allowed": True,
            "current_count": current_count + 1,
            "remaining": self.max_requests - current_count - 1
        }

class InputValidationPipeline:
    """Complete input validation pipeline"""

    def __init__(self):
        self.validator = InputValidator()
        self.rate_limiter = RateLimiter()
        self.type_validator = TypeValidator()

    def process(
        self,
        user_input: str,
        user_id: str,
        input_type: str = "text"
    ) -> Dict:
        """Process input through validation pipeline"""
        # Check rate limit
        rate_check = self.rate_limiter.check_rate_limit(user_id)
        if not rate_check["allowed"]:
            return {
                "success": False,
                "error": "rate_limit_exceeded",
                "reset_time": rate_check["reset_time"].isoformat()
            }

        # Base validation
        base_result = self.validator.validate(user_input)

        # Type-specific validation
        if input_type == "json":
            type_result = self.type_validator.validate_json(user_input)
        elif input_type == "url":
            type_result = self.type_validator.validate_url(user_input)
        elif input_type == "code":
            type_result = self.type_validator.validate_code(user_input)
        else:
            type_result = None

        # Combine results
        all_issues = base_result.issues
        if type_result:
            all_issues.extend(type_result.issues)

        valid = base_result.valid and (type_result is None or type_result.valid)

        return {
            "success": valid,
            "sanitized_input": base_result.sanitized_input,
            "issues": all_issues,
            "risk_score": base_result.risk_score,
            "rate_limit_remaining": rate_check["remaining"]
        }

# Usage
pipeline = InputValidationPipeline()

result = pipeline.process(
    user_input="What is the capital of France?",
    user_id="user_123",
    input_type="text"
)

print(f"Success: {result['success']}")
print(f"Risk Score: {result['risk_score']:.2f}")

Conclusion

Input validation is critical for LLM application security. A comprehensive approach includes length and encoding checks, structure validation, content analysis, type-specific validation, context awareness, and rate limiting. Implementing these layers protects against injection attacks and abuse while ensuring data quality.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.