Advanced RAG
Self-RAG — Teaching LLMs When to Retrieve
Not every question requires retrieval. Self-RAG trains models to decide when to retrieve, what to retrieve, and whether the retrieved information is actually useful — all through special reflection tokens.
- Adaptive Retrieval — Retrieve only when the model needs external knowledge
- Reflection Tokens — Special tokens that enable self-evaluation
- Dynamic Retrieval — Adjust retrieval strategy per query
The smartest retriever knows when not to retrieve.
Self-RAG and Adaptive Retrieval
Standard RAG always retrieves, even when the model already knows the answer. This wastes compute and can introduce irrelevant information. Self-RAG (Asai et al., 2023) trains the model to make retrieval decisions dynamically.
DfSelf-RAG
Self-RAG is a framework that trains LLMs to adaptively retrieve documents on-demand using special reflection tokens. The model learns to generate tokens that indicate when retrieval is needed, whether retrieved passages are relevant, and whether the generation is supported by evidence.
Reflection Tokens
DfReflection Tokens
Reflection tokens are special tokens generated during the forward pass that enable self-evaluation. They are trained to predict: (1) whether to retrieve [Retrieve], (2) whether a passage is relevant [IsRel], (3) whether the generation is supported [IsSup], (4) whether the answer is useful [IsUse].
Token Types
| Token | Values | Meaning |
|---|---|---|
| [Retrieve] | Yes / No | Whether to retrieve for this generation step |
| [IsRel] | Relevant / Irrelevant | Whether retrieved passage is relevant to query |
| [IsSup] | Fully Supported / Partially Supported / No Support | Whether generation is supported by passages |
| [IsUse] | Useful / Not Useful | Whether the overall response is useful |
Implementation
class SelfRAG(nn.Module):
def __init__(self, base_model, vocab_size):
super().__init__()
self.base_model = base_model
hidden_size = base_model.config.hidden_size
# Reflection token embeddings
self.retrieve_head = nn.Linear(hidden_size, 2) # Yes/No
self.isrel_head = nn.Linear(hidden_size, 2) # Relevant/Irrelevant
self.issup_head = nn.Linear(hidden_size, 3) # Fully/Partial/No
self.isuse_head = nn.Linear(hidden_size, 2) # Useful/Not Useful
def forward(self, input_ids, attention_mask=None):
outputs = self.base_model(input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state
# Generate reflection logits
retrieve_logits = self.retrieve_head(hidden_states[:, -1, :])
isrel_logits = self.isrel_head(hidden_states[:, -1, :])
issup_logits = self.issup_head(hidden_states[:, -1, :])
isuse_logits = self.isuse_head(hidden_states[:, -1, :])
return {
"retrieve": retrieve_logits,
"isrel": isrel_logits,
"issup": issup_logits,
"isuse": isuse_logits,
"hidden": hidden_states
}
Training Self-RAG
Retrieval Decision Training
DfRetrieval Decision Training
Train the model to predict whether retrieval would improve the response. The training data includes pairs of (query, context) where some contexts benefit from retrieval and others do not.
def train_retrieval_decision(model, data, optimizer):
"""Train the retrieval decision head."""
total_loss = 0
for batch in data:
# Forward pass
outputs = model(batch["input_ids"])
# Retrieval decision loss
retrieve_loss = F.cross_entropy(
outputs["retrieve"],
batch["retrieve_labels"] # 0=No, 1=Yes
)
# Relevance prediction loss
isrel_loss = F.cross_entropy(
outputs["isrel"],
batch["isrel_labels"] # 0=Irrelevant, 1=Relevant
)
# Support prediction loss
issup_loss = F.cross_entropy(
outputs["issup"],
batch["issup_labels"] # 0=No, 1=Partial, 2=Full
)
loss = retrieve_loss + isrel_loss + issup_loss
total_loss += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
return total_loss / len(data)
Adaptive Retrieval Strategy
Dynamic Retrieval Threshold
DfDynamic Threshold
Dynamic threshold adjusts the retrieval probability based on the model's confidence. When the model is confident in its parametric knowledge, it retrieves less often. When uncertain, it retrieves more aggressively.
def adaptive_retrieve(query, model, retriever, confidence_threshold=0.7):
"""Retrieve only when model confidence is low."""
# Get model's initial prediction
with torch.no_grad():
outputs = model(query)
retrieve_prob = F.softmax(outputs["retrieve"], dim=-1)[0, 1].item()
# Decide whether to retrieve
if retrieve_prob > confidence_threshold:
# Model is uncertain, retrieve
documents = retriever.retrieve(query, top_k=5)
# Filter by relevance
relevant_docs = []
for doc in documents:
with torch.no_grad():
rel_outputs = model(query, doc)
isrel_prob = F.softmax(rel_outputs["isrel"], dim=-1)[0, 1].item()
if isrel_prob > 0.5:
relevant_docs.append(doc)
return relevant_docs
else:
# Model is confident, no retrieval needed
return []
Iterative Retrieval
DfIterative Self-RAG
Iterative Self-RAG performs multiple rounds of retrieval and generation. After each generation step, the model evaluates whether additional retrieval would improve the response, and retrieves again if needed.
def iterative_self_rag(query, model, retriever, max_iterations=3):
"""Perform iterative retrieval and generation."""
context = ""
generated = ""
for i in range(max_iterations):
# Generate with current context
full_prompt = f"Query: {query}\nContext: {context}\nAnswer:"
outputs = model(full_prompt)
# Check if more retrieval is needed
retrieve_prob = F.softmax(outputs["retrieve"], dim=-1)[0, 1].item()
if retrieve_prob < 0.3 or i == max_iterations - 1:
# No more retrieval needed
generated = decode(outputs["hidden"])
break
# Retrieve more information
new_docs = retriever.retrieve(query + " " + generated, top_k=3)
# Filter by relevance
for doc in new_docs:
rel_outputs = model(query, doc)
if F.softmax(rel_outputs["isrel"], dim=-1)[0, 1].item() > 0.5:
context += f"\n{doc}"
return generated
Self-RAG vs Standard RAG
| Feature | Standard RAG | Self-RAG |
|---|---|---|
| Retrieval decision | Always retrieve | Adaptive (on-demand) |
| Relevance filtering | External reranker | Model self-evaluates |
| Support verification | None | Model verifies claims |
| Usefulness evaluation | Human feedback | Model self-evaluates |
| Computational cost | Always pays retrieval cost | Retrieves only when needed |
| Quality control | Post-hoc | During generation |
Self-RAG reduces retrieval calls by 30-50% while maintaining or improving quality. This translates directly to reduced latency and cost in production systems.
CRAG: Corrective RAG
DfCorrective RAG (CRAG)
CRAG (Yan et al., 2024) adds a lightweight retrieval evaluator that assesses document quality and triggers different actions: (1) if correct, use documents directly; (2) if uncertain, combine with web search; (3) if incorrect, perform web search only.
def crag_retrieve(query, retriever, evaluator, web_searcher):
"""Corrective RAG with quality-aware retrieval."""
# Retrieve initial documents
documents = retriever.retrieve(query, top_k=5)
# Evaluate retrieval quality
quality_scores = evaluator.score_documents(query, documents)
avg_quality = sum(quality_scores) / len(quality_scores)
if avg_quality > 0.8:
# High quality: use documents directly
return documents, "direct"
elif avg_quality > 0.4:
# Uncertain: combine with web search
web_docs = web_searcher.search(query, top_k=3)
return documents + web_docs, "combined"
else:
# Low quality: web search only
web_docs = web_searcher.search(query, top_k=5)
return web_docs, "web_only"
Practice Exercises
-
Self-RAG Training: Train a Self-RAG model on a dataset with retrieval labels. How does the retrieval decision accuracy compare to a fixed threshold?
-
Adaptive Threshold: Implement a dynamic retrieval threshold that adjusts based on query complexity. How does this affect retrieval frequency and answer quality?
-
Iterative Retrieval: Compare single-round vs iterative Self-RAG on questions requiring multiple pieces of evidence. How many iterations are typically needed?
-
CRAG Implementation: Implement CRAG with a quality evaluator. How does the quality threshold affect the balance between direct use and web search?
Key Takeaways
Summary: Self-RAG and Adaptive Retrieval
- Self-RAG trains models to decide when to retrieve using reflection tokens
- Reflection tokens evaluate relevance, support, and usefulness
- Adaptive retrieval reduces unnecessary retrieval calls by 30-50%
- Iterative Self-RAG performs multiple retrieval rounds for complex questions
- CRAG adds quality evaluation to trigger different retrieval strategies
- Dynamic thresholds adjust retrieval based on model confidence
- Quality filtering removes irrelevant documents before generation
- Self-evaluation enables continuous quality improvement
What to Learn Next
-> RAG System Design Advanced RAG architecture and design patterns.
-> Retrieval-Augmented Generation RAG fundamentals and basic implementation.
-> Graph RAG and Knowledge Graphs Structured knowledge for better reasoning.
-> Agentic RAG Systems Agent-based approaches to retrieval.
-> Chain-of-Thought Reasoning Teaching models to reason step by step.
-> Prompt Engineering Getting the most out of language models.