CW

Self-RAG and Adaptive Retrieval

Advanced RAGAdaptive RetrievalFree Lesson

Advertisement

Advanced RAG

Self-RAG — Teaching LLMs When to Retrieve

Not every question requires retrieval. Self-RAG trains models to decide when to retrieve, what to retrieve, and whether the retrieved information is actually useful — all through special reflection tokens.

  • Adaptive Retrieval — Retrieve only when the model needs external knowledge
  • Reflection Tokens — Special tokens that enable self-evaluation
  • Dynamic Retrieval — Adjust retrieval strategy per query

The smartest retriever knows when not to retrieve.

Self-RAG and Adaptive Retrieval

Standard RAG always retrieves, even when the model already knows the answer. This wastes compute and can introduce irrelevant information. Self-RAG (Asai et al., 2023) trains the model to make retrieval decisions dynamically.

DfSelf-RAG

Self-RAG is a framework that trains LLMs to adaptively retrieve documents on-demand using special reflection tokens. The model learns to generate tokens that indicate when retrieval is needed, whether retrieved passages are relevant, and whether the generation is supported by evidence.

Reflection Tokens

DfReflection Tokens

Reflection tokens are special tokens generated during the forward pass that enable self-evaluation. They are trained to predict: (1) whether to retrieve [Retrieve], (2) whether a passage is relevant [IsRel], (3) whether the generation is supported [IsSup], (4) whether the answer is useful [IsUse].

Token Types

TokenValuesMeaning
[Retrieve]Yes / NoWhether to retrieve for this generation step
[IsRel]Relevant / IrrelevantWhether retrieved passage is relevant to query
[IsSup]Fully Supported / Partially Supported / No SupportWhether generation is supported by passages
[IsUse]Useful / Not UsefulWhether the overall response is useful

Implementation

class SelfRAG(nn.Module):
    def __init__(self, base_model, vocab_size):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size
        
        # Reflection token embeddings
        self.retrieve_head = nn.Linear(hidden_size, 2)  # Yes/No
        self.isrel_head = nn.Linear(hidden_size, 2)     # Relevant/Irrelevant
        self.issup_head = nn.Linear(hidden_size, 3)     # Fully/Partial/No
        self.isuse_head = nn.Linear(hidden_size, 2)     # Useful/Not Useful
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        
        # Generate reflection logits
        retrieve_logits = self.retrieve_head(hidden_states[:, -1, :])
        isrel_logits = self.isrel_head(hidden_states[:, -1, :])
        issup_logits = self.issup_head(hidden_states[:, -1, :])
        isuse_logits = self.isuse_head(hidden_states[:, -1, :])
        
        return {
            "retrieve": retrieve_logits,
            "isrel": isrel_logits,
            "issup": issup_logits,
            "isuse": isuse_logits,
            "hidden": hidden_states
        }

Training Self-RAG

Retrieval Decision Training

DfRetrieval Decision Training

Train the model to predict whether retrieval would improve the response. The training data includes pairs of (query, context) where some contexts benefit from retrieval and others do not.

def train_retrieval_decision(model, data, optimizer):
    """Train the retrieval decision head."""
    total_loss = 0
    
    for batch in data:
        # Forward pass
        outputs = model(batch["input_ids"])
        
        # Retrieval decision loss
        retrieve_loss = F.cross_entropy(
            outputs["retrieve"],
            batch["retrieve_labels"]  # 0=No, 1=Yes
        )
        
        # Relevance prediction loss
        isrel_loss = F.cross_entropy(
            outputs["isrel"],
            batch["isrel_labels"]  # 0=Irrelevant, 1=Relevant
        )
        
        # Support prediction loss
        issup_loss = F.cross_entropy(
            outputs["issup"],
            batch["issup_labels"]  # 0=No, 1=Partial, 2=Full
        )
        
        loss = retrieve_loss + isrel_loss + issup_loss
        total_loss += loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return total_loss / len(data)

Adaptive Retrieval Strategy

Dynamic Retrieval Threshold

DfDynamic Threshold

Dynamic threshold adjusts the retrieval probability based on the model's confidence. When the model is confident in its parametric knowledge, it retrieves less often. When uncertain, it retrieves more aggressively.

def adaptive_retrieve(query, model, retriever, confidence_threshold=0.7):
    """Retrieve only when model confidence is low."""
    # Get model's initial prediction
    with torch.no_grad():
        outputs = model(query)
        retrieve_prob = F.softmax(outputs["retrieve"], dim=-1)[0, 1].item()
    
    # Decide whether to retrieve
    if retrieve_prob > confidence_threshold:
        # Model is uncertain, retrieve
        documents = retriever.retrieve(query, top_k=5)
        
        # Filter by relevance
        relevant_docs = []
        for doc in documents:
            with torch.no_grad():
                rel_outputs = model(query, doc)
                isrel_prob = F.softmax(rel_outputs["isrel"], dim=-1)[0, 1].item()
                if isrel_prob > 0.5:
                    relevant_docs.append(doc)
        
        return relevant_docs
    else:
        # Model is confident, no retrieval needed
        return []

Iterative Retrieval

DfIterative Self-RAG

Iterative Self-RAG performs multiple rounds of retrieval and generation. After each generation step, the model evaluates whether additional retrieval would improve the response, and retrieves again if needed.

def iterative_self_rag(query, model, retriever, max_iterations=3):
    """Perform iterative retrieval and generation."""
    context = ""
    generated = ""
    
    for i in range(max_iterations):
        # Generate with current context
        full_prompt = f"Query: {query}\nContext: {context}\nAnswer:"
        outputs = model(full_prompt)
        
        # Check if more retrieval is needed
        retrieve_prob = F.softmax(outputs["retrieve"], dim=-1)[0, 1].item()
        
        if retrieve_prob < 0.3 or i == max_iterations - 1:
            # No more retrieval needed
            generated = decode(outputs["hidden"])
            break
        
        # Retrieve more information
        new_docs = retriever.retrieve(query + " " + generated, top_k=3)
        
        # Filter by relevance
        for doc in new_docs:
            rel_outputs = model(query, doc)
            if F.softmax(rel_outputs["isrel"], dim=-1)[0, 1].item() > 0.5:
                context += f"\n{doc}"
    
    return generated

Self-RAG vs Standard RAG

FeatureStandard RAGSelf-RAG
Retrieval decisionAlways retrieveAdaptive (on-demand)
Relevance filteringExternal rerankerModel self-evaluates
Support verificationNoneModel verifies claims
Usefulness evaluationHuman feedbackModel self-evaluates
Computational costAlways pays retrieval costRetrieves only when needed
Quality controlPost-hocDuring generation

Self-RAG reduces retrieval calls by 30-50% while maintaining or improving quality. This translates directly to reduced latency and cost in production systems.

CRAG: Corrective RAG

DfCorrective RAG (CRAG)

CRAG (Yan et al., 2024) adds a lightweight retrieval evaluator that assesses document quality and triggers different actions: (1) if correct, use documents directly; (2) if uncertain, combine with web search; (3) if incorrect, perform web search only.

def crag_retrieve(query, retriever, evaluator, web_searcher):
    """Corrective RAG with quality-aware retrieval."""
    # Retrieve initial documents
    documents = retriever.retrieve(query, top_k=5)
    
    # Evaluate retrieval quality
    quality_scores = evaluator.score_documents(query, documents)
    avg_quality = sum(quality_scores) / len(quality_scores)
    
    if avg_quality > 0.8:
        # High quality: use documents directly
        return documents, "direct"
    elif avg_quality > 0.4:
        # Uncertain: combine with web search
        web_docs = web_searcher.search(query, top_k=3)
        return documents + web_docs, "combined"
    else:
        # Low quality: web search only
        web_docs = web_searcher.search(query, top_k=5)
        return web_docs, "web_only"

Practice Exercises

  1. Self-RAG Training: Train a Self-RAG model on a dataset with retrieval labels. How does the retrieval decision accuracy compare to a fixed threshold?

  2. Adaptive Threshold: Implement a dynamic retrieval threshold that adjusts based on query complexity. How does this affect retrieval frequency and answer quality?

  3. Iterative Retrieval: Compare single-round vs iterative Self-RAG on questions requiring multiple pieces of evidence. How many iterations are typically needed?

  4. CRAG Implementation: Implement CRAG with a quality evaluator. How does the quality threshold affect the balance between direct use and web search?

Key Takeaways

Summary: Self-RAG and Adaptive Retrieval

  • Self-RAG trains models to decide when to retrieve using reflection tokens
  • Reflection tokens evaluate relevance, support, and usefulness
  • Adaptive retrieval reduces unnecessary retrieval calls by 30-50%
  • Iterative Self-RAG performs multiple retrieval rounds for complex questions
  • CRAG adds quality evaluation to trigger different retrieval strategies
  • Dynamic thresholds adjust retrieval based on model confidence
  • Quality filtering removes irrelevant documents before generation
  • Self-evaluation enables continuous quality improvement

What to Learn Next

-> RAG System Design Advanced RAG architecture and design patterns.

-> Retrieval-Augmented Generation RAG fundamentals and basic implementation.

-> Graph RAG and Knowledge Graphs Structured knowledge for better reasoning.

-> Agentic RAG Systems Agent-based approaches to retrieval.

-> Chain-of-Thought Reasoning Teaching models to reason step by step.

-> Prompt Engineering Getting the most out of language models.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement