CW

RWKV and Linear Attention

ArchitecturesLinear AttentionFree Lesson

Advertisement

Architectures

RWKV and Linear Attention — Scaling Beyond Quadratic Complexity

Transformers achieve remarkable performance but suffer from O(n²) attention complexity. RWKV and linear attention variants offer O(n) alternatives that maintain competitive quality while enabling dramatically longer context windows.

  • RWKV — Recurrent linear attention with time-mixed channel evolution
  • Linear Attention — Kernel-based softmax approximation for O(n) attention
  • Token Shifting — Positional information through shifted channels
  • Production Trade-offs — When linear attention wins over standard transformers

Efficiency is not about doing less—it's about doing more with the same resources.

RWKV and Linear Attention

Standard softmax attention computes pairwise interactions between all tokens, resulting in quadratic complexity. Linear attention variants reformulate attention as a kernel operation, enabling linear-time computation. RWKV combines these ideas with a recurrent formulation for efficient autoregressive inference.

DfLinear Attention

Linear attention replaces the softmax normalization with a kernel feature map, allowing the attention computation to be reordered as a matrix product with linear complexity O(n·d²) instead of O(n²·d).

The Attention Bottleneck

Quadratic vs Linear Complexity

Standard Attention Complexity

textAttention(Q,K,V)=textsoftmaxleft(fracQKTsqrtdkright)V\\text{Attention}(Q,K,V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V

Here,

  • QQ=Query matrix (n × d_k)
  • KK=Key matrix (n × d_k)
  • VV=Value matrix (n × d_v)
  • dkd_k=Key dimension

The standard attention mechanism requires computing the n×n attention matrix QK^T, which costs O(n²·d_k) FLOPs and O(n²) memory.

Linear Attention Formulation

textLinAttn(Q,K,V)i=fracsumj=1nphi(qi)Tphi(kj)vjsumj=1nphi(qi)Tphi(kj)\\text{LinAttn}(Q,K,V)_i = \\frac{\\sum_{j=1}^{n} \\phi(q_i)^T \\phi(k_j) v_j}{\\sum_{j=1}^{n} \\phi(q_i)^T \\phi(k_j)}

Here,

  • ϕ\phi=Feature map (e.g., ELU + 1)
  • qiq_i=Query at position i
  • kjk_j=Key at position j
  • vjv_j=Value at position j

The key insight is that by changing the order of operations—computing K^TV first (d×d matrix) then multiplying by Q—we avoid the n×n attention matrix entirely. This reduces complexity from O(n²·d) to O(n·d²).

Complexity Comparison

MechanismFLOPsMemoryInference Cost
Standard AttentionO(n²d)O(n²)O(n) per token
Linear AttentionO(nd²)O(d²)O(d²) per token
RWKVO(nd²)O(d)O(d²) per token

Practical Complexity

For sequence length n = 100,000 and dimension d = 128:

Standard Attention: O(100,000² × 128) = 1.28 × 10¹² FLOPs Linear Attention: O(100,000 × 128²) = 1.64 × 10⁹ FLOPs

Speedup: ~780× fewer FLOPs for long sequences.

RWKV Architecture

Overview

DfRWKV

RWKV (Receptance Weighted Key Value) is a linear attention architecture that combines the training efficiency of transformers with the inference efficiency of RNNs. It uses time-mixed channel evolution and a WKV (weighted key-value) mechanism.

Core Components

class RWKVChannelMix(nn.Module):
    """RWKV channel mixing block with time-mixing."""
    
    def __init__(self, d_model, d_ffn, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.d_model = d_model
        
        # Time-mixing parameters
        self.time_mix_r = nn.Parameter(torch.ones(d_model))
        self.time_mix_k = nn.Parameter(torch.ones(d_model))
        
        # Linear projections
        self.key = nn.Linear(d_model, d_ffn, bias=False)
        self.receptance = nn.Linear(d_model, d_ffn, bias=False)
        self.value = nn.Linear(d_ffn, d_model, bias=False)
        
        # Scale and gating
        self.r_k = nn.Parameter(torch.zeros(d_ffn))
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
    
    def forward(self, x):
        """Time-mixed forward pass."""
        xx = self.time_shift(x)  # Shift by 1 position
        
        # Time-mixed inputs
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        
        # Compute receptance, key, value
        r = torch.sigmoid(self.receptance(xr))
        k = self.key(xk)
        v = self.value(F.silu(k))
        
        # WKV attention
        wkv = self.wkv_attention(k, v)
        
        return r * wkv
    
    def wkv_attention(self, k, v):
        """Weighted key-value attention (linear complexity)."""
        # Simplified WKV: linear combination of values weighted by keys
        kv = k * v
        return kv.sum(dim=1, keepdim=True)

Time-Mixing Mechanism

DfTime-Mixing

Time-mixing in RWKV blends the current input with the previous input using learnable mixing parameters, providing positional information without explicit positional encodings.

Time-Mixing Operation

xttextmixed=muodotxt+(1mu)odotxt1x_t^{\\text{mixed}} = \\mu \\odot x_t + (1 - \\mu) \\odot x_{t-1}

Here,

  • μ\mu=Mixing parameter (learned per channel)
  • xtx_t=Current input
  • xt1x_{t-1}=Previous input (time-shifted)
  • \odot=Element-wise multiplication
class RWKVTokenShift(nn.Module):
    """Time-shift operation for positional encoding."""
    
    def __init__(self, d_model):
        super().__init__()
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        self.mix = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x):
        """Shift tokens by one position and mix."""
        # x: (batch, seq_len, d_model)
        xx = self.time_shift(x)[:, :-1, :]  # Shift right
        return x * self.mix + xx * (1 - self.mix)

WKV Attention

WKV (Weighted Key-Value) Attention

textWKVt=fracsumi=1t1e(t1i)w+kivi+eu+ktvtsumi=1t1e(t1i)w+ki+eu+kt\\text{WKV}_t = \\frac{\\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} v_i + e^{u + k_t} v_t}{\\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} + e^{u + k_t}}

Here,

  • ww=Decay parameter (learned per channel)
  • kik_i=Key at position i
  • viv_i=Value at position i
  • uu=Bonus parameter for current token
  • tt=Current time step

The WKV mechanism is essentially a linear attention with exponential decay. The decay parameter w controls how quickly past information is forgotten, and the bonus parameter u gives extra weight to the current token.

Linear Attention Variants

Kernel-Based Linear Attention

class LinearAttention(nn.Module):
    """Linear attention with ELU feature map."""
    
    def __init__(self, d_model, n_heads, eps=1e-6):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.eps = eps
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
    
    def feature_map(self, x):
        """ELU-based feature map for kernel approximation."""
        return F.elu(x) + 1
    
    def forward(self, x):
        B, L, _ = x.shape
        
        # Project to Q, K, V
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        # Apply feature map
        q = self.feature_map(q)  # (B, L, H, d_k)
        k = self.feature_map(k)
        
        # Compute KV matrix (d_k x d_v)
        kv = torch.einsum("blhd,b lhv->bh dv", k, v)
        
        # Compute normalizer
        k_sum = k.sum(dim=1)  # (B, H, d_k)
        
        # Apply to queries
        numerator = torch.einsum("blhd,bhdv->blhv", q, kv)
        denominator = torch.einsum("blhd,bhd->blh", q, k_sum)
        denominator = denominator.unsqueeze(-1).clamp(min=self.eps)
        
        output = numerator / denominator
        return self.out(output.reshape(B, L, -1))

Performer (Random Feature Attention)

DfPerformer

The Performer (Choromanski et al., 2021) uses random features to approximate the softmax kernel, enabling linear attention while maintaining approximate softmax behavior.

Random Feature Map

phi(x)=frac1sqrtmleft[eomega1Txx2/2,ldots,eomegamTxx2/2right]\\phi(x) = \\frac{1}{\\sqrt{m}} \\left[ e^{\\omega_1^T x - \\|x\\|^2/2}, \\ldots, e^{\\omega_m^T x - \\|x\\|^2/2} \\right]

Here,

  • ωi\omega_i=Random projection vectors (Gaussian)
  • mm=Number of random features
  • ϕ\phi=Feature map approximation
class PerformerAttention(nn.Module):
    """Performer attention with random features."""
    
    def __init__(self, d_model, n_heads, n_features=256):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.n_features = n_features
        
        # Random projection matrix
        self.register_buffer(
            "projection_matrix",
            self._create_projection(n_features, self.d_k)
        )
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
    
    def _create_projection(self, n_features, d_k):
        """Create random Gaussian projection matrix."""
        return torch.randn(n_features, d_k) / (d_k ** 0.5)
    
    def feature_map(self, x):
        """Compute random feature approximation."""
        # x: (B, L, H, d_k)
        # projection_matrix: (n_features, d_k)
        
        # Project and compute exponentials
        proj = torch.einsum("blhd,md->blhm", x, self.projection_matrix)
        return torch.exp(proj - x.norm(dim=-1, keepdim=True) ** 2 / 2)
    
    def forward(self, x):
        B, L, _ = x.shape
        
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        # Compute feature maps
        q_prime = self.feature_map(q)  # (B, L, H, m)
        k_prime = self.feature_map(k)
        
        # Linear attention: (Q'K')V instead of Q(K'V)
        kv = torch.einsum("blhm,b lhv->bhmv", k_prime, v)
        numerator = torch.einsum("blhm,bhmv->blhv", q_prime, kv)
        
        denominator = torch.einsum("blhm,bhm->blh", q_prime, k_prime.sum(dim=1))
        denominator = denominator.unsqueeze(-1).clamp(min=1e-6)
        
        output = numerator / denominator
        return self.out(output.reshape(B, L, -1))

RWKV-6 and Beyond

Evolution of RWKV

VersionKey InnovationContext LengthPerformance
RWKV-4Basic WKV attention8KCompetitive
RWKV-5 (Eagle)Improved time-mixing32KStrong
RWKV-6 (Finch)Data-dependent decay128K+Near-transformer
RWKV-7Eagle Eye mechanism1M+Transformer-matching

Data-Dependent Decay

Data-Dependent Decay (RWKV-6)

wt=textsoftplus(Wwxt+bw)w_t = \\text{softplus}(W_w x_t + b_w)

Here,

  • wtw_t=Time-dependent decay at step t
  • WwW_w=Decay weight matrix
  • bwb_w=Decay bias
  • xtx_t=Input at time t

RWKV-6 introduces data-dependent decay, allowing the model to dynamically control how quickly to forget past information based on the input content. This significantly improves the model's ability to handle long-range dependencies.

Training with Linear Attention

class RWKVBlock(nn.Module):
    """Complete RWKV block with time and channel mixing."""
    
    def __init__(self, d_model, d_ffn, layer_id):
        super().__init__()
        
        # Time mixing (attention-like)
        self.time_mix = RWKVTokenShift(d_model)
        self.attention = RWKVAttention(d_model)
        
        # Channel mixing (FFN-like)
        self.channel_mix = RWKVChannelMix(d_model, d_ffn, layer_id)
        
        # Layer normalization
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # Time mixing (attention branch)
        x = x + self.attention(self.ln1(self.time_mix(x)))
        
        # Channel mixing (FFN branch)
        x = x + self.channel_mix(self.ln2(x))
        
        return x

Hybrid Approaches

RWKV + Attention Hybrids

DfHybrid RWKV-Transformer

Hybrid architectures combine RWKV layers for efficient long-range processing with standard attention layers for global reasoning, achieving both efficiency and expressivity.

class HybridRWKVTransformer(nn.Module):
    """Hybrid model with RWKV and attention blocks."""
    
    def __init__(self, d_model, n_layers, n_heads, mlp_ratio=4.0):
        super().__init__()
        self.layers = nn.ModuleList()
        
        for i in range(n_layers):
            if i % 4 == 0:
                # Attention layer every 4th block
                self.layers.append(
                    TransformerBlock(d_model, n_heads, mlp_ratio)
                )
            else:
                # RWKV layer
                d_ffn = int(d_model * mlp_ratio)
                self.layers.append(
                    RWKVBlock(d_model, d_ffn, layer_id=i)
                )
        
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)

Performance Comparison

Benchmark Results

ModelSizePPLMMLUSpeed (tokens/s)
Transformer (LLaMA)7B5.6846.81,200
RWKV-67B5.8245.22,800
Hybrid (3:1 RWKV:Attn)7B5.7146.12,400
Mamba7B5.9144.93,100

The hybrid RWKV-Transformer approach achieves nearly transformer-level quality while maintaining 2× inference speedup. The key insight is that full attention is only needed for ~25% of layers.

Practice Exercises

  1. Conceptual: Explain why linear attention fails to match transformer quality on tasks requiring precise token-to-token comparison. What specific capabilities are lost?

  2. Mathematical: Derive the computational complexity of RWKV inference for a sequence of length n with model dimension d. How does this compare to transformer inference with KV cache?

  3. Practical: Implement RWKV-4 channel mixing with time-shifting and verify it produces correct autoregressive outputs on a simple language modeling task.

  4. Research: Compare the perplexity curves of RWKV-6 vs transformer models of similar size during training. At what sequence lengths does RWKV's advantage become apparent?

Key Takeaways:

  • Linear attention reduces complexity from O(n²d) to O(nd²) through kernel approximation
  • RWKV combines linear attention with time-mixing for efficient autoregressive inference
  • WKV attention uses exponential decay to weight past information
  • Hybrid RWKV-Transformer models achieve near-transformer quality with better efficiency
  • RWKV-6 introduces data-dependent decay for dynamic forgetting control
  • Linear attention enables million-length context with constant memory

What to Learn Next

-> Sparse Attention Patterns Longformer, BigBird, and efficient attention patterns for long documents.

-> State Space Models Mamba, S4, and linear-time alternatives to transformers.

-> Flash Attention and Memory Efficiency IO-aware attention optimization for modern hardware.

-> Mixture of Experts Sparse architectures that scale efficiently.

-> LLM Architecture Deep Dive Understanding transformer architectures in depth.

-> Long Context Window Scaling context windows to millions of tokens.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement