Multimodal LLMs

ArchitecturesMulti-ModalFree Lesson

Advertisement

Multimodal LLMs

Multimodal large language models (MLLMs) extend the capabilities of text-only LLMs to process and reason about multiple modalities, primarily images and text. These models can understand visual content, answer questions about images, and generate text descriptions.

Why Multimodal?

Human cognition naturally integrates information from multiple senses. Multimodal LLMs aim to:

  • Ground language in perception: Connect words to visual concepts
  • Enable visual question answering: Answer questions about images
  • Support document understanding: Process charts, tables, and diagrams
  • Facilitate embodied AI: Enable robots and agents to perceive and act

Architecture Patterns

Early Fusion

Combine visual and textual tokens before the transformer:

An architecture where visual features (from a vision encoder) are projected into the same embedding space as text tokens and concatenated before entering the transformer. The model processes all modalities jointly from the start.

Architecture Diagram
Image → Vision Encoder → [IMG] [IMG] [IMG] [IMG] → ┐
                                                     ├→ Transformer → Text
Text → Tokenizer → [CLS] [the] [cat] [sat]       → ┘

Late Fusion

Process each modality independently, then combine representations:

An architecture where each modality is processed by separate encoders, and the resulting representations are combined (e.g., via cross-attention) at later layers for reasoning.

Architecture Diagram
Image → Vision Encoder → Visual Features → ┐
                                             ├→ Cross-Attention → Transformer → Text
Text → Text Encoder → Text Features       → ┘

Cross-Attention Fusion

Inject visual information into the language model via cross-attention layers:

A hybrid approach where visual features are injected into specific layers of the language model through cross-attention mechanisms, allowing the language model to attend to visual information at multiple processing stages.

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttentionFusion(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, text_features: torch.Tensor, visual_features: torch.Tensor):
        # Cross-attention: text attends to visual features
        attended, _ = self.cross_attn(
            query=text_features,
            key=visual_features,
            value=visual_features
        )
        text_features = self.norm(text_features + attended)
        text_features = text_features + self.ffn(text_features)
        return text_features

Vision-Language Models

GPT-4V Architecture

GPT-4V (GPT-4 Vision) uses a vision encoder (likely CLIP-based) to process images, which are then projected into the language model's embedding space. The exact architecture is proprietary, but it follows the early fusion pattern.

LLaVA (Large Language and Vision Assistant)

LLaVA is a prominent open-source vision-language model:

LLaVA Visual Projection

Hv=WcdotEv(Xv)+bH_v = W \\cdot E_v(X_v) + b

Here,

  • =
  • =
  • =
  • =
  • =
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer

class LLaVAModel(nn.Module):
    def __init__(
        self,
        vision_model_name: str = "openai/clip-vit-large-patch14-336",
        language_model_name: str = "meta-llama/Llama-2-7b-hf",
        vision_dim: int = 1024,
        language_dim: int = 4096
    ):
        super().__init__()
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
        self.language_model = LlamaForCausalLM.from_pretrained(language_model_name)
        self.tokenizer = LlamaTokenizer.from_pretrained(language_model_name)
        
        # Visual projection layer
        self.visual_projection = nn.Sequential(
            nn.Linear(vision_dim, language_dim),
            nn.GELU(),
            nn.Linear(language_dim, language_dim)
        )
        
        # Special tokens
        self.img_token_id = self.tokenizer.convert_tokens_to_ids("<img>")
    
    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
        """Encode images to visual tokens."""
        with torch.no_grad():
            vision_features = self.vision_encoder(images).last_hidden_state
        
        # Project to language model dimension
        visual_tokens = self.visual_projection(vision_features)
        return visual_tokens
    
    def forward(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None
    ):
        # Encode images
        visual_tokens = self.encode_image(images)
        
        # Get text embeddings
        text_embeddings = self.language_model.get_input_embeddings()(input_ids)
        
        # Find image token positions and insert visual tokens
        batch_size = input_ids.shape[0]
        combined_embeddings = []
        combined_labels = []
        
        for i in range(batch_size):
            # Find image token positions
            img_positions = (input_ids[i] == self.img_token_id).nonzero().squeeze()
            
            if img_positions.numel() > 0:
                # Split text at image positions
                text_parts = []
                visual_idx = 0
                
                prev_pos = 0
                for pos in img_positions:
                    text_parts.append(text_embeddings[i, prev_pos:pos])
                    text_parts.append(visual_tokens[i])
                    prev_pos = pos + 1
                    visual_idx += 1
                
                text_parts.append(text_embeddings[i, prev_pos:])
                combined = torch.cat(text_parts, dim=0)
            else:
                combined = text_embeddings[i]
            
            combined_embeddings.append(combined)
        
        # Pad to same length
        max_len = max(e.shape[0] for e in combined_embeddings)
        padded_embeddings = torch.zeros(batch_size, max_len, text_embeddings.shape[-1])
        padded_embeddings = padded_embeddings.to(text_embeddings.device)
        
        for i, emb in enumerate(combined_embeddings):
            padded_embeddings[i, :emb.shape[0]] = emb
        
        # Forward through language model
        outputs = self.language_model(
            inputs_embeds=padded_embeddings,
            attention_mask=attention_mask,
            labels=labels
        )
        
        return outputs
    
    def generate(
        self,
        images: torch.Tensor,
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.7
    ) -> str:
        """Generate text from image and prompt."""
        visual_tokens = self.encode_image(images)
        
        # Tokenize prompt
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(visual_tokens.device)
        
        # Insert visual tokens
        text_embeddings = self.language_model.get_input_embeddings()(input_ids)
        combined = torch.cat([
            text_embeddings[:, :1],  # BOS token
            visual_tokens[:, :256],  # Limit visual tokens
            text_embeddings[:, 1:]   # Rest of prompt
        ], dim=1)
        
        # Generate
        outputs = self.language_model.generate(
            inputs_embeds=combined,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True
        )
        
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

Visual Projection Layers

The projection layer maps visual features to the language model's embedding space:

Multi-Layer Projection

Hv=WncdottextGELU(Wn1cdotldotscdottextGELU(W1cdotEv(Xv)+b1)ldots+bn1)+bnH_v = W_n \\cdot \\text{GELU}(W_{n-1} \\cdot \\ldots \\cdot \\text{GELU}(W_1 \\cdot E_v(X_v) + b_1) \\ldots + b_{n-1}) + b_n

Here,

  • =
  • =
  • =
  • =
  • =

Projection Variants

TypeArchitectureProsCons
LinearSingle linear layerSimple, fastLimited capacity
MLP2-3 layer MLPBetter alignmentMore parameters
Q-FormerQuery-based (BLIP-2)Flexible, efficientComplex
PerceiverCross-attentionHandles variable lengthSlower

Contrastive Learning for Vision-Language Models

Many vision encoders are pre-trained using contrastive learning:

InfoNCE Contrastive Loss

mathcalLtextInfoNCE=frac1Nsumi=1Nlogfracexp(textsim(vi,ti)/tau)sumj=1Nexp(textsim(vi,tj)/tau)\\mathcal{L}_{\\text{InfoNCE}} = -\\frac{1}{N} \\sum_{i=1}^{N} \\log \\frac{\\exp(\\text{sim}(v_i, t_i) / \\tau)}{\\sum_{j=1}^{N} \\exp(\\text{sim}(v_i, t_j) / \\tau)}

Here,

  • =
  • =
  • =
  • =
  • =
mathcalLtextCLIP=frac12(mathcalLtextimagetotexttext+mathcalLtexttexttotextimage)\\mathcal{L}_{\\text{CLIP}} = \\frac{1}{2}(\\mathcal{L}_{\\text{image} \\to \\text{text}} + \\mathcal{L}_{\\text{text} \\to \\text{image}})
def clip_contrastive_loss(
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    temperature: float = 0.07
) -> torch.Tensor:
    """Compute CLIP contrastive loss."""
    # Normalize features
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)
    
    # Compute similarity matrix
    logits = torch.matmul(image_features, text_features.T) / temperature
    
    # Labels: diagonal (each image matches its caption)
    labels = torch.arange(len(logits), device=logits.device)
    
    # Symmetric loss
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2

Training Pipeline for Multi-Modal Models

Stage 1: Vision Encoder Pre-training

Train or fine-tune the vision encoder on image-text pairs:

# Using CLIP as vision encoder
from transformers import CLIPModel

clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
# Freeze CLIP and use as feature extractor
for param in clip.parameters():
    param.requires_grad = False

Stage 2: Visual-Language Alignment

Train the projection layer to align visual and text representations:

def train_alignment_stage(
    model: LLaVAModel,
    dataloader,
    epochs: int = 1,
    lr: float = 1e-3
):
    # Freeze vision encoder and language model
    for param in model.vision_encoder.parameters():
        param.requires_grad = False
    for param in model.language_model.parameters():
        param.requires_grad = False
    
    # Only train projection layer
    optimizer = torch.optim.AdamW(
        model.visual_projection.parameters(),
        lr=lr
    )
    
    for epoch in range(epochs):
        for batch in dataloader:
            images, captions = batch
            outputs = model(images, captions)
            loss = outputs.loss
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

Stage 3: End-to-End Fine-tuning

Fine-tune the entire model on instruction-following data:

def train_instruction_stage(
    model: LLaVAModel,
    dataloader,
    epochs: int = 3,
    lr: float = 2e-5
):
    # Unfreeze language model, keep vision encoder frozen
    for param in model.vision_encoder.parameters():
        param.requires_grad = False
    for param in model.language_model.parameters():
        param.requires_grad = True
    
    optimizer = torch.optim.AdamW(
        [
            {"params": model.visual_projection.parameters()},
            {"params": model.language_model.parameters(), "lr": lr * 0.1}
        ],
        lr=lr
    )
    
    for epoch in range(epochs):
        for batch in dataloader:
            images, input_ids, labels = batch
            outputs = model(images, input_ids, labels=labels)
            loss = outputs.loss
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

Applications and Limitations

Key Applications

ApplicationDescriptionExample Models
Visual QAAnswer questions about imagesLLaVA, GPT-4V
Document UnderstandingParse charts, tables, diagramsDocVQA models
Image CaptioningGenerate text descriptionsBLIP-2, Flamingo
Visual ReasoningMulti-step reasoning about imagesGPT-4V
Medical ImagingAnalyze X-rays, MRIsMed-PaLM M

Limitations

  1. Hallucination: Models may describe objects not present in images
  2. Spatial Reasoning: Struggle with precise spatial relationships
  3. Counting: Difficulty counting objects accurately
  4. Fine-grained Details: May miss small or subtle visual elements
  5. Compositional Reasoning: Struggle with complex compositions

Research shows that multimodal LLMs still lag behind human performance on tasks requiring precise spatial reasoning, counting, and compositional understanding. These limitations are active areas of research.

Summary

  • Multimodal LLMs combine vision encoders with language models
  • Early fusion concatenates visual and text tokens before the transformer
  • Cross-attention fusion injects visual information at multiple layers
  • Visual projection layers map visual features to the language model's embedding space
  • Contrastive learning (InfoNCE) pre-trains vision encoders on image-text pairs
  • Training typically involves three stages: vision pre-training, alignment, and instruction tuning
  • Key limitations include hallucination, spatial reasoning, and counting

Practice Exercises

  1. Projection Analysis: Compare linear vs MLP projection layers for aligning CLIP features with a language model. Which achieves better alignment?

  2. Visual QA: Build a simple visual QA system using LLaVA. Test it on 20 images with different question types.

  3. Hallucination Detection: Analyze when the model hallucinates objects. What types of images trigger hallucination?

  4. Fine-tuning: Fine-tune a vision-language model on a specific domain (e.g., medical images). How does domain-specific training affect performance?

  5. Architecture Comparison: Compare early fusion vs cross-attention fusion on the same task. What are the tradeoffs?


Previous: 17 - Long Context Window ← | Next: 19 - Mixture of Experts →

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement