CW

State Space Models

ArchitecturesSSMsFree Lesson

Advertisement

Architectures

State Space Models — Beyond Transformers

Transformers dominate LLMs but have quadratic complexity. State Space Models (SSMs) like Mamba and S4 offer linear-time alternatives with strong performance on long sequences.

  • SSM Theory — Continuous-time dynamical systems discretized for deep learning
  • Mamba — Selective state spaces with input-dependent dynamics
  • Linear Attention — Attention variants with O(n) complexity
  • Trade-offs — When to use SSMs vs Transformers

Not all intelligence requires attention—sometimes, state is enough.

State Space Models

While transformers have achieved remarkable success, their quadratic attention complexity limits their applicability to very long sequences. State Space Models (SSMs) offer a promising alternative with linear-time processing and strong theoretical foundations in control theory and signal processing.

DfState Space Model

A State Space Model (SSM) is a sequence model that processes inputs through a hidden state evolving according to linear dynamical systems equations. SSMs map input sequences to output sequences using recurrence or convolution, with complexity that scales linearly with sequence length.

SSM Theory

Continuous-Time State Space

Continuous SSM

h(t)=Ah(t)+Bx(t)y(t)=Ch(t)+Dx(t)\begin{aligned} h'(t) &= Ah(t) + Bx(t) \\ y(t) &= Ch(t) + Dx(t) \end{aligned}

Here,

  • h(t)h(t)=Hidden state at time t
  • x(t)x(t)=Input at time t
  • y(t)y(t)=Output at time t
  • AA=State transition matrix (n x n)
  • BB=Input projection matrix (n x 1)
  • CC=Output projection matrix (1 x n)
  • DD=Feedthrough matrix (skip connection)

The continuous-time formulation connects SSMs to classical control theory, enabling principled initialization and stability analysis. The matrices A, B, C, D are learned parameters.

Discretization

To use SSMs in deep learning, we must discretize the continuous-time system:

Discretized SSM

ht=Aˉht1+Bˉxtyt=Cht+Dxt\begin{aligned} h_t &= \bar{A} h_{t-1} + \bar{B} x_t \\ y_t &= C h_t + D x_t \end{aligned}

Here,

  • hth_t=Discrete hidden state at step t
  • xtx_t=Discrete input at step t
  • yty_t=Discrete output at step t
  • Aˉ\bar{A}=Discretized state matrix
  • Bˉ\bar{B}=Discretized input matrix

Zero-Order Hold Discretization

Aˉ=eDeltaA,quadbarB=(DeltaA)1(eDeltaAI)cdotDeltaB\bar{A} = e^{\\Delta A}, \\quad \\bar{B} = (\\Delta A)^{-1}(e^{\\Delta A} - I) \\cdot \\Delta B

Here,

  • Δ\Delta=Discretization step size
  • eΔAe^{\Delta A}=Matrix exponential
  • II=Identity matrix

SSM vs Transformer Complexity

AspectTransformerSSM
TrainingO(n²d) parallelO(nd²) parallel
InferenceO(n) per stepO(d²) per step
MemoryO(n²) KV cacheO(d) state
Long-rangeO(n²) attentionO(1) recurrence

Complexity Comparison

For sequence length n = 100,000 and dimension d = 256:

Transformer:

  • Training: O(100,000² × 256) ≈ 2.56 × 10¹² FLOPs
  • Memory: O(100,000²) = 10 GB (KV cache)

SSM (Mamba):

  • Training: O(100,000 × 256²) ≈ 6.55 × 10⁹ FLOPs
  • Memory: O(256) = 1 KB (state)

The SSM is ~400x faster in training and uses ~10,000x less memory for this sequence length.

Mamba Architecture

Selective State Spaces

DfMamba

Mamba (Gu & Dao, 2023) is a selective state space model that makes the SSM parameters input-dependent, enabling content-aware reasoning. It achieves transformer-quality performance with linear-time complexity.

The key innovation of Mamba is selectivity—making the SSM matrices B, C, and Δ depend on the input:

class MambaBlock:
    """Simplified Mamba block."""
    
    def __init__(self, d_model, d_state=16, d_conv=4):
        self.d_model = d_model
        self.d_state = d_state
        
        # Input-dependent projections
        self.in_proj = nn.Linear(d_model, 2 * d_model)
        self.conv1d = nn.Conv1d(d_model, d_model, d_conv)
        
        # SSM parameters (input-dependent)
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1)
        self.dt_proj = nn.Linear(d_state, d_model)
        
        # State space parameters
        self.A = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.ones(d_model))
        
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        """Forward pass with selective scan."""
        B, L, D = x.shape
        
        # Input-dependent parameters
        xz = self.in_proj(x)  # (B, L, 2D)
        x, z = xz.chunk(2, dim=-1)
        
        # Convolution
        x = self.conv1d(x.transpose(1, 2))[:, :, :L].transpose(1, 2)
        x = F.silu(x)
        
        # Input-dependent SSM parameters
        x_dbl = self.x_proj(x)  # (B, L, 2*d_state + 1)
        dt, B_param, C_param = x_dbl.split(
            [self.d_state, self.d_state, self.d_state], dim=-1
        )
        dt = F.softplus(self.dt_proj(dt))  # (B, L, D)
        
        # Selective scan (input-dependent dynamics)
        y = selective_scan(x, dt, self.A, B_param, C_param, self.D)
        
        # Gating
        y = y * F.silu(z)
        
        return self.out_proj(y)

Selective Scan

DfSelective Scan

The selective scan is the core operation of Mamba. It computes the SSM recurrence with input-dependent parameters B, C, and Δ, enabling the model to selectively remember or forget information based on the input content.

def selective_scan(x, delta, A, B, C, D):
    """Selective scan algorithm."""
    batch, seq_len, dim = x.shape
    state_size = A.shape[1]
    
    # Discretize
    deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, D, N)
    deltaB_x = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
    
    # Scan
    h = torch.zeros(batch, dim, state_size, device=x.device)
    ys = []
    
    for t in range(seq_len):
        h = deltaA[:, t] * h + deltaB_x[:, t]
        y = (h * C[:, t].unsqueeze(1)).sum(-1)
        ys.append(y)
    
    y = torch.stack(ys, dim=1)
    return y + x * D

Mamba Variants

ModelParametersSpeedup vs TransformerQuality
Mamba-130M130M5xCompetitive
Mamba-370M370M4xCompetitive
Mamba-1.4B1.4B3xCompetitive
Mamba-2.8B2.8B3xCompetitive
Jamba52B (Mamba + Attention)2xState-of-the-art

S4 (Structured State Spaces)

HiPPO Initialization

DfHiPPO Matrix

The HiPPO (High-order Polynomial Projection Operator) matrix provides principled initialization for the state transition matrix A, enabling SSMs to efficiently compress long-range dependencies.

HiPPO-LegS Matrix

A_{nk} = -\\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \\text{if } n > k \\\\ n+1 & \\text{if } n = k \\\\ 0 & \\text{if } n < k \\end{cases}

Here,

  • AnkA_{nk}=Element of the HiPPO matrix
  • n,kn, k=Row and column indices
def hippo_legs_matrix(N):
    """Generate HiPPO-LegS matrix."""
    P = torch.sqrt(1 + 2 * torch.arange(N).float())
    A = torch.zeros(N, N)
    
    for i in range(N):
        for j in range(N):
            if i > j:
                A[i, j] = P[i] * P[j]
            elif i == j:
                A[i, j] = i + 1
    
    return -A

S4 Architecture

class S4Layer:
    """S4 layer with structured parameterization."""
    
    def __init__(self, d_model, N=64):
        self.d_model = d_model
        self.N = N
        
        # HiPPO initialization
        A = hippo_legs_matrix(N)
        self.A_log = nn.Parameter(torch.log(-A))
        
        # Other parameters
        self.B = nn.Parameter(torch.randn(N))
        self.C = nn.Parameter(torch.randn(N))
        self.D = nn.Parameter(torch.ones(d_model))
        
        # Step size
        self.log_delta = nn.Parameter(torch.zeros(d_model))
    
    def forward(self, x):
        """Forward pass using convolution."""
        L = x.shape[1]
        
        # Discretize
        A = -torch.exp(self.A_log)
        delta = torch.exp(self.log_delta)
        
        # Compute kernel
        kernel = self.compute_kernel(A, self.B, self.C, delta, L)
        
        # Convolution
        y = fft_conv(x, kernel)
        
        return y + x * self.D

Linear Attention

Linear Attention Formulation

DfLinear Attention

Linear attention replaces the softmax attention with kernel-based approximation, reducing complexity from O(n²) to O(n) while maintaining the ability to capture long-range dependencies.

Linear Attention

textAttention(Q,K,V)=fracphi(Q)(phi(K)TV)phi(Q)phi(K)Tmathbf1\\text{Attention}(Q, K, V) = \\frac{\\phi(Q)(\\phi(K)^T V)}{\\phi(Q)\\phi(K)^T \\mathbf{1}}

Here,

  • Q,K,VQ, K, V=Query, Key, Value matrices
  • ϕ\phi=Feature map (kernel approximation)
  • 1\mathbf{1}=All-ones vector
class LinearAttention(nn.Module):
    """Linear attention with kernel approximation."""
    
    def __init__(self, d_model, n_heads, feature_map="elu"):
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
        
        if feature_map == "elu":
            self.feature_map = lambda x: F.elu(x) + 1
        elif feature_map == "relu":
            self.feature_map = lambda x: F.relu(x) + 1
    
    def forward(self, x):
        B, L, _ = x.shape
        
        # Project to Q, K, V
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.d_k)
        q, k, v = qkv.unbind(2)
        
        # Apply feature map
        q = self.feature_map(q)
        k = self.feature_map(k)
        
        # Linear attention computation
        kv = torch.einsum("bhld,bhle->bhde", k, v)
        qkv = torch.einsum("bhld,bhde->bhle", q, kv)
        
        # Normalize
        k_sum = k.sum(dim=1)
        denominator = torch.einsum("bhld,bhd->bhl", q, k_sum)
        denominator = denominator.unsqueeze(-1).clamp(min=1e-6)
        
        y = qkv / denominator
        
        return self.out(y.reshape(B, L, -1))

Hybrid Architectures

Mamba + Transformer Hybrids

DfHybrid SSM-Transformer

Hybrid architectures combine the efficiency of SSMs with the expressivity of attention, using SSMs for local processing and attention for global reasoning.

class HybridSSMTransformer(nn.Module):
    """Hybrid model with SSM and Transformer blocks."""
    
    def __init__(self, d_model, n_layers, n_heads, mamba_ratio=0.75):
        self.layers = nn.ModuleList()
        
        n_mamba = int(n_layers * mamba_ratio)
        n_attention = n_layers - n_mamba
        
        # Mamba layers for local processing
        for _ in range(n_mamba):
            self.layers.append(MambaBlock(d_model))
        
        # Attention layers for global reasoning
        for _ in range(n_attention):
            self.layers.append(TransformerBlock(d_model, n_heads))
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

When to Use SSMs vs Transformers

ScenarioRecommendedReason
Long sequences (>100K)SSMLinear complexity
Autoregressive generationSSMO(1) state per step
Complex reasoningTransformerGlobal attention
Memory-constrainedSSMO(d) state
Very long contextSSMNo KV cache limit
Task-specific fine-tuningTransformerBetter adaptation

The Jamba model (AI21) demonstrates that hybrid architectures can achieve state-of-the-art performance by using Mamba layers for 80% of the model and attention layers for 20%, getting the best of both worlds.

Practice Exercises

  1. Conceptual: Explain why the selective scan in Mamba enables content-aware reasoning. How does this differ from fixed-parameter SSMs?

  2. Mathematical: For a sequence of length n = 50,000 and dimension d = 512, calculate the FLOPs for training a transformer vs an SSM. What is the speedup ratio?

  3. Practical: Implement a simple S4 layer with HiPPO initialization and test it on a long-range dependency task.

  4. Research: Compare the performance of Mamba, Transformer, and hybrid models on the Long Range Arena benchmark. What explains the performance differences?

Key Takeaways:

  • SSMs offer linear-time alternatives to quadratic-complexity transformers
  • Mamba introduces input-dependent (selective) dynamics for content-aware reasoning
  • HiPPO initialization enables efficient long-range dependency modeling
  • Hybrid architectures combine SSM efficiency with transformer expressivity
  • SSMs excel at long sequences; transformers excel at complex reasoning

What to Learn Next

-> Mixture of Experts Sparse architectures that scale efficiently.

-> Flash Attention and Memory Efficiency Optimizing transformer attention for efficiency.

-> LLM Architecture Deep Dive Understanding transformer architectures in depth.

-> Scaling Laws and Chinchilla How model size affects performance.

-> Speculative Decoding Speeding up inference with draft models.

-> KV Cache Optimization Optimizing transformer inference memory.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement