Attention Mechanisms — Deep Dive

TransformersAttentionFree Lesson

Advertisement

Attention Mechanisms — Deep Dive

Attention allows models to dynamically focus on relevant parts of the input when producing each output element. It is the core innovation behind Transformers.


Why Attention?

Seq2seq models compress the entire input into a single fixed-length vector, creating an information bottleneck. Attention solves this by allowing the decoder to "look at" all encoder states at each decoding step.

DfAttention Mechanism

Attention computes a weighted sum of values (encoder states) where the weights are determined by the compatibility between a query (decoder state) and keys (encoder states):

Attention(q,K,V)=i=1Tαivi\text{Attention}(q, K, V) = \sum_{i=1}^{T} \alpha_i v_i

where αi\alpha_i are attention weights that sum to 1.


Bahdanau Attention

DfBahdanau (Additive) Attention

Proposed by Bahdanau et al. (2015), this uses a learned feedforward network to compute alignment scores:

eij=vTtanh(Wssi1+Whhj)e_{ij} = v^T \tanh(W_s s_{i-1} + W_h h_j)
αij=exp(eij)k=1Texp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T} \exp(e_{ik})}
ci=j=1Tαijhjc_i = \sum_{j=1}^{T} \alpha_{ij} h_j

Bahdanau Attention Score

eij=vTtanh(Wssi1+Whhj)e_{ij} = v^T \tanh(W_s s_{i-1} + W_h h_j)

Here,

  • eije_{ij}=Alignment score between decoder state i and encoder state j
  • vv=Learnable weight vector
  • WsW_s=Decoder state projection
  • WhW_h=Encoder state projection
  • si1s_{i-1}=Previous decoder hidden state
  • hjh_j=Encoder hidden state at position j

Luong Attention

DfLuong (Multiplicative) Attention

Proposed by Luong et al. (2015), this uses simpler dot-product or general scoring:

Dot: eij=siThje_{ij} = s_i^T h_j

General: eij=siTWahje_{ij} = s_i^T W_a h_j

Concat: eij=vTtanh(Wa[si;hj])e_{ij} = v^T \tanh(W_a [s_i; h_j])

Luong General Score

eij=siTWahje_{ij} = s_i^T W_a h_j

Here,

  • sis_i=Decoder hidden state at step i
  • WaW_a=Learnable alignment matrix
  • hjh_j=Encoder hidden state at position j
TypeScore FunctionComplexityParameters
DotsThs^T hO(d)O(d)None
GeneralsTWhs^T W hO(d2)O(d^2)WRd×dW \in \mathbb{R}^{d \times d}
ConcatvTtanh(W[s;h])v^T \tanh(W[s;h])O(d)O(d)W,vW, v
AdditivevTtanh(W1s+W2h)v^T \tanh(W_1 s + W_2 h)O(d)O(d)W1,W2,vW_1, W_2, v

Self-Attention

DfSelf-Attention

Self-attention computes attention within a single sequence, allowing each position to attend to all other positions:

Self-Attention(X)=softmax(XWQ(XWK)Tdk)XWV\text{Self-Attention}(X) = \text{softmax}\left(\frac{XW_Q (XW_K)^T}{\sqrt{d_k}}\right) XW_V

Each token can directly attend to every other token, capturing long-range dependencies in O(1)O(1) sequential operations.

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

ℹ️ Why Scale by sqrt(d_k)?

Without scaling, dot products grow large with dimension dkd_k, pushing softmax into regions with extremely small gradients. Dividing by dk\sqrt{d_k} keeps variance stable, enabling effective training. This was a critical insight in "Attention Is All You Need" (Vaswani et al., 2017).


Multi-Head Attention

DfMulti-Head Attention

Instead of a single attention function, project queries, keys, and values into hh different subspaces, compute attention in parallel, and concatenate:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

where each head:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
MultiHead(Q,K,V)=Concat(head1,,headh)WOwhereheadi=Attention(QWiQ,KWiK,VWiV)\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O \quad \text{where} \quad \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

💡 Why Multiple Heads?

Different heads can learn to attend to different types of relationships: syntactic structure, semantic similarity, positional patterns, etc. Typically h=8h = 8 or 1616 with dk=dmodel/hd_k = d_{model} / h.


Theorem: Attention as Differentiable Lookup

ThAttention as Soft Lookup

Hard lookup (dictionary retrieval) returns viv_i for a specific key kik_i. Attention generalizes this to a differentiable soft lookup that returns a weighted average of all values, with weights determined by key-query similarity. As temperature τ0\tau \to 0, attention converges to hard lookup.

Temperature-Dependent Attention

αi=exp(ei/τ)jexp(ej/τ)\alpha_i = \frac{\exp(e_i / \tau)}{\sum_j \exp(e_j / \tau)}

Here,

  • τ\tau=Temperature parameter
  • eie_i=Attention score for position i
  • αi\alpha_i=Attention weight

Full PyTorch Implementation

📝Example: Multi-Head Attention from Scratch

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, dropout=0.1):
        super().__init__()
        self.scale = math.sqrt(d_k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (batch, num_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        output = torch.matmul(attn_weights, V)
        return output, attn_weights


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention(self.d_k, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections: (batch, seq_len, d_model) -> (batch, seq_len, d_model)
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # Reshape: (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention
        output, attn_weights = self.attention(Q, K, V, mask)

        # Reshape: (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        output = output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )

        # Final linear projection
        output = self.W_o(output)
        return output, attn_weights


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


# Test multi-head attention
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)  # batch=2, seq_len=10, d_model=512
output, weights = mha(x, x, x)  # self-attention
print(f"Output shape: {output.shape}")   # (2, 10, 512)
print(f"Weights shape: {weights.shape}") # (2, 8, 10, 10)

Attention Patterns

ℹ️ Understanding Attention Patterns

  • Local attention: Tokens attend mainly to nearby positions — common in early layers
  • Global attention: Some tokens attend to all positions — e.g., [CLS] token in BERT
  • Distributed attention: Weights spread across many positions — captures semantic similarity
  • Block attention: Attention confined to windows — used in efficient Transformers (Longformer, BigBird)

Practice Exercises

  1. Implement attention variants: Code dot, general, and concat attention. Compare on a fixed sequence.

  2. Visualize attention: Train a seq2seq model and plot attention heatmaps. What patterns emerge?

  3. Multi-head analysis: Train with different numbers of heads (1, 4, 8, 16). How does performance change?

  4. Efficient attention: Implement linear attention (kernel-based) and compare with standard attention on long sequences.


Key Takeaways

📋Summary: Attention Mechanisms

  • Bahdanau attention: Additive scoring, learned alignment
  • Luong attention: Multiplicative scoring, simpler and faster
  • Self-attention: Each position attends to all positions in the same sequence
  • Scaled dot-product: softmax(QKT/dk)V\text{softmax}(QK^T / \sqrt{d_k}) V — the standard attention mechanism
  • Multi-head attention: Parallel attention heads capture different relationship types
  • Attention as soft lookup: Generalizes hard dictionary lookup to differentiable weighted average
  • Attention enables O(1)O(1) sequential operations — fully parallelizable
  • Foundation of Transformers, BERT, GPT, and modern NLP
  • See also: Transformers for the complete architecture

Advertisement

Need Expert Deep Learning Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement