Long Context & Context Window

InferenceLong ContextFree Lesson

Advertisement

Long Context & Context Window

The context window—the maximum number of tokens a model can process—is one of the most important limitations of modern LLMs. Extending context windows while maintaining performance is an active area of research with significant practical implications.

Context Window Limitations

The maximum sequence length a transformer model can process in a single forward pass, determined by the model's positional encoding scheme and memory constraints.

Standard transformer self-attention has quadratic complexity:

Attention Computational Complexity

O(n2cdotd)O(n^2 \\cdot d)

Here,

  • =
  • =
  • =

Attention Memory Complexity

O(n2+ncdotd)O(n^2 + n \\cdot d)

Here,

  • =
  • =

Positional Encoding Background

Absolute Positional Encodings

Standard transformers use learned or sinusoidal absolute position embeddings:

Sinusoidal Positional Encoding

PE(pos,2i)=sinleft(fracpos100002i/dright),quadPE(pos,2i+1)=cosleft(fracpos100002i/dright)PE_{(pos, 2i)} = \\sin\\left(\\frac{pos}{10000^{2i/d}}\\right), \\quad PE_{(pos, 2i+1)} = \\cos\\left(\\frac{pos}{10000^{2i/d}}\\right)

Here,

  • =
  • =
  • =

Problem: Models trained on sequences of length L cannot generalize to positions > L without modification.

Relative Positional Encodings

Relative position encodings encode the distance between tokens rather than absolute position, enabling better length generalization.

RoPE (Rotary Position Embedding)

RoPE encodes positions by rotating query and key vectors in 2D subspaces, enabling the attention mechanism to be relative-position-aware.

RoPE Rotation Matrix

R_\\theta(x, m) = \\begin{pmatrix} x_0 \\cos(m\\theta) - x_1 \\sin(m\\theta) \\ x_0 \\sin(m\\theta) + x_1 \\cos(m\\theta) \\end{pmatrix}

Here,

  • =
  • =
  • =
f(x,m)=beginpmatrixx0 x1 x2 x3 vdots xd2 xd1endpmatrixotimesbeginpmatrixcos(mtheta0) cos(mtheta0) cos(mtheta1) cos(mtheta1) vdots cos(mthetad/21) cos(mthetad/21)endpmatrix+beginpmatrixx1 x0 x3 x2 vdots xd1 xd2endpmatrixotimesbeginpmatrixsin(mtheta0) sin(mtheta0) sin(mtheta1) sin(mtheta1) vdots sin(mthetad/21) sin(mthetad/21)endpmatrixf(x, m) = \\begin{pmatrix} x_0 \ x_1 \ x_2 \ x_3 \ \\vdots \ x_{d-2} \ x_{d-1} \\end{pmatrix} \\otimes \\begin{pmatrix} \\cos(m\\theta_0) \ \\cos(m\\theta_0) \ \\cos(m\\theta_1) \ \\cos(m\\theta_1) \ \\vdots \ \\cos(m\\theta_{d/2-1}) \ \\cos(m\\theta_{d/2-1}) \\end{pmatrix} + \\begin{pmatrix} -x_1 \ x_0 \ -x_3 \ x_2 \ \\vdots \ -x_{d-1} \ x_{d-2} \\end{pmatrix} \\otimes \\begin{pmatrix} \\sin(m\\theta_0) \ \\sin(m\\theta_0) \ \\sin(m\\theta_1) \ \\sin(m\\theta_1) \ \\vdots \ \\sin(m\\theta_{d/2-1}) \ \\sin(m\\theta_{d/2-1}) \\end{pmatrix}

The key property of RoPE: attention scores depend only on relative positions:

With RoPE, the dot product between queries and keys at positions m and n depends only on their relative position (m - n): q_m^T k_n = g(x_m, x_n, m - n) This enables length generalization when combined with appropriate scaling.

RoPE Scaling Methods

Position Interpolation (PI)

Linearly scale positions to fit within the original context window:

Position Interpolation

m=mcdotfracLtextoriginalLtexttargetm' = m \\cdot \\frac{L_{\\text{original}}}{L_{\\text{target}}}

Here,

  • =
  • =
  • =
  • =
import torch
import math

def rope_with_interpolation(dim: int, max_seq_len: int, original_max: int = 4096):
    """RoPE with position interpolation."""
    freqs = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
    
    def apply_rope(x, positions):
        # Interpolate positions
        scale = original_max / max_seq_len
        positions = positions * scale
        
        freqs = freqs.to(x.device)
        t = positions.unsqueeze(-1) * freqs.unsqueeze(0)
        cos = torch.cos(t)
        sin = torch.sin(t)
        
        x_rot = torch.stack([x[..., ::2], x[..., 1::2]], dim=-1)
        x_rot = x_rot.reshape(x.shape[:-1] + (-1,))
        
        x_rotated = x_rot * cos + torch.roll(x_rot, 1, dims=-1) * sin
        return x_rotated.reshape(x.shape)
    
    return apply_rope

NTK-Aware Scaling

NTK-aware interpolation adjusts the base frequency rather than scaling positions directly:

NTK-Aware RoPE

thetak=left(alphacdot10000right)2k/d\\theta'_k = \\left(\\alpha \\cdot 10000\\right)^{-2k/d}

Here,

  • =
  • =
  • =
def rope_ntk_aware(dim: int, max_seq_len: int, original_max: int = 4096, beta_fast: float = 32, beta_slow: float = 1):
    """NTK-aware RoPE scaling."""
    scale = max_seq_len / original_max
    
    # Compute new base frequency
    new_base = 10000 * scale ** (dim / (dim - 2))
    
    freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
    
    def apply_rope(x, positions):
        freqs = freqs.to(x.device)
        t = positions.unsqueeze(-1) * freqs.unsqueeze(0)
        cos = torch.cos(t)
        sin = torch.sin(t)
        
        x_rot = torch.stack([x[..., ::2], x[..., 1::2]], dim=-1)
        x_rot = x_rot.reshape(x.shape[:-1] + (-1,))
        
        x_rotated = x_rot * cos + torch.roll(x_rot, 1, dims=-1) * sin
        return x_rotated.reshape(x.shape)
    
    return apply_rope

YaRN (Yet another RoPE extensioN)

YaRN combines multiple techniques for better long-context performance:

YaRN Interpolation Factor

gamma(m)=1frac(1alphacdotm/L)cdotLm\\gamma(m) = 1 - \\frac{(1 - \\alpha \\cdot m / L) \\cdot L}{m}

Here,

  • =
  • =
  • =
  • =

YaRN achieves better perplexity than PI or NTK-aware alone by applying different scaling factors to different frequency components. High-frequency components (local patterns) are interpolated less, while low-frequency components (global patterns) are interpolated more.

ALiBi (Attention with Linear Biases)

ALiBi adds a linear bias to attention scores based on token distance, eliminating the need for positional embeddings entirely.

ALiBi Attention Bias

textAttention(Q,K,V)=textsoftmaxleft(fracQKTsqrtd+Bright)V\\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d}} + B\\right)V

Here,

  • =
  • =
  • =
def get_alibi_slopes(num_heads: int) -> torch.Tensor:
    """Generate ALiBi slopes for attention heads."""
    ratio = 2 ** (-8 / num_heads)
    slopes = [ratio ** i for i in range(1, num_heads + 1)]
    return torch.tensor(slopes, dtype=torch.float32)

def alibi_bias(max_seq_len: int, num_heads: int) -> torch.Tensor:
    """Create ALiBi bias matrix."""
    slopes = get_alibi_slopes(num_heads)
    positions = torch.arange(max_seq_len)
    distances = positions.unsqueeze(0) - positions.unsqueeze(1)
    bias = -slopes.unsqueeze(-1).unsqueeze(-1) * distances.abs().unsqueeze(0)
    return bias

Long-Context Benchmarks

Needle in a Haystack (NIAH)

Tests the model's ability to retrieve a specific fact from a long context:

def create_needle_in_haystack(
    needle: str,
    haystack: str,
    context_length: int,
    needle_position: float  # 0.0 to 1.0
) -> str:
    """Create a NIAH test case."""
    # Truncate haystack to target length
    tokens_haystack = tokenizer.encode(haystack)[:context_length - 100]
    truncated_haystack = tokenizer.decode(tokens_haystack)
    
    # Insert needle at specified position
    insert_pos = int(len(truncated_haystack) * needle_position)
    context = (
        truncated_haystack[:insert_pos] + 
        needle + 
        truncated_haystack[insert_pos:]
    )
    
    return f"{context}\n\nWhat is the secret? Answer:"

def evaluate_niah(model, tokenizer, context_lengths, needle_positions):
    results = {}
    for ctx_len in context_lengths:
        for pos in needle_positions:
            prompt = create_needle_in_haystack(
                needle="The secret code is: ABC-123-XYZ",
                haystack=load_haystack_text(),
                context_length=ctx_len,
                needle_position=pos
            )
            
            response = generate(model, tokenizer, prompt)
            correct = "ABC-123-XYZ" in response
            
            results[(ctx_len, pos)] = correct
    
    return results

Long-Context Evaluation Metrics

BenchmarkContext LengthTaskMetric
NIAHUp to 128KRetrievalAccuracy
LongBenchUp to 32KMultipleF1, EM
∞BenchUp to 100KMultipleAccuracy
GovReportUp to 20KSummarizationROUGE

Efficient Attention Mechanisms

Flash Attention

