Architectures
Sparse Attention Patterns — Efficient Long-Range Modeling
Full quadratic attention is unnecessary for most tasks. Sparse attention patterns like local windows, global tokens, and random connections can reduce complexity to O(n·√n) or O(n) while preserving the ability to model long-range dependencies.
- Local Attention — Focus on nearby tokens with sliding windows
- Global Tokens — Special tokens that attend to all positions
- Longformer & BigBird — Combining patterns for optimal coverage
- Random Attention — Stochastic connections for guaranteed mixing
Sparse attention proves that you don't need to see everything to understand everything.
Sparse Attention Patterns
Full attention computes interactions between all token pairs, but many of these interactions are redundant or unnecessary. Sparse attention patterns selectively compute attention only for a subset of token pairs, dramatically reducing computational cost while maintaining model quality.
DfSparse Attention
Sparse attention restricts each token to attend to only a subset of other tokens, defined by a sparsity pattern. The pattern determines which token pairs compute attention scores, reducing the number of attention connections from O(n²) to O(n·k) where k << n.
Why Sparse Attention Works
Information Bottleneck Theory
In natural language, most semantic relationships can be captured through local context and a small number of long-range connections. Sparse attention patterns exploit this by maintaining dense local attention while selectively adding long-range connections.
Coverage Guarantees
Attention Coverage
Here,
- =Sparsity pattern matrix (n × n)
- =1 if tokens i and j attend, 0 otherwise
- =Sequence length
- =Coverage ratio
Sparsity Benefits
For sequence length n = 8,192:
Full attention: 8,192² = 67,108,864 connections Local window (w=256): 8,192 × 256 = 2,097,152 connections (3.1% of full) BigBird pattern: 8,192 × 512 = 4,194,304 connections (6.2% of full)
The sparse patterns use 94-97% fewer connections while maintaining model expressivity.
Local (Sliding Window) Attention
Basic Local Attention
DfLocal Attention
Local attention restricts each token to attend only to a fixed window of nearby tokens, typically w/2 tokens on each side. This captures local context efficiently with O(n·w) complexity.
Local Attention Window
Here,
- =Sparsity pattern entry
- =Window size
- =Token positions
class LocalAttention(nn.Module):
"""Local sliding window attention."""
def __init__(self, d_model, n_heads, window_size=256):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.window_size = window_size
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, L, _ = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
# Create local attention mask
mask = self._create_local_mask(L, x.device)
# Compute attention with mask
scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
output = torch.einsum("bhlm,bmhd->blhd", attn, v)
return self.out(output.reshape(B, L, -1))
def _create_local_mask(self, seq_len, device):
"""Create sliding window attention mask."""
mask = torch.zeros(seq_len, seq_len, device=device)
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 1
return mask.bool()
Dilated Local Attention
Dilated Window
Here,
- =Window size
- =Dilation rate
Dilated attention increases the receptive field without increasing computation. With dilation d, a window of size w covers a range of w×d tokens, enabling longer-range dependencies with the same number of attention connections.
Global Attention Tokens
Special Token Approach
DfGlobal Tokens
Global tokens are special positions that attend to all other tokens and are attended to by all tokens. They serve as information bottlenecks, aggregating and distributing information across the sequence.
class GlobalAttention(nn.Module):
"""Attention with global tokens."""
def __init__(self, d_model, n_heads, n_global=8):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.n_global = n_global
# Global token embeddings (learned)
self.global_tokens = nn.Parameter(torch.randn(1, n_global, d_model))
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, L, _ = x.shape
# Prepend global tokens
global_tokens = self.global_tokens.expand(B, -1, -1)
x_with_global = torch.cat([global_tokens, x], dim=1)
L_total = L + self.n_global
# Compute attention
qkv = self.qkv(x_with_global).reshape(B, L_total, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
# Create pattern: global tokens attend everywhere, local tokens attend to globals + local
mask = self._create_global_mask(L, x.device)
scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
output = torch.einsum("bhlm,bmhd->blhd", attn, v)
return self.out(output.reshape(B, L_total, -1))
def _create_global_mask(self, seq_len, device):
"""Global tokens attend to all, local tokens attend to globals + local window."""
mask = torch.zeros(seq_len + self.n_global, seq_len + self.n_global, device=device)
# Global tokens attend to everything
mask[:self.n_global, :] = 1
# Everything attends to global tokens
mask[:, :self.n_global] = 1
return mask.bool()
Longformer Architecture
Combining Local + Global
DfLongformer
Longformer (Beltagy et al., 2020) combines sliding window attention with a small number of global tokens, achieving O(n) complexity while maintaining the ability to model long-range dependencies through the global tokens.
Longformer Attention Pattern
Here,
- =Local sliding window
- =Global token connections
- =Dilated window pattern
class LongformerAttention(nn.Module):
"""Longformer attention with local, global, and dilated patterns."""
def __init__(self, d_model, n_heads, window_size=512, n_global=1, dilation=1):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.window_size = window_size
self.dilation = dilation
# Global attention on [CLS] token
self.global_attention = GlobalAttention(d_model, n_heads, n_global)
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, L, _ = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
# Create combined sparse mask
mask = self._create_longformer_mask(L, x.device)
scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
output = torch.einsum("bhlm,bmhd->blhd", attn, v)
return self.out(output.reshape(B, L, -1))
def _create_longformer_mask(self, seq_len, device):
"""Create Longformer sparse attention mask."""
mask = torch.zeros(seq_len, seq_len, device=device)
for i in range(seq_len):
# Local window
local_start = max(0, i - self.window_size // 2)
local_end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, local_start:local_end] = 1
# Dilated window
for j in range(0, seq_len, self.dilation):
if abs(i - j) <= self.window_size // 2:
mask[i, j] = 1
# Global token (first token)
mask[i, 0] = 1
mask[0, i] = 1
return mask.bool()
BigBird Architecture
Three-Way Sparse Pattern
DfBigBird
BigBird (Zaheer et al., 2020) combines local attention, global tokens, and random connections. The random connections ensure that the attention graph is connected, guaranteeing that information can flow between any two tokens in O(log n) steps.
BigBird Pattern
Here,
- =Sliding window (w connections per token)
- =Global tokens (g connections per token)
- =Random connections (r connections per token)
class BigBirdAttention(nn.Module):
"""BigBird attention with local + global + random."""
def __init__(self, d_model, n_heads, window_size=256, n_global=64, n_random=64):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.window_size = window_size
self.n_global = n_global
self.n_random = n_random
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, L, _ = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
mask = self._create_bigbird_mask(L, x.device)
scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
output = torch.einsum("bhlm,bmhd->blhd", attn, v)
return self.out(output.reshape(B, L, -1))
def _create_bigbird_mask(self, seq_len, device):
"""Create BigBird sparse mask."""
mask = torch.zeros(seq_len, seq_len, device=device)
for i in range(seq_len):
# Local window
local_start = max(0, i - self.window_size // 2)
local_end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, local_start:local_end] = 1
# Global tokens (first n_global)
mask[i, :self.n_global] = 1
mask[:self.n_global, i] = 1
# Random connections
random_indices = torch.randperm(seq_len)[:self.n_random]
mask[i, random_indices] = 1
return mask.bool()
Theoretical Guarantees
Connectivity and Information Flow
Information Flow Distance
Here,
- =Minimum number of attention hops from i to j
- =Sparsity pattern edge
BigBird's random connections ensure that the expected distance between any two tokens is O(log n), guaranteeing that information can flow across the entire sequence in logarithmic steps. This is crucial for tasks requiring long-range reasoning.
Theoretical Properties
| Property | Local Only | Local + Global | BigBird |
|---|---|---|---|
| Connected | No | Yes | Yes |
| Log diameter | No | Yes | Yes |
| TC reduction | Yes | Yes | Yes |
| Expressive enough | No | Yes | Yes |
Linear Complexity Transformers
Routing Transformers
DfRouting Attention
Routing attention groups tokens into clusters using a learned routing function, then computes attention only within and between clusters, achieving sub-quadratic complexity.
Routing Complexity
Here,
- =Sequence length
- =Number of clusters
- =Model dimension
Reformer (LSH Attention)
class LSHAttention(nn.Module):
"""Locality-Sensitive Hashing attention."""
def __init__(self, d_model, n_heads, n_hashes=8):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.n_hashes = n_hashes
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x):
B, L, _ = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
# Hash queries and keys
q_hash = self._lsh_hash(q)
k_hash = self._lsh_hash(k)
# Sort by hash and compute block-wise attention
output = self._block_attention(q, k, v, q_hash, k_hash)
return self.out(output.reshape(B, L, -1))
def _lsh_hash(self, x):
"""Locality-sensitive hashing."""
random_projection = torch.randn(x.shape[-1], self.n_hashes, device=x.device)
return torch.einsum("...d,dh->...h", x, random_projection)
def _block_attention(self, q, k, v, q_hash, k_hash):
"""Compute attention within hash buckets."""
# Simplified: in practice, this uses causal masking and multi-round hashing
B, L, H, D = q.shape
# For each hash round, sort and compute attention
outputs = []
for h in range(self.n_hashes):
# Sort by hash
_, sort_idx = q_hash[:, :, h].sort(dim=1)
q_sorted = torch.gather(q, 1, sort_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, D))
k_sorted = torch.gather(k, 1, sort_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, D))
v_sorted = torch.gather(v, 1, sort_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, D))
# Block-wise attention (simplified)
scores = torch.einsum("blhd,bmhd->bhlm", q_sorted, k_sorted) / (D ** 0.5)
attn = F.softmax(scores, dim=-1)
output = torch.einsum("bhlm,bmhd->blhd", attn, v_sorted)
outputs.append(output)
# Average over hash rounds
return torch.stack(outputs).mean(dim=0)
Choosing Sparse Patterns
When to Use Each Pattern
| Pattern | Best For | Sequence Length | Quality |
|---|---|---|---|
| Local Only | Text classification, NER | <8K | Good |
| Local + Global | QA, summarization | 8K-128K | Very Good |
| BigBird | Long document understanding | 16K-256K | Excellent |
| LSH | Memory-constrained | 16K-64K | Good |
| Dilated | Audio, time series | 8K-128K | Good |
For most practical applications, Longformer's local + global pattern provides the best balance of efficiency and quality. BigBird's random connections add theoretical guarantees but marginal practical benefit for typical NLP tasks.
Practice Exercises
-
Conceptual: Explain why random connections in BigBird are necessary for theoretical guarantees, even though local + global patterns often work well in practice.
-
Mathematical: For a sequence of length 32,768, compute the memory required for full attention vs Longformer with window size 512 and 1 global token.
-
Practical: Implement a sliding window attention and compare its speed and quality against full attention on a language modeling task.
-
Research: Investigate how sparse attention patterns affect the model's ability to perform in-context learning. Does sparsity hurt or help few-shot performance?
Key Takeaways:
- Sparse attention reduces complexity from O(n²) to O(n·k) where k << n
- Local attention captures nearby context efficiently
- Global tokens provide information bottlenecks for long-range reasoning
- BigBird combines local + global + random for theoretical guarantees
- Longformer is the most practical sparse pattern for NLP tasks
- Dilated attention increases receptive field without increasing computation
- Sparse patterns maintain 95%+ of full attention quality for most tasks
What to Learn Next
-> RWKV and Linear Attention O(n) attention through kernel approximation and recurrent formulations.
-> Flash Attention and Memory Efficiency IO-aware attention optimization for modern hardware.
-> Hardware-Aware LLM Design Optimizing model architecture for GPU memory hierarchy.
-> Long Context Window Scaling context windows to millions of tokens.
-> State Space Models Mamba and linear-time sequence modeling.
-> Mixture of Experts Sparse architectures that scale efficiently.