CW

Sparse Attention Patterns

ArchitecturesEfficient AttentionFree Lesson

Advertisement

Architectures

Sparse Attention Patterns — Efficient Long-Range Modeling

Full quadratic attention is unnecessary for most tasks. Sparse attention patterns like local windows, global tokens, and random connections can reduce complexity to O(n·√n) or O(n) while preserving the ability to model long-range dependencies.

  • Local Attention — Focus on nearby tokens with sliding windows
  • Global Tokens — Special tokens that attend to all positions
  • Longformer & BigBird — Combining patterns for optimal coverage
  • Random Attention — Stochastic connections for guaranteed mixing

Sparse attention proves that you don't need to see everything to understand everything.

Sparse Attention Patterns

Full attention computes interactions between all token pairs, but many of these interactions are redundant or unnecessary. Sparse attention patterns selectively compute attention only for a subset of token pairs, dramatically reducing computational cost while maintaining model quality.

DfSparse Attention

Sparse attention restricts each token to attend to only a subset of other tokens, defined by a sparsity pattern. The pattern determines which token pairs compute attention scores, reducing the number of attention connections from O(n²) to O(n·k) where k << n.

Why Sparse Attention Works

Information Bottleneck Theory

In natural language, most semantic relationships can be captured through local context and a small number of long-range connections. Sparse attention patterns exploit this by maintaining dense local attention while selectively adding long-range connections.

Coverage Guarantees

Attention Coverage

C(S)=frac(i,j):Sij=1n2C(S) = \\frac{|\\{(i,j) : S_{ij} = 1\\}|}{n^2}

Here,

  • SS=Sparsity pattern matrix (n × n)
  • SijS_{ij}=1 if tokens i and j attend, 0 otherwise
  • nn=Sequence length
  • C(S)C(S)=Coverage ratio

Sparsity Benefits

For sequence length n = 8,192:

Full attention: 8,192² = 67,108,864 connections Local window (w=256): 8,192 × 256 = 2,097,152 connections (3.1% of full) BigBird pattern: 8,192 × 512 = 4,194,304 connections (6.2% of full)

The sparse patterns use 94-97% fewer connections while maintaining model expressivity.

Local (Sliding Window) Attention

Basic Local Attention

DfLocal Attention

Local attention restricts each token to attend only to a fixed window of nearby tokens, typically w/2 tokens on each side. This captures local context efficiently with O(n·w) complexity.

Local Attention Window

S_{ij} = \\begin{cases} 1 & \\text{if } |i - j| \\leq w/2 \\\\ 0 & \\text{otherwise} \\end{cases}

Here,

  • SijS_{ij}=Sparsity pattern entry
  • ww=Window size
  • i,ji, j=Token positions
class LocalAttention(nn.Module):
    """Local sliding window attention."""
    
    def __init__(self, d_model, n_heads, window_size=256):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.window_size = window_size
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, L, _ = x.shape
        
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        # Create local attention mask
        mask = self._create_local_mask(L, x.device)
        
        # Compute attention with mask
        scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        output = torch.einsum("bhlm,bmhd->blhd", attn, v)
        
        return self.out(output.reshape(B, L, -1))
    
    def _create_local_mask(self, seq_len, device):
        """Create sliding window attention mask."""
        mask = torch.zeros(seq_len, seq_len, device=device)
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 1
        return mask.bool()

Dilated Local Attention

Dilated Window

S_{ij} = \\begin{cases} 1 & \\text{if } |i - j| \\leq w/2 \\text{ and } |i - j| \\bmod d = 0 \\\\ 0 & \\text{otherwise} \\end{cases}

Here,

  • ww=Window size
  • dd=Dilation rate

Dilated attention increases the receptive field without increasing computation. With dilation d, a window of size w covers a range of w×d tokens, enabling longer-range dependencies with the same number of attention connections.

Global Attention Tokens

Special Token Approach

DfGlobal Tokens

Global tokens are special positions that attend to all other tokens and are attended to by all tokens. They serve as information bottlenecks, aggregating and distributing information across the sequence.

class GlobalAttention(nn.Module):
    """Attention with global tokens."""
    
    def __init__(self, d_model, n_heads, n_global=8):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.n_global = n_global
        
        # Global token embeddings (learned)
        self.global_tokens = nn.Parameter(torch.randn(1, n_global, d_model))
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, L, _ = x.shape
        
        # Prepend global tokens
        global_tokens = self.global_tokens.expand(B, -1, -1)
        x_with_global = torch.cat([global_tokens, x], dim=1)
        L_total = L + self.n_global
        
        # Compute attention
        qkv = self.qkv(x_with_global).reshape(B, L_total, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        # Create pattern: global tokens attend everywhere, local tokens attend to globals + local
        mask = self._create_global_mask(L, x.device)
        
        scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        output = torch.einsum("bhlm,bmhd->blhd", attn, v)
        
        return self.out(output.reshape(B, L_total, -1))
    
    def _create_global_mask(self, seq_len, device):
        """Global tokens attend to all, local tokens attend to globals + local window."""
        mask = torch.zeros(seq_len + self.n_global, seq_len + self.n_global, device=device)
        
        # Global tokens attend to everything
        mask[:self.n_global, :] = 1
        # Everything attends to global tokens
        mask[:, :self.n_global] = 1
        
        return mask.bool()

Longformer Architecture

Combining Local + Global

DfLongformer

Longformer (Beltagy et al., 2020) combines sliding window attention with a small number of global tokens, achieving O(n) complexity while maintaining the ability to model long-range dependencies through the global tokens.

Longformer Attention Pattern

SijtextLong=underbraceSijtextlocaltextwindowlorunderbraceSijtextglobaltextglobaltokenslorunderbraceSijtextdilatedtextdilatedwindowS_{ij}^{\\text{Long}} = \\underbrace{S_{ij}^{\\text{local}}}_{\\text{window}} \\lor \\underbrace{S_{ij}^{\\text{global}}}_{\\text{global tokens}} \\lor \\underbrace{S_{ij}^{\\text{dilated}}}_{\\text{dilated window}}

Here,

  • SijlocalS_{ij}^{\text{local}}=Local sliding window
  • SijglobalS_{ij}^{\text{global}}=Global token connections
  • SijdilatedS_{ij}^{\text{dilated}}=Dilated window pattern
class LongformerAttention(nn.Module):
    """Longformer attention with local, global, and dilated patterns."""
    
    def __init__(self, d_model, n_heads, window_size=512, n_global=1, dilation=1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.window_size = window_size
        self.dilation = dilation
        
        # Global attention on [CLS] token
        self.global_attention = GlobalAttention(d_model, n_heads, n_global)
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, L, _ = x.shape
        
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        # Create combined sparse mask
        mask = self._create_longformer_mask(L, x.device)
        
        scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        output = torch.einsum("bhlm,bmhd->blhd", attn, v)
        
        return self.out(output.reshape(B, L, -1))
    
    def _create_longformer_mask(self, seq_len, device):
        """Create Longformer sparse attention mask."""
        mask = torch.zeros(seq_len, seq_len, device=device)
        
        for i in range(seq_len):
            # Local window
            local_start = max(0, i - self.window_size // 2)
            local_end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, local_start:local_end] = 1
            
            # Dilated window
            for j in range(0, seq_len, self.dilation):
                if abs(i - j) <= self.window_size // 2:
                    mask[i, j] = 1
            
            # Global token (first token)
            mask[i, 0] = 1
            mask[0, i] = 1
        
        return mask.bool()

BigBird Architecture

Three-Way Sparse Pattern

DfBigBird

BigBird (Zaheer et al., 2020) combines local attention, global tokens, and random connections. The random connections ensure that the attention graph is connected, guaranteeing that information can flow between any two tokens in O(log n) steps.

BigBird Pattern

StextBigBird=StextlocallorStextgloballorStextrandomS^{\\text{BigBird}} = S^{\\text{local}} \\lor S^{\\text{global}} \\lor S^{\\text{random}}

Here,

  • SlocalS^{\text{local}}=Sliding window (w connections per token)
  • SglobalS^{\text{global}}=Global tokens (g connections per token)
  • SrandomS^{\text{random}}=Random connections (r connections per token)
class BigBirdAttention(nn.Module):
    """BigBird attention with local + global + random."""
    
    def __init__(self, d_model, n_heads, window_size=256, n_global=64, n_random=64):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.window_size = window_size
        self.n_global = n_global
        self.n_random = n_random
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, L, _ = x.shape
        
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        mask = self._create_bigbird_mask(L, x.device)
        
        scores = torch.einsum("blhd,bmhd->bhlm", q, k) / (self.d_k ** 0.5)
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        output = torch.einsum("bhlm,bmhd->blhd", attn, v)
        
        return self.out(output.reshape(B, L, -1))
    
    def _create_bigbird_mask(self, seq_len, device):
        """Create BigBird sparse mask."""
        mask = torch.zeros(seq_len, seq_len, device=device)
        
        for i in range(seq_len):
            # Local window
            local_start = max(0, i - self.window_size // 2)
            local_end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, local_start:local_end] = 1
            
            # Global tokens (first n_global)
            mask[i, :self.n_global] = 1
            mask[:self.n_global, i] = 1
            
            # Random connections
            random_indices = torch.randperm(seq_len)[:self.n_random]
            mask[i, random_indices] = 1
        
        return mask.bool()

Theoretical Guarantees

Connectivity and Information Flow

Information Flow Distance

d(i,j)=mintextpathitojsumkSk,k+1d(i,j) = \\min_{\\text{path } i \\to j} \\sum_{k} S_{k,k+1}

Here,

  • d(i,j)d(i,j)=Minimum number of attention hops from i to j
  • Sk,k+1S_{k,k+1}=Sparsity pattern edge

BigBird's random connections ensure that the expected distance between any two tokens is O(log n), guaranteeing that information can flow across the entire sequence in logarithmic steps. This is crucial for tasks requiring long-range reasoning.

Theoretical Properties

PropertyLocal OnlyLocal + GlobalBigBird
ConnectedNoYesYes
Log diameterNoYesYes
TC reductionYesYesYes
Expressive enoughNoYesYes

Linear Complexity Transformers

Routing Transformers

DfRouting Attention

Routing attention groups tokens into clusters using a learned routing function, then computes attention only within and between clusters, achieving sub-quadratic complexity.

Routing Complexity

Oleft(fracn2k+ncdotkcdotdright)O\\left(\\frac{n^2}{k} + n \\cdot k \\cdot d\\right)

Here,

  • nn=Sequence length
  • kk=Number of clusters
  • dd=Model dimension

Reformer (LSH Attention)

class LSHAttention(nn.Module):
    """Locality-Sensitive Hashing attention."""
    
    def __init__(self, d_model, n_heads, n_hashes=8):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.n_hashes = n_hashes
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, L, _ = x.shape
        
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        # Hash queries and keys
        q_hash = self._lsh_hash(q)
        k_hash = self._lsh_hash(k)
        
        # Sort by hash and compute block-wise attention
        output = self._block_attention(q, k, v, q_hash, k_hash)
        
        return self.out(output.reshape(B, L, -1))
    
    def _lsh_hash(self, x):
        """Locality-sensitive hashing."""
        random_projection = torch.randn(x.shape[-1], self.n_hashes, device=x.device)
        return torch.einsum("...d,dh->...h", x, random_projection)
    
    def _block_attention(self, q, k, v, q_hash, k_hash):
        """Compute attention within hash buckets."""
        # Simplified: in practice, this uses causal masking and multi-round hashing
        B, L, H, D = q.shape
        
        # For each hash round, sort and compute attention
        outputs = []
        for h in range(self.n_hashes):
            # Sort by hash
            _, sort_idx = q_hash[:, :, h].sort(dim=1)
            
            q_sorted = torch.gather(q, 1, sort_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, D))
            k_sorted = torch.gather(k, 1, sort_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, D))
            v_sorted = torch.gather(v, 1, sort_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, D))
            
            # Block-wise attention (simplified)
            scores = torch.einsum("blhd,bmhd->bhlm", q_sorted, k_sorted) / (D ** 0.5)
            attn = F.softmax(scores, dim=-1)
            output = torch.einsum("bhlm,bmhd->blhd", attn, v_sorted)
            outputs.append(output)
        
        # Average over hash rounds
        return torch.stack(outputs).mean(dim=0)

Choosing Sparse Patterns

When to Use Each Pattern

PatternBest ForSequence LengthQuality
Local OnlyText classification, NER<8KGood
Local + GlobalQA, summarization8K-128KVery Good
BigBirdLong document understanding16K-256KExcellent
LSHMemory-constrained16K-64KGood
DilatedAudio, time series8K-128KGood

For most practical applications, Longformer's local + global pattern provides the best balance of efficiency and quality. BigBird's random connections add theoretical guarantees but marginal practical benefit for typical NLP tasks.

Practice Exercises

  1. Conceptual: Explain why random connections in BigBird are necessary for theoretical guarantees, even though local + global patterns often work well in practice.

  2. Mathematical: For a sequence of length 32,768, compute the memory required for full attention vs Longformer with window size 512 and 1 global token.

  3. Practical: Implement a sliding window attention and compare its speed and quality against full attention on a language modeling task.

  4. Research: Investigate how sparse attention patterns affect the model's ability to perform in-context learning. Does sparsity hurt or help few-shot performance?

Key Takeaways:

  • Sparse attention reduces complexity from O(n²) to O(n·k) where k << n
  • Local attention captures nearby context efficiently
  • Global tokens provide information bottlenecks for long-range reasoning
  • BigBird combines local + global + random for theoretical guarantees
  • Longformer is the most practical sparse pattern for NLP tasks
  • Dilated attention increases receptive field without increasing computation
  • Sparse patterns maintain 95%+ of full attention quality for most tasks

What to Learn Next

-> RWKV and Linear Attention O(n) attention through kernel approximation and recurrent formulations.

-> Flash Attention and Memory Efficiency IO-aware attention optimization for modern hardware.

-> Hardware-Aware LLM Design Optimizing model architecture for GPU memory hierarchy.

-> Long Context Window Scaling context windows to millions of tokens.

-> State Space Models Mamba and linear-time sequence modeling.

-> Mixture of Experts Sparse architectures that scale efficiently.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement