Multimodal LLMs
Multimodal large language models (MLLMs) extend the capabilities of text-only LLMs to process and reason about multiple modalities, primarily images and text. These models can understand visual content, answer questions about images, and generate text descriptions.
Why Multimodal?
Human cognition naturally integrates information from multiple senses. Multimodal LLMs aim to:
- Ground language in perception: Connect words to visual concepts
- Enable visual question answering: Answer questions about images
- Support document understanding: Process charts, tables, and diagrams
- Facilitate embodied AI: Enable robots and agents to perceive and act
Architecture Patterns
Early Fusion
Combine visual and textual tokens before the transformer:
An architecture where visual features (from a vision encoder) are projected into the same embedding space as text tokens and concatenated before entering the transformer. The model processes all modalities jointly from the start.
Image → Vision Encoder → [IMG] [IMG] [IMG] [IMG] → ┐
├→ Transformer → Text
Text → Tokenizer → [CLS] [the] [cat] [sat] → ┘
Late Fusion
Process each modality independently, then combine representations:
An architecture where each modality is processed by separate encoders, and the resulting representations are combined (e.g., via cross-attention) at later layers for reasoning.
Image → Vision Encoder → Visual Features → ┐
├→ Cross-Attention → Transformer → Text
Text → Text Encoder → Text Features → ┘
Cross-Attention Fusion
Inject visual information into the language model via cross-attention layers:
A hybrid approach where visual features are injected into specific layers of the language model through cross-attention mechanisms, allowing the language model to attend to visual information at multiple processing stages.
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttentionFusion(nn.Module):
def __init__(self, dim: int, num_heads: int = 8):
super().__init__()
self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, text_features: torch.Tensor, visual_features: torch.Tensor):
# Cross-attention: text attends to visual features
attended, _ = self.cross_attn(
query=text_features,
key=visual_features,
value=visual_features
)
text_features = self.norm(text_features + attended)
text_features = text_features + self.ffn(text_features)
return text_features
Vision-Language Models
GPT-4V Architecture
GPT-4V (GPT-4 Vision) uses a vision encoder (likely CLIP-based) to process images, which are then projected into the language model's embedding space. The exact architecture is proprietary, but it follows the early fusion pattern.
LLaVA (Large Language and Vision Assistant)
LLaVA is a prominent open-source vision-language model:
LLaVA Visual Projection
Here,
- =
- =
- =
- =
- =
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
class LLaVAModel(nn.Module):
def __init__(
self,
vision_model_name: str = "openai/clip-vit-large-patch14-336",
language_model_name: str = "meta-llama/Llama-2-7b-hf",
vision_dim: int = 1024,
language_dim: int = 4096
):
super().__init__()
self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
self.language_model = LlamaForCausalLM.from_pretrained(language_model_name)
self.tokenizer = LlamaTokenizer.from_pretrained(language_model_name)
# Visual projection layer
self.visual_projection = nn.Sequential(
nn.Linear(vision_dim, language_dim),
nn.GELU(),
nn.Linear(language_dim, language_dim)
)
# Special tokens
self.img_token_id = self.tokenizer.convert_tokens_to_ids("<img>")
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
"""Encode images to visual tokens."""
with torch.no_grad():
vision_features = self.vision_encoder(images).last_hidden_state
# Project to language model dimension
visual_tokens = self.visual_projection(vision_features)
return visual_tokens
def forward(
self,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None
):
# Encode images
visual_tokens = self.encode_image(images)
# Get text embeddings
text_embeddings = self.language_model.get_input_embeddings()(input_ids)
# Find image token positions and insert visual tokens
batch_size = input_ids.shape[0]
combined_embeddings = []
combined_labels = []
for i in range(batch_size):
# Find image token positions
img_positions = (input_ids[i] == self.img_token_id).nonzero().squeeze()
if img_positions.numel() > 0:
# Split text at image positions
text_parts = []
visual_idx = 0
prev_pos = 0
for pos in img_positions:
text_parts.append(text_embeddings[i, prev_pos:pos])
text_parts.append(visual_tokens[i])
prev_pos = pos + 1
visual_idx += 1
text_parts.append(text_embeddings[i, prev_pos:])
combined = torch.cat(text_parts, dim=0)
else:
combined = text_embeddings[i]
combined_embeddings.append(combined)
# Pad to same length
max_len = max(e.shape[0] for e in combined_embeddings)
padded_embeddings = torch.zeros(batch_size, max_len, text_embeddings.shape[-1])
padded_embeddings = padded_embeddings.to(text_embeddings.device)
for i, emb in enumerate(combined_embeddings):
padded_embeddings[i, :emb.shape[0]] = emb
# Forward through language model
outputs = self.language_model(
inputs_embeds=padded_embeddings,
attention_mask=attention_mask,
labels=labels
)
return outputs
def generate(
self,
images: torch.Tensor,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7
) -> str:
"""Generate text from image and prompt."""
visual_tokens = self.encode_image(images)
# Tokenize prompt
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(visual_tokens.device)
# Insert visual tokens
text_embeddings = self.language_model.get_input_embeddings()(input_ids)
combined = torch.cat([
text_embeddings[:, :1], # BOS token
visual_tokens[:, :256], # Limit visual tokens
text_embeddings[:, 1:] # Rest of prompt
], dim=1)
# Generate
outputs = self.language_model.generate(
inputs_embeds=combined,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
Visual Projection Layers
The projection layer maps visual features to the language model's embedding space:
Multi-Layer Projection
Here,
- =
- =
- =
- =
- =
Projection Variants
| Type | Architecture | Pros | Cons |
|---|---|---|---|
| Linear | Single linear layer | Simple, fast | Limited capacity |
| MLP | 2-3 layer MLP | Better alignment | More parameters |
| Q-Former | Query-based (BLIP-2) | Flexible, efficient | Complex |
| Perceiver | Cross-attention | Handles variable length | Slower |
Contrastive Learning for Vision-Language Models
Many vision encoders are pre-trained using contrastive learning:
InfoNCE Contrastive Loss
Here,
- =
- =
- =
- =
- =
def clip_contrastive_loss(
image_features: torch.Tensor,
text_features: torch.Tensor,
temperature: float = 0.07
) -> torch.Tensor:
"""Compute CLIP contrastive loss."""
# Normalize features
image_features = F.normalize(image_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# Compute similarity matrix
logits = torch.matmul(image_features, text_features.T) / temperature
# Labels: diagonal (each image matches its caption)
labels = torch.arange(len(logits), device=logits.device)
# Symmetric loss
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.T, labels)
return (loss_i2t + loss_t2i) / 2
Training Pipeline for Multi-Modal Models
Stage 1: Vision Encoder Pre-training
Train or fine-tune the vision encoder on image-text pairs:
# Using CLIP as vision encoder
from transformers import CLIPModel
clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
# Freeze CLIP and use as feature extractor
for param in clip.parameters():
param.requires_grad = False
Stage 2: Visual-Language Alignment
Train the projection layer to align visual and text representations:
def train_alignment_stage(
model: LLaVAModel,
dataloader,
epochs: int = 1,
lr: float = 1e-3
):
# Freeze vision encoder and language model
for param in model.vision_encoder.parameters():
param.requires_grad = False
for param in model.language_model.parameters():
param.requires_grad = False
# Only train projection layer
optimizer = torch.optim.AdamW(
model.visual_projection.parameters(),
lr=lr
)
for epoch in range(epochs):
for batch in dataloader:
images, captions = batch
outputs = model(images, captions)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
Stage 3: End-to-End Fine-tuning
Fine-tune the entire model on instruction-following data:
def train_instruction_stage(
model: LLaVAModel,
dataloader,
epochs: int = 3,
lr: float = 2e-5
):
# Unfreeze language model, keep vision encoder frozen
for param in model.vision_encoder.parameters():
param.requires_grad = False
for param in model.language_model.parameters():
param.requires_grad = True
optimizer = torch.optim.AdamW(
[
{"params": model.visual_projection.parameters()},
{"params": model.language_model.parameters(), "lr": lr * 0.1}
],
lr=lr
)
for epoch in range(epochs):
for batch in dataloader:
images, input_ids, labels = batch
outputs = model(images, input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
Applications and Limitations
Key Applications
| Application | Description | Example Models |
|---|---|---|
| Visual QA | Answer questions about images | LLaVA, GPT-4V |
| Document Understanding | Parse charts, tables, diagrams | DocVQA models |
| Image Captioning | Generate text descriptions | BLIP-2, Flamingo |
| Visual Reasoning | Multi-step reasoning about images | GPT-4V |
| Medical Imaging | Analyze X-rays, MRIs | Med-PaLM M |
Limitations
- Hallucination: Models may describe objects not present in images
- Spatial Reasoning: Struggle with precise spatial relationships
- Counting: Difficulty counting objects accurately
- Fine-grained Details: May miss small or subtle visual elements
- Compositional Reasoning: Struggle with complex compositions
Research shows that multimodal LLMs still lag behind human performance on tasks requiring precise spatial reasoning, counting, and compositional understanding. These limitations are active areas of research.
Summary
- Multimodal LLMs combine vision encoders with language models
- Early fusion concatenates visual and text tokens before the transformer
- Cross-attention fusion injects visual information at multiple layers
- Visual projection layers map visual features to the language model's embedding space
- Contrastive learning (InfoNCE) pre-trains vision encoders on image-text pairs
- Training typically involves three stages: vision pre-training, alignment, and instruction tuning
- Key limitations include hallucination, spatial reasoning, and counting
Practice Exercises
-
Projection Analysis: Compare linear vs MLP projection layers for aligning CLIP features with a language model. Which achieves better alignment?
-
Visual QA: Build a simple visual QA system using LLaVA. Test it on 20 images with different question types.
-
Hallucination Detection: Analyze when the model hallucinates objects. What types of images trigger hallucination?
-
Fine-tuning: Fine-tune a vision-language model on a specific domain (e.g., medical images). How does domain-specific training affect performance?
-
Architecture Comparison: Compare early fusion vs cross-attention fusion on the same task. What are the tradeoffs?
Previous: 17 - Long Context Window ← | Next: 19 - Mixture of Experts →