Long Context & Context Window
The context window—the maximum number of tokens a model can process—is one of the most important limitations of modern LLMs. Extending context windows while maintaining performance is an active area of research with significant practical implications.
Context Window Limitations
The maximum sequence length a transformer model can process in a single forward pass, determined by the model's positional encoding scheme and memory constraints.
Standard transformer self-attention has quadratic complexity:
Attention Computational Complexity
Here,
- =
- =
- =
Attention Memory Complexity
Here,
- =
- =
Positional Encoding Background
Absolute Positional Encodings
Standard transformers use learned or sinusoidal absolute position embeddings:
Sinusoidal Positional Encoding
Here,
- =
- =
- =
Problem: Models trained on sequences of length L cannot generalize to positions > L without modification.
Relative Positional Encodings
Relative position encodings encode the distance between tokens rather than absolute position, enabling better length generalization.
RoPE (Rotary Position Embedding)
RoPE encodes positions by rotating query and key vectors in 2D subspaces, enabling the attention mechanism to be relative-position-aware.
RoPE Rotation Matrix
Here,
- =
- =
- =
The key property of RoPE: attention scores depend only on relative positions:
With RoPE, the dot product between queries and keys at positions m and n depends only on their relative position (m - n): q_m^T k_n = g(x_m, x_n, m - n) This enables length generalization when combined with appropriate scaling.
RoPE Scaling Methods
Position Interpolation (PI)
Linearly scale positions to fit within the original context window:
Position Interpolation
Here,
- =
- =
- =
- =
import torch
import math
def rope_with_interpolation(dim: int, max_seq_len: int, original_max: int = 4096):
"""RoPE with position interpolation."""
freqs = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
def apply_rope(x, positions):
# Interpolate positions
scale = original_max / max_seq_len
positions = positions * scale
freqs = freqs.to(x.device)
t = positions.unsqueeze(-1) * freqs.unsqueeze(0)
cos = torch.cos(t)
sin = torch.sin(t)
x_rot = torch.stack([x[..., ::2], x[..., 1::2]], dim=-1)
x_rot = x_rot.reshape(x.shape[:-1] + (-1,))
x_rotated = x_rot * cos + torch.roll(x_rot, 1, dims=-1) * sin
return x_rotated.reshape(x.shape)
return apply_rope
NTK-Aware Scaling
NTK-aware interpolation adjusts the base frequency rather than scaling positions directly:
NTK-Aware RoPE
Here,
- =
- =
- =
def rope_ntk_aware(dim: int, max_seq_len: int, original_max: int = 4096, beta_fast: float = 32, beta_slow: float = 1):
"""NTK-aware RoPE scaling."""
scale = max_seq_len / original_max
# Compute new base frequency
new_base = 10000 * scale ** (dim / (dim - 2))
freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
def apply_rope(x, positions):
freqs = freqs.to(x.device)
t = positions.unsqueeze(-1) * freqs.unsqueeze(0)
cos = torch.cos(t)
sin = torch.sin(t)
x_rot = torch.stack([x[..., ::2], x[..., 1::2]], dim=-1)
x_rot = x_rot.reshape(x.shape[:-1] + (-1,))
x_rotated = x_rot * cos + torch.roll(x_rot, 1, dims=-1) * sin
return x_rotated.reshape(x.shape)
return apply_rope
YaRN (Yet another RoPE extensioN)
YaRN combines multiple techniques for better long-context performance:
YaRN Interpolation Factor
Here,
- =
- =
- =
- =
YaRN achieves better perplexity than PI or NTK-aware alone by applying different scaling factors to different frequency components. High-frequency components (local patterns) are interpolated less, while low-frequency components (global patterns) are interpolated more.
ALiBi (Attention with Linear Biases)
ALiBi adds a linear bias to attention scores based on token distance, eliminating the need for positional embeddings entirely.
ALiBi Attention Bias
Here,
- =
- =
- =
def get_alibi_slopes(num_heads: int) -> torch.Tensor:
"""Generate ALiBi slopes for attention heads."""
ratio = 2 ** (-8 / num_heads)
slopes = [ratio ** i for i in range(1, num_heads + 1)]
return torch.tensor(slopes, dtype=torch.float32)
def alibi_bias(max_seq_len: int, num_heads: int) -> torch.Tensor:
"""Create ALiBi bias matrix."""
slopes = get_alibi_slopes(num_heads)
positions = torch.arange(max_seq_len)
distances = positions.unsqueeze(0) - positions.unsqueeze(1)
bias = -slopes.unsqueeze(-1).unsqueeze(-1) * distances.abs().unsqueeze(0)
return bias
Long-Context Benchmarks
Needle in a Haystack (NIAH)
Tests the model's ability to retrieve a specific fact from a long context:
def create_needle_in_haystack(
needle: str,
haystack: str,
context_length: int,
needle_position: float # 0.0 to 1.0
) -> str:
"""Create a NIAH test case."""
# Truncate haystack to target length
tokens_haystack = tokenizer.encode(haystack)[:context_length - 100]
truncated_haystack = tokenizer.decode(tokens_haystack)
# Insert needle at specified position
insert_pos = int(len(truncated_haystack) * needle_position)
context = (
truncated_haystack[:insert_pos] +
needle +
truncated_haystack[insert_pos:]
)
return f"{context}\n\nWhat is the secret? Answer:"
def evaluate_niah(model, tokenizer, context_lengths, needle_positions):
results = {}
for ctx_len in context_lengths:
for pos in needle_positions:
prompt = create_needle_in_haystack(
needle="The secret code is: ABC-123-XYZ",
haystack=load_haystack_text(),
context_length=ctx_len,
needle_position=pos
)
response = generate(model, tokenizer, prompt)
correct = "ABC-123-XYZ" in response
results[(ctx_len, pos)] = correct
return results
Long-Context Evaluation Metrics
| Benchmark | Context Length | Task | Metric |
|---|---|---|---|
| NIAH | Up to 128K | Retrieval | Accuracy |
| LongBench | Up to 32K | Multiple | F1, EM |
| ∞Bench | Up to 100K | Multiple | Accuracy |
| GovReport | Up to 20K | Summarization | ROUGE |
Efficient Attention Mechanisms
Flash Attention
Flash Attention optimizes memory access patterns rather than reducing computation:
Flash Attention reduces memory from O(n²) to O(n) by computing attention in blocks and avoiding materialization of the full attention matrix. It uses tiling and kernel fusion to maximize GPU memory hierarchy utilization.
import torch
from flash_attn import flash_attention
def flash_attention_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True
) -> torch.Tensor:
"""
q, k, v: (batch, seq_len, num_heads, head_dim)
Returns: (batch, seq_len, num_heads, head_dim)
"""
return flash_attention(q, k, v, causal=causal)
Ring Attention
Ring Attention distributes long sequences across multiple devices in a ring topology:
Ring Attention Memory per Device
Here,
- =
- =
- =
- =
def ring_attention_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
group
) -> torch.Tensor:
"""
Distribute attention computation across devices in a ring.
Each device processes a chunk of the sequence.
"""
chunk_size = q.shape[1] // world_size
q_chunk = q[:, local_rank * chunk_size:(local_rank + 1) * chunk_size]
# Ring communication of KV pairs
output = torch.zeros_like(q_chunk)
k_recv, v_recv = k, v
for step in range(world_size):
# Compute attention for current KV chunk
output += compute_attention_chunk(q_chunk, k_recv, v_recv)
# Send KV to next device, receive from previous
k_recv = send_recv(k_recv, group)
v_recv = send_recv(v_recv, group)
return output
Ring Attention enables context lengths of 1M+ tokens by distributing the KV cache across multiple GPUs. The communication overhead is hidden by overlapping it with computation.
Practical: Handling 100K+ Context
class LongContextHandler:
def __init__(self, model, tokenizer, max_context: int = 128000):
self.model = model
self.tokenizer = tokenizer
self.max_context = max_context
def process_long_document(self, document: str, query: str) -> str:
"""Process a document longer than the context window."""
chunks = self._chunk_document(document)
# Process chunks with sliding window
summaries = []
for chunk in chunks:
summary = self._summarize_chunk(chunk, query)
summaries.append(summary)
# Combine summaries
combined = "\n".join(summaries)
# Final answer with condensed context
return self._generate_answer(combined, query)
def _chunk_document(self, document: str, chunk_size: int = 4096, overlap: int = 512) -> list:
tokens = self.tokenizer.encode(document)
chunks = []
for i in range(0, len(tokens), chunk_size - overlap):
chunk_tokens = tokens[i:i + chunk_size]
chunks.append(self.tokenizer.decode(chunk_tokens))
return chunks
def _summarize_chunk(self, chunk: str, query: str) -> str:
prompt = f"""Summarize the following text, focusing on information relevant to: {query}
Text: {chunk}
Summary:"""
return self._generate(prompt, max_tokens=512)
def _generate_answer(self, context: str, query: str) -> str:
prompt = f"""Based on the following context, answer the question.
Context: {context}
Question: {query}
Answer:"""
return self._generate(prompt, max_tokens=1024)
def _generate(self, prompt: str, max_tokens: int = 256) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7
)
return self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:])
Summary
- Context window is limited by quadratic attention complexity and positional encoding generalization
- RoPE encodes positions via rotation matrices, enabling relative position awareness
- Position Interpolation scales positions linearly to fit longer contexts
- NTK-aware scaling adjusts base frequencies for better long-context performance
- YaRN combines interpolation strategies for optimal results
- ALiBi uses linear attention biases, eliminating positional embeddings
- Flash Attention reduces memory from O(n²) to O(n) via memory-efficient tiling
- Ring Attention distributes long sequences across multiple devices
- For 100K+ contexts, combine efficient attention with chunking and summarization
Practice Exercises
-
RoPE Implementation: Implement RoPE from scratch. Verify that attention scores depend only on relative positions.
-
Position Interpolation: Compare perplexity of a model at 4096, 8192, and 16384 tokens with and without position interpolation.
-
NTK-Aware Scaling: Implement NTK-aware RoPE scaling and compare with PI at 4x context extension.
-
NIAH Test: Create a Needle in a Haystack test and evaluate a model at different context lengths and needle positions.
-
Long Context Pipeline: Build a document QA system that handles 100K+ token documents using chunking and summarization.
Previous: 16 - LLM Inference Optimization ← | Next: 18 - Multimodal LLMs →