Architectures
RWKV and Linear Attention — Scaling Beyond Quadratic Complexity
Transformers achieve remarkable performance but suffer from O(n²) attention complexity. RWKV and linear attention variants offer O(n) alternatives that maintain competitive quality while enabling dramatically longer context windows.
- RWKV — Recurrent linear attention with time-mixed channel evolution
- Linear Attention — Kernel-based softmax approximation for O(n) attention
- Token Shifting — Positional information through shifted channels
- Production Trade-offs — When linear attention wins over standard transformers
Efficiency is not about doing less—it's about doing more with the same resources.
RWKV and Linear Attention
Standard softmax attention computes pairwise interactions between all tokens, resulting in quadratic complexity. Linear attention variants reformulate attention as a kernel operation, enabling linear-time computation. RWKV combines these ideas with a recurrent formulation for efficient autoregressive inference.
DfLinear Attention
Linear attention replaces the softmax normalization with a kernel feature map, allowing the attention computation to be reordered as a matrix product with linear complexity O(n·d²) instead of O(n²·d).
The Attention Bottleneck
Quadratic vs Linear Complexity
Standard Attention Complexity
Here,
- =Query matrix (n × d_k)
- =Key matrix (n × d_k)
- =Value matrix (n × d_v)
- =Key dimension
The standard attention mechanism requires computing the n×n attention matrix QK^T, which costs O(n²·d_k) FLOPs and O(n²) memory.
Linear Attention Formulation
Here,
- =Feature map (e.g., ELU + 1)
- =Query at position i
- =Key at position j
- =Value at position j
The key insight is that by changing the order of operations—computing K^TV first (d×d matrix) then multiplying by Q—we avoid the n×n attention matrix entirely. This reduces complexity from O(n²·d) to O(n·d²).
Complexity Comparison
| Mechanism | FLOPs | Memory | Inference Cost |
|---|---|---|---|
| Standard Attention | O(n²d) | O(n²) | O(n) per token |
| Linear Attention | O(nd²) | O(d²) | O(d²) per token |
| RWKV | O(nd²) | O(d) | O(d²) per token |
Practical Complexity
For sequence length n = 100,000 and dimension d = 128:
Standard Attention: O(100,000² × 128) = 1.28 × 10¹² FLOPs Linear Attention: O(100,000 × 128²) = 1.64 × 10⁹ FLOPs
Speedup: ~780× fewer FLOPs for long sequences.
RWKV Architecture
Overview
DfRWKV
RWKV (Receptance Weighted Key Value) is a linear attention architecture that combines the training efficiency of transformers with the inference efficiency of RNNs. It uses time-mixed channel evolution and a WKV (weighted key-value) mechanism.
Core Components
class RWKVChannelMix(nn.Module):
"""RWKV channel mixing block with time-mixing."""
def __init__(self, d_model, d_ffn, layer_id):
super().__init__()
self.layer_id = layer_id
self.d_model = d_model
# Time-mixing parameters
self.time_mix_r = nn.Parameter(torch.ones(d_model))
self.time_mix_k = nn.Parameter(torch.ones(d_model))
# Linear projections
self.key = nn.Linear(d_model, d_ffn, bias=False)
self.receptance = nn.Linear(d_model, d_ffn, bias=False)
self.value = nn.Linear(d_ffn, d_model, bias=False)
# Scale and gating
self.r_k = nn.Parameter(torch.zeros(d_ffn))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
def forward(self, x):
"""Time-mixed forward pass."""
xx = self.time_shift(x) # Shift by 1 position
# Time-mixed inputs
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
# Compute receptance, key, value
r = torch.sigmoid(self.receptance(xr))
k = self.key(xk)
v = self.value(F.silu(k))
# WKV attention
wkv = self.wkv_attention(k, v)
return r * wkv
def wkv_attention(self, k, v):
"""Weighted key-value attention (linear complexity)."""
# Simplified WKV: linear combination of values weighted by keys
kv = k * v
return kv.sum(dim=1, keepdim=True)
Time-Mixing Mechanism
DfTime-Mixing
Time-mixing in RWKV blends the current input with the previous input using learnable mixing parameters, providing positional information without explicit positional encodings.
Time-Mixing Operation
Here,
- =Mixing parameter (learned per channel)
- =Current input
- =Previous input (time-shifted)
- =Element-wise multiplication
class RWKVTokenShift(nn.Module):
"""Time-shift operation for positional encoding."""
def __init__(self, d_model):
super().__init__()
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.mix = nn.Parameter(torch.ones(d_model))
def forward(self, x):
"""Shift tokens by one position and mix."""
# x: (batch, seq_len, d_model)
xx = self.time_shift(x)[:, :-1, :] # Shift right
return x * self.mix + xx * (1 - self.mix)
WKV Attention
WKV (Weighted Key-Value) Attention
Here,
- =Decay parameter (learned per channel)
- =Key at position i
- =Value at position i
- =Bonus parameter for current token
- =Current time step
The WKV mechanism is essentially a linear attention with exponential decay. The decay parameter w controls how quickly past information is forgotten, and the bonus parameter u gives extra weight to the current token.
Linear Attention Variants
Kernel-Based Linear Attention
class LinearAttention(nn.Module):
"""Linear attention with ELU feature map."""
def __init__(self, d_model, n_heads, eps=1e-6):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.eps = eps
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def feature_map(self, x):
"""ELU-based feature map for kernel approximation."""
return F.elu(x) + 1
def forward(self, x):
B, L, _ = x.shape
# Project to Q, K, V
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
# Apply feature map
q = self.feature_map(q) # (B, L, H, d_k)
k = self.feature_map(k)
# Compute KV matrix (d_k x d_v)
kv = torch.einsum("blhd,b lhv->bh dv", k, v)
# Compute normalizer
k_sum = k.sum(dim=1) # (B, H, d_k)
# Apply to queries
numerator = torch.einsum("blhd,bhdv->blhv", q, kv)
denominator = torch.einsum("blhd,bhd->blh", q, k_sum)
denominator = denominator.unsqueeze(-1).clamp(min=self.eps)
output = numerator / denominator
return self.out(output.reshape(B, L, -1))
Performer (Random Feature Attention)
DfPerformer
The Performer (Choromanski et al., 2021) uses random features to approximate the softmax kernel, enabling linear attention while maintaining approximate softmax behavior.
Random Feature Map
Here,
- =Random projection vectors (Gaussian)
- =Number of random features
- =Feature map approximation
class PerformerAttention(nn.Module):
"""Performer attention with random features."""
def __init__(self, d_model, n_heads, n_features=256):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.n_features = n_features
# Random projection matrix
self.register_buffer(
"projection_matrix",
self._create_projection(n_features, self.d_k)
)
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def _create_projection(self, n_features, d_k):
"""Create random Gaussian projection matrix."""
return torch.randn(n_features, d_k) / (d_k ** 0.5)
def feature_map(self, x):
"""Compute random feature approximation."""
# x: (B, L, H, d_k)
# projection_matrix: (n_features, d_k)
# Project and compute exponentials
proj = torch.einsum("blhd,md->blhm", x, self.projection_matrix)
return torch.exp(proj - x.norm(dim=-1, keepdim=True) ** 2 / 2)
def forward(self, x):
B, L, _ = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
# Compute feature maps
q_prime = self.feature_map(q) # (B, L, H, m)
k_prime = self.feature_map(k)
# Linear attention: (Q'K')V instead of Q(K'V)
kv = torch.einsum("blhm,b lhv->bhmv", k_prime, v)
numerator = torch.einsum("blhm,bhmv->blhv", q_prime, kv)
denominator = torch.einsum("blhm,bhm->blh", q_prime, k_prime.sum(dim=1))
denominator = denominator.unsqueeze(-1).clamp(min=1e-6)
output = numerator / denominator
return self.out(output.reshape(B, L, -1))
RWKV-6 and Beyond
Evolution of RWKV
| Version | Key Innovation | Context Length | Performance |
|---|---|---|---|
| RWKV-4 | Basic WKV attention | 8K | Competitive |
| RWKV-5 (Eagle) | Improved time-mixing | 32K | Strong |
| RWKV-6 (Finch) | Data-dependent decay | 128K+ | Near-transformer |
| RWKV-7 | Eagle Eye mechanism | 1M+ | Transformer-matching |
Data-Dependent Decay
Data-Dependent Decay (RWKV-6)
Here,
- =Time-dependent decay at step t
- =Decay weight matrix
- =Decay bias
- =Input at time t
RWKV-6 introduces data-dependent decay, allowing the model to dynamically control how quickly to forget past information based on the input content. This significantly improves the model's ability to handle long-range dependencies.
Training with Linear Attention
class RWKVBlock(nn.Module):
"""Complete RWKV block with time and channel mixing."""
def __init__(self, d_model, d_ffn, layer_id):
super().__init__()
# Time mixing (attention-like)
self.time_mix = RWKVTokenShift(d_model)
self.attention = RWKVAttention(d_model)
# Channel mixing (FFN-like)
self.channel_mix = RWKVChannelMix(d_model, d_ffn, layer_id)
# Layer normalization
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Time mixing (attention branch)
x = x + self.attention(self.ln1(self.time_mix(x)))
# Channel mixing (FFN branch)
x = x + self.channel_mix(self.ln2(x))
return x
Hybrid Approaches
RWKV + Attention Hybrids
DfHybrid RWKV-Transformer
Hybrid architectures combine RWKV layers for efficient long-range processing with standard attention layers for global reasoning, achieving both efficiency and expressivity.
class HybridRWKVTransformer(nn.Module):
"""Hybrid model with RWKV and attention blocks."""
def __init__(self, d_model, n_layers, n_heads, mlp_ratio=4.0):
super().__init__()
self.layers = nn.ModuleList()
for i in range(n_layers):
if i % 4 == 0:
# Attention layer every 4th block
self.layers.append(
TransformerBlock(d_model, n_heads, mlp_ratio)
)
else:
# RWKV layer
d_ffn = int(d_model * mlp_ratio)
self.layers.append(
RWKVBlock(d_model, d_ffn, layer_id=i)
)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.norm(x)
Performance Comparison
Benchmark Results
| Model | Size | PPL | MMLU | Speed (tokens/s) |
|---|---|---|---|---|
| Transformer (LLaMA) | 7B | 5.68 | 46.8 | 1,200 |
| RWKV-6 | 7B | 5.82 | 45.2 | 2,800 |
| Hybrid (3:1 RWKV:Attn) | 7B | 5.71 | 46.1 | 2,400 |
| Mamba | 7B | 5.91 | 44.9 | 3,100 |
The hybrid RWKV-Transformer approach achieves nearly transformer-level quality while maintaining 2× inference speedup. The key insight is that full attention is only needed for ~25% of layers.
Practice Exercises
-
Conceptual: Explain why linear attention fails to match transformer quality on tasks requiring precise token-to-token comparison. What specific capabilities are lost?
-
Mathematical: Derive the computational complexity of RWKV inference for a sequence of length n with model dimension d. How does this compare to transformer inference with KV cache?
-
Practical: Implement RWKV-4 channel mixing with time-shifting and verify it produces correct autoregressive outputs on a simple language modeling task.
-
Research: Compare the perplexity curves of RWKV-6 vs transformer models of similar size during training. At what sequence lengths does RWKV's advantage become apparent?
Key Takeaways:
- Linear attention reduces complexity from O(n²d) to O(nd²) through kernel approximation
- RWKV combines linear attention with time-mixing for efficient autoregressive inference
- WKV attention uses exponential decay to weight past information
- Hybrid RWKV-Transformer models achieve near-transformer quality with better efficiency
- RWKV-6 introduces data-dependent decay for dynamic forgetting control
- Linear attention enables million-length context with constant memory
What to Learn Next
-> Sparse Attention Patterns Longformer, BigBird, and efficient attention patterns for long documents.
-> State Space Models Mamba, S4, and linear-time alternatives to transformers.
-> Flash Attention and Memory Efficiency IO-aware attention optimization for modern hardware.
-> Mixture of Experts Sparse architectures that scale efficiently.
-> LLM Architecture Deep Dive Understanding transformer architectures in depth.
-> Long Context Window Scaling context windows to millions of tokens.