CW

Curriculum Learning for LLMs

Advanced TrainingTraining StrategiesFree Lesson

Advertisement

Advanced Training

Curriculum Learning for LLMs — The Order of Data Matters

The order in which data is presented during training significantly impacts model performance. Curriculum learning strategically orders training examples from easy to hard, improving convergence speed and final performance.

  • Difficulty-Based Ordering — Present easy examples first, hard examples later
  • Domain Curriculum — Start with general knowledge, add specialized domains
  • Epoch Scheduling — Optimal data repetition strategies across training

How you teach matters as much as what you teach.

Curriculum Learning for LLMs

Curriculum learning in LLMs draws inspiration from human education: we learn fundamentals before advanced topics. Research shows that presenting training data in a strategic order — from simple to complex, general to specific — can improve both convergence speed and final model quality.

DfCurriculum Learning

Curriculum learning is a training strategy where data is presented to the model in a meaningful order rather than randomly. The ordering is based on a difficulty metric, and the model progressively learns from easy to hard examples.

Why Data Order Matters

Theoretical Foundation

Curriculum Learning Objective

Lcurriculum=t=1Tw(t)L(xσ(t),θt)\mathcal{L}_{\text{curriculum}} = \sum_{t=1}^{T} w(t) \cdot \mathcal{L}(x_{\sigma(t)}, \theta_t)

Here,

  • w(t)w(t)=Difficulty weight at training step t
  • σ(t)\sigma(t)=Permutation (ordering) of training examples
  • θt\theta_t=Model parameters at step t
  • TT=Total training steps

The key insight: random sampling treats all examples equally, but curriculum learning assigns higher weight (more training time) to harder examples as the model becomes more capable.

Empirical Evidence

Training StrategyConvergence SpeedFinal PerformanceTraining Stability
Random (baseline)1.0xBaselineNormal
Easy-first curriculum1.3x+2-5%More stable
Hard-first (anti-curriculum)0.8x-1-3%Less stable
Self-paced curriculum1.2x+3-7%Most stable

Difficulty Metrics for LLM Training

Loss-Based Difficulty

DfLoss-Based Difficulty

Define difficulty as the loss on a reference model. Examples with higher loss are considered harder because the reference model struggles with them.

Difficulty Score

d(xi)=L(xi,θref)d(x_i) = \mathcal{L}(x_i, \theta_{\text{ref}})

Here,

  • d(xi)d(x_i)=Difficulty of example x_i
  • θref\theta_{\text{ref}}=Reference model parameters
import torch

def compute_difficulty_scores(dataset, reference_model, tokenizer):
    scores = []
    reference_model.eval()
    with torch.no_grad():
        for example in dataset:
            inputs = tokenizer(example["text"], return_tensors="pt", truncation=True)
            outputs = reference_model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss.item()
            scores.append(loss)
    return scores

Perplexity-Based Difficulty

def perplexity_difficulty(example, language_model, tokenizer):
    inputs = tokenizer(example["text"], return_tensors="pt", truncation=True)
    with torch.no_grad():
        outputs = language_model(**inputs, labels=inputs["input_ids"])
    ppl = torch.exp(outputs.loss).item()
    return ppl

Quality-Weighted Difficulty

Quality-Weighted Difficulty

dqw(xi)=αd(xi)+(1α)q(xi)d_{\text{qw}}(x_i) = \alpha \cdot d(x_i) + (1 - \alpha) \cdot q(x_i)

Here,

  • dqwd_{\text{qw}}=Quality-weighted difficulty score
  • d(xi)d(x_i)=Raw difficulty (loss-based)
  • q(xi)q(x_i)=Quality score (1 - quality)
  • α\alpha=Balance parameter

Curriculum Strategies

Linear Curriculum

DfLinear Curriculum

A linear curriculum increases the proportion of hard examples linearly over time. At step t, the probability of sampling example i is proportional to its difficulty multiplied by a time-varying temperature.

def linear_curriculum_weight(difficulty, step, total_steps):
    temperature = step / total_steps
    weight = difficulty * temperature + (1 - temperature) * 0.5
    return weight

Exponential Curriculum

def exponential_curriculum_weight(difficulty, step, total_steps):
    temperature = 1 - math.exp(-5 * step / total_steps)
    weight = difficulty * temperature + (1 - temperature) * 0.5
    return weight

Self-Paced Curriculum

DfSelf-Paced Learning

Self-paced learning allows the model to automatically determine which examples to focus on based on its current capability. The model starts with easy examples and gradually incorporates harder ones as its loss decreases.

def self_paced_weights(losses, lambda_param):
    """Self-paced learning: weight examples inversely proportional to loss."""
    weights = []
    for loss in losses:
        if loss < lambda_param:
            weights.append(1.0 - loss / lambda_param)
        else:
            weights.append(0.0)
    return weights

Self-Paced Weight Function

w(xi,λ)=max(0,1L(xi)λ)w(x_i, \lambda) = \max(0, 1 - \frac{\mathcal{L}(x_i)}{\lambda})

Here,

  • w(xi,λ)w(x_i, \lambda)=Self-paced weight for example x_i
  • λ\lambda=Pace parameter (increases over time)
  • L(xi)\mathcal{L}(x_i)=Loss on example x_i

Domain Curriculum for LLMs

Progressive Domain Introduction

DfDomain Curriculum

Domain curriculum introduces training domains in a strategic order: start with general web text, then add specialized domains (code, math, books, academic papers). This prevents catastrophic forgetting of general capabilities while building specialized skills.

Architecture Diagram
Phase 1 (Steps 0-30%):   Web text only (general language)
Phase 2 (Steps 30-60%):  Web + Code (add programming)
Phase 3 (Steps 60-80%):  Web + Code + Math (add reasoning)
Phase 4 (Steps 80-100%): Web + Code + Math + Books (add deep knowledge)

Domain Mixing Schedules

def domain_weights(step, total_steps, domains):
    """Compute domain mixing weights based on training progress."""
    progress = step / total_steps
    
    weights = {}
    if progress < 0.3:
        weights = {"web": 1.0, "code": 0.0, "math": 0.0, "books": 0.0}
    elif progress < 0.6:
        t = (progress - 0.3) / 0.3
        weights = {"web": 1.0 - 0.2 * t, "code": 0.2 * t, "math": 0.0, "books": 0.0}
    elif progress < 0.8:
        t = (progress - 0.6) / 0.2
        weights = {"web": 0.8, "code": 0.2, "math": 0.15 * t, "books": 0.05 * t}
    else:
        weights = {"web": 0.8, "code": 0.2, "math": 0.15, "books": 0.05}
    
    total = sum(weights.values())
    return {k: v / total for k, v in weights.items()}

Epoch Scheduling

Optimal Data Repetition

DfEpoch Scheduling

Epoch scheduling determines how many times each example is seen during training. The "Scaling Data-Constrained Language Models" paper found that repeating data beyond 4 epochs yields diminishing returns, and optimal repetition depends on dataset size.

Diminishing Returns of Repetition

L(E)=L0+CEβL(E) = L_0 + \frac{C}{E^{\beta}}

Here,

  • L(E)L(E)=Loss after E epochs
  • L0L_0=Irreducible loss
  • CC=Constant depending on data quality
  • β\beta=Decay rate (typically 0.1-0.3)

The "Chinchilla" paper recommends training for 1 epoch on a sufficiently large dataset. However, when data is limited, 2-4 epochs of repetition can be beneficial. Beyond 4 epochs, the model begins to memorize rather than generalize.

Adaptive Epoch Scheduling

def adaptive_epoch_schedule(dataset_size, model_params, max_epochs=4):
    """Determine optimal epochs based on dataset size and model size."""
    chinchilla_tokens = model_params * 20
    
    if dataset_size >= chinchilla_tokens:
        return 1  # Single epoch is optimal
    elif dataset_size >= chinchilla_tokens / 2:
        return 2  # Moderate repetition
    elif dataset_size >= chinchilla_tokens / 4:
        return 3  # More repetition needed
    else:
        return min(max_epochs, 4)  # Maximum beneficial repetition

Implementation with PyTorch

import torch
from torch.utils.data import Sampler, DataLoader

class CurriculumSampler(Sampler):
    """Sampler that orders examples by difficulty."""
    
    def __init__(self, difficulties, shuffle=True):
        self.difficulties = difficulties
        self.shuffle = shuffle
    
    def __iter__(self):
        indices = list(range(len(self.difficulties)))
        if self.shuffle:
            # Sort by difficulty with some randomness
            sorted_indices = sorted(indices, key=lambda i: self.difficulties[i])
            # Add noise to prevent perfectly deterministic ordering
            noise = torch.randn(len(sorted_indices)) * 0.1
            noisy_order = sorted_indices[torch.argsort(noise + torch.arange(len(sorted_indices)))]
            return iter(noisy_order.tolist())
        return iter(indices)
    
    def __len__(self):
        return len(self.difficulties)

# Usage
difficulties = compute_difficulty_scores(dataset, reference_model, tokenizer)
sampler = CurriculumSampler(difficulties, shuffle=True)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

Practice Exercises

  1. Curriculum Design: Design a curriculum learning strategy for training an LLM on a mix of web text, code, and mathematical proofs. What order would you use and why?

  2. Difficulty Metric: Implement a difficulty metric that combines loss-based difficulty with document quality scores. How would you balance these two signals?

  3. Epoch Analysis: If you have 100B tokens of data and need to train a 7B parameter model, how many epochs would you use? Justify based on scaling laws.

  4. Self-Paced Implementation: Implement a self-paced curriculum that automatically adjusts the difficulty threshold based on the model's current training loss.

Key Takeaways

Summary: Curriculum Learning for LLMs

  • Data order matters — strategic ordering improves convergence and final performance
  • Loss-based difficulty uses a reference model to score example hardness
  • Self-paced learning lets the model determine its own curriculum
  • Domain curriculum introduces specialized domains progressively
  • Epoch scheduling limits data repetition to 2-4 epochs maximum
  • Chinchilla-optimal training prefers 1 epoch on sufficiently large datasets
  • Quality-weighted difficulty combines difficulty with data quality scores
  • Random sampling is a strong baseline but curriculum learning provides consistent gains

What to Learn Next

-> Data Quality and Curation for LLMs The foundations of data quality, deduplication, and filtering.

-> Synthetic Data Generation Using LLMs to create high-quality training data for themselves.

-> Distributed Training for LLMs Scaling training across hundreds of GPUs with parallelism strategies.

-> Knowledge Distillation for LLMs Compressing large models into smaller, faster ones.

-> Scaling Laws and Chinchilla How model performance scales with compute, data, and parameters.

-> Pretraining Language Models The fundamentals of training language models on large corpora.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement