State Space Models
What are State Space Models?
SSMs are sequence models that use continuous-time state space representations, offering linear complexity compared to Transformers' quadratic complexity.
Mamba Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelectiveSSM(nn.Module):
def __init__(self, d_model, d_state=16):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# Input-dependent parameters
self.A_log = nn.Parameter(torch.randn(d_model, d_state))
self.B = nn.Linear(d_model, d_state)
self.C = nn.Linear(d_model, d_state)
self.D = nn.Parameter(torch.ones(d_model))
def forward(self, x):
batch, seq_len, d_model = x.shape
# Make A, B, C input-dependent (selective)
A = -torch.exp(self.A_log)
B = self.B(x) # [batch, seq, d_state]
C = self.C(x) # [batch, seq, d_state]
# Discretize
dt = torch.ones(batch, seq_len, 1, device=x.device) * 0.01
dA = torch.exp(dt.unsqueeze(-1) * A)
dB = dt.unsqueeze(-1) * B.unsqueeze(-2)
# Selective scan (linear recurrence)
h = torch.zeros(batch, self.d_model, self.d_state, device=x.device)
outputs = []
for t in range(seq_len):
h = dA[:, t].unsqueeze(1) * h + dB[:, t].unsqueeze(1) * x[:, t].unsqueeze(-1)
y = (h * C[:, t].unsqueeze(1)).sum(-1) + self.D * x[:, t]
outputs.append(y)
return torch.stack(outputs, dim=1)
SSM vs Transformer
| Aspect | SSM | Transformer |
|---|---|---|
| Complexity | O(L) | O(L^2) |
| Memory | Linear | Quadratic |
| Parallelism | Sequential | Parallel |
| Long-range | Good | Excellent |
Summary
State space models like Mamba offer efficient alternatives to Transformers with linear complexity, making them suitable for very long sequences.
Next: We'll explore inference optimization.