Flash Attention optimizes memory access patterns rather than reducing computation:

Flash Attention reduces memory from O(n²) to O(n) by computing attention in blocks and avoiding materialization of the full attention matrix. It uses tiling and kernel fusion to maximize GPU memory hierarchy utilization.

import torch
from flash_attn import flash_attention

def flash_attention_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    causal: bool = True
) -> torch.Tensor:
    """
    q, k, v: (batch, seq_len, num_heads, head_dim)
    Returns: (batch, seq_len, num_heads, head_dim)
    """
    return flash_attention(q, k, v, causal=causal)

Ring Attention

Ring Attention distributes long sequences across multiple devices in a ring topology:

Ring Attention Memory per Device

textMemorytextdevice=Oleft(fracnPcdotd+fracn2P2cdotbright)\\text{Memory}_{\\text{device}} = O\\left(\\frac{n}{P} \\cdot d + \\frac{n^2}{P^2} \\cdot b\\right)

Here,

  • =
  • =
  • =
  • =
def ring_attention_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    group
) -> torch.Tensor:
    """
    Distribute attention computation across devices in a ring.
    Each device processes a chunk of the sequence.
    """
    chunk_size = q.shape[1] // world_size
    q_chunk = q[:, local_rank * chunk_size:(local_rank + 1) * chunk_size]
    
    # Ring communication of KV pairs
    output = torch.zeros_like(q_chunk)
    k_recv, v_recv = k, v
    
    for step in range(world_size):
        # Compute attention for current KV chunk
        output += compute_attention_chunk(q_chunk, k_recv, v_recv)
        
        # Send KV to next device, receive from previous
        k_recv = send_recv(k_recv, group)
        v_recv = send_recv(v_recv, group)
    
    return output

Ring Attention enables context lengths of 1M+ tokens by distributing the KV cache across multiple GPUs. The communication overhead is hidden by overlapping it with computation.

Practical: Handling 100K+ Context

class LongContextHandler:
    def __init__(self, model, tokenizer, max_context: int = 128000):
        self.model = model
        self.tokenizer = tokenizer
        self.max_context = max_context
    
    def process_long_document(self, document: str, query: str) -> str:
        """Process a document longer than the context window."""
        chunks = self._chunk_document(document)
        
        # Process chunks with sliding window
        summaries = []
        for chunk in chunks:
            summary = self._summarize_chunk(chunk, query)
            summaries.append(summary)
        
        # Combine summaries
        combined = "\n".join(summaries)
        
        # Final answer with condensed context
        return self._generate_answer(combined, query)
    
    def _chunk_document(self, document: str, chunk_size: int = 4096, overlap: int = 512) -> list:
        tokens = self.tokenizer.encode(document)
        chunks = []
        
        for i in range(0, len(tokens), chunk_size - overlap):
            chunk_tokens = tokens[i:i + chunk_size]
            chunks.append(self.tokenizer.decode(chunk_tokens))
        
        return chunks
    
    def _summarize_chunk(self, chunk: str, query: str) -> str:
        prompt = f"""Summarize the following text, focusing on information relevant to: {query}

Text: {chunk}

Summary:"""
        return self._generate(prompt, max_tokens=512)
    
    def _generate_answer(self, context: str, query: str) -> str:
        prompt = f"""Based on the following context, answer the question.

Context: {context}

Question: {query}

Answer:"""
        return self._generate(prompt, max_tokens=1024)
    
    def _generate(self, prompt: str, max_tokens: int = 256) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=0.7
            )
        return self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:])

Summary

  • Context window is limited by quadratic attention complexity and positional encoding generalization
  • RoPE encodes positions via rotation matrices, enabling relative position awareness
  • Position Interpolation scales positions linearly to fit longer contexts
  • NTK-aware scaling adjusts base frequencies for better long-context performance
  • YaRN combines interpolation strategies for optimal results
  • ALiBi uses linear attention biases, eliminating positional embeddings
  • Flash Attention reduces memory from O(n²) to O(n) via memory-efficient tiling
  • Ring Attention distributes long sequences across multiple devices
  • For 100K+ contexts, combine efficient attention with chunking and summarization

Practice Exercises

  1. RoPE Implementation: Implement RoPE from scratch. Verify that attention scores depend only on relative positions.

  2. Position Interpolation: Compare perplexity of a model at 4096, 8192, and 16384 tokens with and without position interpolation.

  3. NTK-Aware Scaling: Implement NTK-aware RoPE scaling and compare with PI at 4x context extension.

  4. NIAH Test: Create a Needle in a Haystack test and evaluate a model at different context lengths and needle positions.

  5. Long Context Pipeline: Build a document QA system that handles 100K+ token documents using chunking and summarization.


Previous: 16 - LLM Inference Optimization ← | Next: 18 - Multimodal LLMs →

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement