Architectures
State Space Models — Beyond Transformers
Transformers dominate LLMs but have quadratic complexity. State Space Models (SSMs) like Mamba and S4 offer linear-time alternatives with strong performance on long sequences.
- SSM Theory — Continuous-time dynamical systems discretized for deep learning
- Mamba — Selective state spaces with input-dependent dynamics
- Linear Attention — Attention variants with O(n) complexity
- Trade-offs — When to use SSMs vs Transformers
Not all intelligence requires attention—sometimes, state is enough.
State Space Models
While transformers have achieved remarkable success, their quadratic attention complexity limits their applicability to very long sequences. State Space Models (SSMs) offer a promising alternative with linear-time processing and strong theoretical foundations in control theory and signal processing.
DfState Space Model
A State Space Model (SSM) is a sequence model that processes inputs through a hidden state evolving according to linear dynamical systems equations. SSMs map input sequences to output sequences using recurrence or convolution, with complexity that scales linearly with sequence length.
SSM Theory
Continuous-Time State Space
Continuous SSM
Here,
- =Hidden state at time t
- =Input at time t
- =Output at time t
- =State transition matrix (n x n)
- =Input projection matrix (n x 1)
- =Output projection matrix (1 x n)
- =Feedthrough matrix (skip connection)
The continuous-time formulation connects SSMs to classical control theory, enabling principled initialization and stability analysis. The matrices A, B, C, D are learned parameters.
Discretization
To use SSMs in deep learning, we must discretize the continuous-time system:
Discretized SSM
Here,
- =Discrete hidden state at step t
- =Discrete input at step t
- =Discrete output at step t
- =Discretized state matrix
- =Discretized input matrix
Zero-Order Hold Discretization
Here,
- =Discretization step size
- =Matrix exponential
- =Identity matrix
SSM vs Transformer Complexity
| Aspect | Transformer | SSM |
|---|---|---|
| Training | O(n²d) parallel | O(nd²) parallel |
| Inference | O(n) per step | O(d²) per step |
| Memory | O(n²) KV cache | O(d) state |
| Long-range | O(n²) attention | O(1) recurrence |
Complexity Comparison
For sequence length n = 100,000 and dimension d = 256:
Transformer:
- Training: O(100,000² × 256) ≈ 2.56 × 10¹² FLOPs
- Memory: O(100,000²) = 10 GB (KV cache)
SSM (Mamba):
- Training: O(100,000 × 256²) ≈ 6.55 × 10⁹ FLOPs
- Memory: O(256) = 1 KB (state)
The SSM is ~400x faster in training and uses ~10,000x less memory for this sequence length.
Mamba Architecture
Selective State Spaces
DfMamba
Mamba (Gu & Dao, 2023) is a selective state space model that makes the SSM parameters input-dependent, enabling content-aware reasoning. It achieves transformer-quality performance with linear-time complexity.
The key innovation of Mamba is selectivity—making the SSM matrices B, C, and Δ depend on the input:
class MambaBlock:
"""Simplified Mamba block."""
def __init__(self, d_model, d_state=16, d_conv=4):
self.d_model = d_model
self.d_state = d_state
# Input-dependent projections
self.in_proj = nn.Linear(d_model, 2 * d_model)
self.conv1d = nn.Conv1d(d_model, d_model, d_conv)
# SSM parameters (input-dependent)
self.x_proj = nn.Linear(d_model, d_state * 2 + 1)
self.dt_proj = nn.Linear(d_state, d_model)
# State space parameters
self.A = nn.Parameter(torch.randn(d_model, d_state))
self.D = nn.Parameter(torch.ones(d_model))
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
"""Forward pass with selective scan."""
B, L, D = x.shape
# Input-dependent parameters
xz = self.in_proj(x) # (B, L, 2D)
x, z = xz.chunk(2, dim=-1)
# Convolution
x = self.conv1d(x.transpose(1, 2))[:, :, :L].transpose(1, 2)
x = F.silu(x)
# Input-dependent SSM parameters
x_dbl = self.x_proj(x) # (B, L, 2*d_state + 1)
dt, B_param, C_param = x_dbl.split(
[self.d_state, self.d_state, self.d_state], dim=-1
)
dt = F.softplus(self.dt_proj(dt)) # (B, L, D)
# Selective scan (input-dependent dynamics)
y = selective_scan(x, dt, self.A, B_param, C_param, self.D)
# Gating
y = y * F.silu(z)
return self.out_proj(y)
Selective Scan
DfSelective Scan
The selective scan is the core operation of Mamba. It computes the SSM recurrence with input-dependent parameters B, C, and Δ, enabling the model to selectively remember or forget information based on the input content.
def selective_scan(x, delta, A, B, C, D):
"""Selective scan algorithm."""
batch, seq_len, dim = x.shape
state_size = A.shape[1]
# Discretize
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, D, N)
deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
# Scan
h = torch.zeros(batch, dim, state_size, device=x.device)
ys = []
for t in range(seq_len):
h = deltaA[:, t] * h + deltaB_x[:, t]
y = (h * C[:, t].unsqueeze(1)).sum(-1)
ys.append(y)
y = torch.stack(ys, dim=1)
return y + x * D
Mamba Variants
| Model | Parameters | Speedup vs Transformer | Quality |
|---|---|---|---|
| Mamba-130M | 130M | 5x | Competitive |
| Mamba-370M | 370M | 4x | Competitive |
| Mamba-1.4B | 1.4B | 3x | Competitive |
| Mamba-2.8B | 2.8B | 3x | Competitive |
| Jamba | 52B (Mamba + Attention) | 2x | State-of-the-art |
S4 (Structured State Spaces)
HiPPO Initialization
DfHiPPO Matrix
The HiPPO (High-order Polynomial Projection Operator) matrix provides principled initialization for the state transition matrix A, enabling SSMs to efficiently compress long-range dependencies.
HiPPO-LegS Matrix
Here,
- =Element of the HiPPO matrix
- =Row and column indices
def hippo_legs_matrix(N):
"""Generate HiPPO-LegS matrix."""
P = torch.sqrt(1 + 2 * torch.arange(N).float())
A = torch.zeros(N, N)
for i in range(N):
for j in range(N):
if i > j:
A[i, j] = P[i] * P[j]
elif i == j:
A[i, j] = i + 1
return -A
S4 Architecture
class S4Layer:
"""S4 layer with structured parameterization."""
def __init__(self, d_model, N=64):
self.d_model = d_model
self.N = N
# HiPPO initialization
A = hippo_legs_matrix(N)
self.A_log = nn.Parameter(torch.log(-A))
# Other parameters
self.B = nn.Parameter(torch.randn(N))
self.C = nn.Parameter(torch.randn(N))
self.D = nn.Parameter(torch.ones(d_model))
# Step size
self.log_delta = nn.Parameter(torch.zeros(d_model))
def forward(self, x):
"""Forward pass using convolution."""
L = x.shape[1]
# Discretize
A = -torch.exp(self.A_log)
delta = torch.exp(self.log_delta)
# Compute kernel
kernel = self.compute_kernel(A, self.B, self.C, delta, L)
# Convolution
y = fft_conv(x, kernel)
return y + x * self.D
Linear Attention
Linear Attention Formulation
DfLinear Attention
Linear attention replaces the softmax attention with kernel-based approximation, reducing complexity from O(n²) to O(n) while maintaining the ability to capture long-range dependencies.
Linear Attention
Here,
- =Query, Key, Value matrices
- =Feature map (kernel approximation)
- =All-ones vector
class LinearAttention(nn.Module):
"""Linear attention with kernel approximation."""
def __init__(self, d_model, n_heads, feature_map="elu"):
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
if feature_map == "elu":
self.feature_map = lambda x: F.elu(x) + 1
elif feature_map == "relu":
self.feature_map = lambda x: F.relu(x) + 1
def forward(self, x):
B, L, _ = x.shape
# Project to Q, K, V
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
# Apply feature map
q = self.feature_map(q)
k = self.feature_map(k)
# Linear attention computation
kv = torch.einsum("bhld,bhle->bhde", k, v)
qkv = torch.einsum("bhld,bhde->bhle", q, kv)
# Normalize
k_sum = k.sum(dim=1)
denominator = torch.einsum("bhld,bhd->bhl", q, k_sum)
denominator = denominator.unsqueeze(-1).clamp(min=1e-6)
y = qkv / denominator
return self.out(y.reshape(B, L, -1))
Hybrid Architectures
Mamba + Transformer Hybrids
DfHybrid SSM-Transformer
Hybrid architectures combine the efficiency of SSMs with the expressivity of attention, using SSMs for local processing and attention for global reasoning.
class HybridSSMTransformer(nn.Module):
"""Hybrid model with SSM and Transformer blocks."""
def __init__(self, d_model, n_layers, n_heads, mamba_ratio=0.75):
self.layers = nn.ModuleList()
n_mamba = int(n_layers * mamba_ratio)
n_attention = n_layers - n_mamba
# Mamba layers for local processing
for _ in range(n_mamba):
self.layers.append(MambaBlock(d_model))
# Attention layers for global reasoning
for _ in range(n_attention):
self.layers.append(TransformerBlock(d_model, n_heads))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
When to Use SSMs vs Transformers
| Scenario | Recommended | Reason |
|---|---|---|
| Long sequences (>100K) | SSM | Linear complexity |
| Autoregressive generation | SSM | O(1) state per step |
| Complex reasoning | Transformer | Global attention |
| Memory-constrained | SSM | O(d) state |
| Very long context | SSM | No KV cache limit |
| Task-specific fine-tuning | Transformer | Better adaptation |
The Jamba model (AI21) demonstrates that hybrid architectures can achieve state-of-the-art performance by using Mamba layers for 80% of the model and attention layers for 20%, getting the best of both worlds.
Practice Exercises
-
Conceptual: Explain why the selective scan in Mamba enables content-aware reasoning. How does this differ from fixed-parameter SSMs?
-
Mathematical: For a sequence of length n = 50,000 and dimension d = 512, calculate the FLOPs for training a transformer vs an SSM. What is the speedup ratio?
-
Practical: Implement a simple S4 layer with HiPPO initialization and test it on a long-range dependency task.
-
Research: Compare the performance of Mamba, Transformer, and hybrid models on the Long Range Arena benchmark. What explains the performance differences?
Key Takeaways:
- SSMs offer linear-time alternatives to quadratic-complexity transformers
- Mamba introduces input-dependent (selective) dynamics for content-aware reasoning
- HiPPO initialization enables efficient long-range dependency modeling
- Hybrid architectures combine SSM efficiency with transformer expressivity
- SSMs excel at long sequences; transformers excel at complex reasoning
What to Learn Next
-> Mixture of Experts Sparse architectures that scale efficiently.
-> Flash Attention and Memory Efficiency Optimizing transformer attention for efficiency.
-> LLM Architecture Deep Dive Understanding transformer architectures in depth.
-> Scaling Laws and Chinchilla How model size affects performance.
-> Speculative Decoding Speeding up inference with draft models.
-> KV Cache Optimization Optimizing transformer inference memory.