LLM Caching
LLM caching reduces costs and latency by storing and reusing responses for similar queries, using semantic similarity rather than exact matching.
Semantic Cache Implementation
import hashlib
import json
from datetime import datetime, timedelta
from typing import Optional
import numpy as np
from dataclasses import dataclass
@dataclass
class CacheEntry:
prompt_hash: str
embedding: np.ndarray
response: str
model: str
created_at: datetime
ttl_seconds: int
hit_count: int = 0
class SemanticCache:
def __init__(self, embedding_model, similarity_threshold: float = 0.95):
self.entries = {}
self.embedding_model = embedding_model
self.threshold = similarity_threshold
def _hash_prompt(self, prompt: str, model: str) -> str:
content = f"{prompt}:{model}"
return hashlib.sha256(content.encode()).hexdigest()
def _get_embedding(self, text: str) -> np.ndarray:
return self.embedding_model.embed_query(text)
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def get(self, prompt: str, model: str) -> Optional[str]:
prompt_hash = self._hash_prompt(prompt, model)
if prompt_hash in self.entries:
entry = self.entries[prompt_hash]
if datetime.now() - entry.created_at < timedelta(seconds=entry.ttl_seconds):
entry.hit_count += 1
return entry.response
else:
del self.entries[prompt_hash]
query_embedding = self._get_embedding(prompt)
for hash_key, entry in self.entries.items():
if entry.model != model:
continue
if datetime.now() - entry.created_at >= timedelta(seconds=entry.ttl_seconds):
continue
similarity = self._cosine_similarity(query_embedding, entry.embedding)
if similarity >= self.threshold:
entry.hit_count += 1
return entry.response
return None
def set(self, prompt: str, model: str, response: str, ttl_seconds: int = 3600):
prompt_hash = self._hash_prompt(prompt, model)
embedding = self._get_embedding(prompt)
self.entries[prompt_hash] = CacheEntry(
prompt_hash=prompt_hash,
embedding=embedding,
response=response,
model=model,
created_at=datetime.now(),
ttl_seconds=ttl_seconds
)
def get_stats(self) -> dict:
total_hits = sum(e.hit_count for e in self.entries.values())
return {
"total_entries": len(self.entries),
"total_hits": total_hits,
"avg_hits_per_entry": total_hits / len(self.entries) if self.entries else 0
}
Prompt Caching (Prefix Matching)
from functools import lru_cache
import tiktoken
class PromptCacher:
def __init__(self, llm_client):
self.llm = llm_client
self.cache = {}
self.tokenizer = tiktoken.encoding_for_model("gpt-4")
def find_common_prefix(self, prompts: list) -> str:
if not prompts:
return ""
prefix = prompts[0]
for prompt in prompts[1:]:
while not prompt.startswith(prefix):
prefix = prefix[:-1]
return prefix
def cache_prefix(self, system_prompt: str, user_prompts: list) -> list:
results = []
for user_prompt in user_prompts:
cache_key = f"{system_prompt}:{user_prompt[:100]}"
if cache_key in self.cache:
results.append(self.cache[cache_key])
else:
response = self.llm.invoke([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]).content
self.cache[cache_key] = response
results.append(response)
return results
def estimate_savings(self, prompts: list) -> dict:
prefix_tokens = len(self.tokenizer.encode(
self.find_common_prefix(prompts)
))
total_tokens = sum(len(self.tokenizer.encode(p)) for p in prompts)
return {
"prefix_tokens": prefix_tokens,
"total_tokens": total_tokens,
"potential_savings_pct": (prefix_tokens / total_tokens * 100) if total_tokens else 0
}
Redis-Based Distributed Cache
import redis
import json
import hashlib
from typing import Optional
class RedisLLMCache:
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis = redis.from_url(redis_url)
self.default_ttl = 3600
def _make_key(self, prompt: str, model: str) -> str:
content = f"llm_cache:{model}:{prompt}"
return hashlib.md5(content.encode()).hexdigest()
def get(self, prompt: str, model: str) -> Optional[dict]:
key = self._make_key(prompt, model)
cached = self.redis.get(key)
if cached:
data = json.loads(cached)
self.redis.incr(f"{key}:hits")
return data
return None
def set(self, prompt: str, model: str, response: str,
ttl: int = None, metadata: dict = None):
key = self._make_key(prompt, model)
data = {
"response": response,
"model": model,
"prompt_hash": key,
"metadata": metadata or {}
}
self.redis.setex(
key,
ttl or self.default_ttl,
json.dumps(data)
)
def invalidate(self, pattern: str):
keys = self.redis.keys(f"*{pattern}*")
if keys:
self.redis.delete(*keys)
def get_stats(self) -> dict:
info = self.redis.info("stats")
keyspace = self.redis.info("keyspace")
return {
"total_keys": sum(v.get("keys", 0) for v in keyspace.values()),
"hits": info.get("keyspace_hits", 0),
"misses": info.get("keyspace_misses", 0),
"hit_rate": info.get("keyspace_hits", 0) /
(info.get("keyspace_hits", 0) + info.get("keyspace_misses", 1))
}
Multi-Layer Cache Strategy
class MultiLayerCache:
def __init__(self, embedding_model, redis_client=None):
self.l1_cache = {} # In-memory
self.l2_cache = SemanticCache(embedding_model)
self.l3_cache = RedisLLMCache(redis_client) if redis_client else None
def get(self, prompt: str, model: str) -> Optional[str]:
result = self.l1_cache.get(f"{model}:{prompt}")
if result:
return result
result = self.l2_cache.get(prompt, model)
if result:
self.l1_cache[f"{model}:{prompt}"] = result
return result
if self.l3_cache:
cached = self.l3_cache.get(prompt, model)
if cached:
self.l1_cache[f"{model}:{prompt}"] = cached["response"]
return cached["response"]
return None
def set(self, prompt: str, model: str, response: str):
self.l1_cache[f"{model}:{prompt}"] = response
self.l2_cache.set(prompt, model, response)
if self.l3_cache:
self.l3_cache.set(prompt, model, response)
Key Takeaways
- Semantic caching catches similar but not identical queries
- Prefix caching reduces costs for repeated system prompts
- Multi-layer caching optimizes for different access patterns
- TTL management ensures cache freshness
- Cache hit monitoring tracks effectiveness and savings