Inference Optimization
KV Cache
class KVCache:
def __init__(self, max_seq_len, num_heads, head_dim):
self.key_cache = torch.zeros(1, num_heads, max_seq_len, head_dim)
self.value_cache = torch.zeros(1, num_heads, max_seq_len, head_dim)
self.current_len = 0
def update(self, key, value):
seq_len = key.shape[2]
self.key_cache[:, :, self.current_len:self.current_len + seq_len] = key
self.value_cache[:, :, self.current_len:self.current_len + seq_len] = value
self.current_len += seq_len
return self.key_cache[:, :, :self.current_len], self.value_cache[:, :, :self.current_len]
Speculative Decoding
class SpeculativeDecoder:
def __init__(self, draft_model, target_model, gamma=5):
self.draft = draft_model
self.target = target_model
self.gamma = gamma
def generate(self, prompt):
tokens = self.encode(prompt)
while not self.is_done(tokens):
# Generate gamma tokens with draft
draft_tokens = self.draft.generate(tokens, max_new_tokens=self.gamma)
# Verify with target model
target_probs = self.target.get_probs(draft_tokens)
draft_probs = self.draft.get_probs(draft_tokens)
# Accept/reject tokens
accepted = self.verify(draft_tokens, target_probs, draft_probs)
tokens = torch.cat([tokens, accepted])
return tokens
Optimization Summary
| Technique | Speedup | Memory | Quality |
|---|---|---|---|
| KV Cache | 2-4x | +50% | Same |
| Speculative | 2-3x | Same | Same |
| Quantization | 3-6x | -75% | Slight loss |
| Batching | 2-8x | Same | Same |
Summary
Inference optimization is crucial for deploying LLMs efficiently. Combining multiple techniques achieves the best performance.
Next: We'll explore deployment and serving solutions.