Back to Blog
5 min read

Caching Strategies for LLM Applications

Caching is essential for reducing costs and latency in LLM applications. Today, I will cover caching strategies tailored for generative AI workloads.

Caching Challenges with LLMs

caching_challenges = {
    "non_deterministic": "Same input can produce different outputs",
    "semantic_similarity": "Similar prompts might need same response",
    "context_dependence": "Response depends on conversation history",
    "freshness": "Some queries need real-time information"
}

Exact Match Caching

import hashlib
import json
from datetime import datetime, timedelta
import redis

class ExactMatchCache:
    """Cache based on exact prompt match"""

    def __init__(self, redis_url: str, ttl_hours: int = 24):
        self.redis = redis.from_url(redis_url)
        self.ttl = ttl_hours * 3600

    def _generate_key(self, prompt: str, model: str, params: dict) -> str:
        """Generate cache key from request parameters"""
        key_data = {
            "prompt": prompt,
            "model": model,
            "temperature": params.get("temperature", 0.7),
            "max_tokens": params.get("max_tokens")
        }
        key_string = json.dumps(key_data, sort_keys=True)
        return f"llm_cache:{hashlib.sha256(key_string.encode()).hexdigest()}"

    def get(self, prompt: str, model: str, params: dict) -> dict | None:
        """Get cached response"""
        key = self._generate_key(prompt, model, params)
        data = self.redis.get(key)

        if data:
            cached = json.loads(data)
            # Update access time for LRU
            self.redis.expire(key, self.ttl)
            return cached

        return None

    def set(self, prompt: str, model: str, params: dict, response: dict):
        """Cache response"""
        key = self._generate_key(prompt, model, params)
        cache_data = {
            "response": response,
            "cached_at": datetime.utcnow().isoformat(),
            "model": model
        }
        self.redis.setex(key, self.ttl, json.dumps(cache_data))

    def invalidate(self, prompt: str, model: str, params: dict):
        """Remove from cache"""
        key = self._generate_key(prompt, model, params)
        self.redis.delete(key)

Semantic Caching

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

class SemanticCache:
    """Cache based on semantic similarity"""

    def __init__(self, embedding_client, redis_url: str, similarity_threshold: float = 0.95):
        self.embedding_client = embedding_client
        self.redis = redis.from_url(redis_url)
        self.threshold = similarity_threshold
        self.cache_index = "semantic_cache"

    def _get_embedding(self, text: str) -> list:
        """Get embedding for text"""
        response = self.embedding_client.embeddings.create(
            model="text-embedding-ada-002",
            input=text
        )
        return response.data[0].embedding

    def get(self, prompt: str) -> dict | None:
        """Find semantically similar cached response"""
        prompt_embedding = self._get_embedding(prompt)

        # Get all cached embeddings (in production, use vector DB)
        cached_keys = self.redis.keys(f"{self.cache_index}:*")

        best_match = None
        best_similarity = 0

        for key in cached_keys:
            data = json.loads(self.redis.get(key))
            cached_embedding = data["embedding"]

            similarity = cosine_similarity(
                [prompt_embedding],
                [cached_embedding]
            )[0][0]

            if similarity > self.threshold and similarity > best_similarity:
                best_similarity = similarity
                best_match = data

        if best_match:
            return {
                "response": best_match["response"],
                "similarity": best_similarity,
                "original_prompt": best_match["prompt"]
            }

        return None

    def set(self, prompt: str, response: dict):
        """Cache with embedding"""
        embedding = self._get_embedding(prompt)
        key = f"{self.cache_index}:{hashlib.md5(prompt.encode()).hexdigest()}"

        cache_data = {
            "prompt": prompt,
            "embedding": embedding,
            "response": response,
            "cached_at": datetime.utcnow().isoformat()
        }

        self.redis.setex(key, 86400, json.dumps(cache_data))  # 24 hour TTL

Conversation-Aware Caching

class ConversationCache:
    """Cache that considers conversation context"""

    def __init__(self, redis_url: str):
        self.redis = redis.from_url(redis_url)

    def _context_hash(self, messages: list, depth: int = 3) -> str:
        """Hash recent conversation context"""
        recent = messages[-depth:] if len(messages) > depth else messages
        context_str = json.dumps([m["content"][:100] for m in recent])
        return hashlib.md5(context_str.encode()).hexdigest()

    def get(self, prompt: str, context: list) -> dict | None:
        """Get cached response considering context"""
        context_hash = self._context_hash(context)
        key = f"conv_cache:{context_hash}:{hashlib.md5(prompt.encode()).hexdigest()}"
        data = self.redis.get(key)
        return json.loads(data) if data else None

    def set(self, prompt: str, context: list, response: dict, ttl: int = 3600):
        """Cache with context awareness"""
        context_hash = self._context_hash(context)
        key = f"conv_cache:{context_hash}:{hashlib.md5(prompt.encode()).hexdigest()}"
        self.redis.setex(key, ttl, json.dumps(response))

Tiered Caching

class TieredCache:
    """Multi-tier caching strategy"""

    def __init__(self, embedding_client, redis_url: str):
        self.exact_cache = ExactMatchCache(redis_url)
        self.semantic_cache = SemanticCache(embedding_client, redis_url)
        self.stats = {"exact_hits": 0, "semantic_hits": 0, "misses": 0}

    def get(self, prompt: str, model: str, params: dict) -> dict | None:
        """Try caches in order of speed"""

        # Tier 1: Exact match (fastest)
        exact = self.exact_cache.get(prompt, model, params)
        if exact:
            self.stats["exact_hits"] += 1
            return {"source": "exact", **exact}

        # Tier 2: Semantic match (slower but broader)
        semantic = self.semantic_cache.get(prompt)
        if semantic:
            self.stats["semantic_hits"] += 1
            return {"source": "semantic", **semantic}

        self.stats["misses"] += 1
        return None

    def set(self, prompt: str, model: str, params: dict, response: dict):
        """Store in all applicable caches"""
        # Always store in exact cache
        self.exact_cache.set(prompt, model, params, response)

        # Store in semantic cache for broader matching
        self.semantic_cache.set(prompt, response)

    def get_stats(self) -> dict:
        total = sum(self.stats.values())
        return {
            **self.stats,
            "hit_rate": (self.stats["exact_hits"] + self.stats["semantic_hits"]) / total if total else 0
        }

Cache-Aware LLM Client

class CachedLLMClient:
    """LLM client with integrated caching"""

    def __init__(self, client, cache: TieredCache):
        self.client = client
        self.cache = cache

    def chat(
        self,
        messages: list,
        model: str = "gpt-4",
        use_cache: bool = True,
        cache_ttl: int = 3600,
        **kwargs
    ) -> dict:
        """Chat with caching"""

        prompt = messages[-1]["content"] if messages else ""

        # Check cache
        if use_cache:
            cached = self.cache.get(prompt, model, kwargs)
            if cached:
                return {
                    "response": cached["response"],
                    "cached": True,
                    "cache_source": cached.get("source"),
                    "cost": 0
                }

        # Make API call
        response = self.client.chat.completions.create(
            model=model,
            messages=messages,
            **kwargs
        )

        result = response.choices[0].message.content

        # Cache response
        if use_cache:
            self.cache.set(prompt, model, kwargs, {"content": result})

        return {
            "response": result,
            "cached": False,
            "tokens": {
                "input": response.usage.prompt_tokens,
                "output": response.usage.completion_tokens
            }
        }

Cache Warming

class CacheWarmer:
    """Pre-populate cache with common queries"""

    def __init__(self, client, cache):
        self.client = client
        self.cache = cache

    def warm_from_logs(self, query_logs: list, top_n: int = 100):
        """Warm cache from historical queries"""
        from collections import Counter

        # Find most common queries
        query_counts = Counter(log["query"] for log in query_logs)
        common_queries = [q for q, _ in query_counts.most_common(top_n)]

        for query in common_queries:
            # Check if already cached
            if not self.cache.get(query, "gpt-4", {}):
                # Generate and cache
                response = self.client.chat.completions.create(
                    model="gpt-4",
                    messages=[{"role": "user", "content": query}]
                )
                self.cache.set(query, "gpt-4", {}, {"content": response.choices[0].message.content})

    def warm_faq(self, faqs: list):
        """Pre-cache FAQ responses"""
        for faq in faqs:
            self.cache.set(
                faq["question"],
                "gpt-4",
                {},
                {"content": faq["answer"]}
            )

Effective caching significantly reduces LLM costs and latency. Tomorrow, I will cover response streaming patterns.

Resources

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.