Fine-tuning LLMs

TrainingFine-TuningFree Lesson

Advertisement

Fine-tuning LLMs

Fine-tuning adapts a pre-trained language model to specific tasks or domains by continuing training on task-specific data. This tutorial covers the methods, objectives, and practical considerations.

DfFine-tuning

Fine-tuning is the process of continuing training a pre-trained language model on a smaller, task-specific dataset. The model leverages knowledge acquired during pre-training and adapts it to the target task through gradient-based optimization.

Full Fine-tuning

Full fine-tuning updates all model parameters on the target dataset.

Fine-tuning Loss

\\mathcal{L}_{\\text{ft}}(\\theta) = -\\sum_{(x,y) \\in \\mathcal{D}_{\\text{ft}}} \\sum_{t=1}^{|y|} \\log P_\\theta(y_t | x, y_{<t})

Here,

  • ΞΈ\theta=All model parameters
  • Dft\mathcal{D}_{\text{ft}}=Fine-tuning dataset
  • xx=Input (instruction/context)
  • yy=Output (response)

Learning Rate Schedule

Cosine Learning Rate Schedule

eta(t)=etamin+frac12(etamaxβˆ’etamin)left(1+cosleft(fractTpiright)right)\\eta(t) = \\eta_{\\min} + \\frac{1}{2}(\\eta_{\\max} - \\eta_{\\min})\\left(1 + \\cos\\left(\\frac{t}{T}\\pi\\right)\\right)

Here,

  • Ξ·min⁑\eta_{\min}=Minimum learning rate
  • Ξ·max⁑\eta_{\max}=Maximum learning rate
  • tt=Current step
  • TT=Total steps

For fine-tuning, use a learning rate 10-100x smaller than pre-training (typically 1e-5 to 5e-5). Always use warmup steps (5-10% of total steps) to avoid early instability.

ΞΈβˆ—=arg⁑min⁑θE(x,y)∼Dft[βˆ’log⁑PΞΈ(y∣x)]\theta^* = \arg\min_\theta \mathbb{E}_{(x,y) \sim \mathcal{D}_{ft}} [-\log P_\theta(y|x)]

Instruction Tuning

DfInstruction Tuning

Instruction tuning trains a language model to follow natural language instructions. The training data consists of (instruction, input, output) triples, where the model learns to generate the appropriate response given an instruction and optional input.

Chat Format

Modern instruction-tuned models use a structured chat format with role tokens. The system message sets behavior, user messages provide instructions, and assistant messages contain the model's responses.

Training Datasets

DatasetSizeSourceQuality
Alpaca52KSelf-instruct (GPT-3.5)Medium
ShareGPT90KUser-shared conversationsHigh
OpenAssistant161KHuman annotationsHigh
Dolly15KDatabricks employeesHigh
FLAN Collection1.8MAggregated NLP tasksMedium

Full Fine-tuning Example

`python from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from datasets import load_dataset

model_name = "meta-llama/Llama-2-7b-hf" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" )

dataset = load_dataset("tatsu-lab/alpaca", split="train")

def format_prompt(example): if example["input"]: return f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}" return f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"

def tokenize(examples): texts = [format_prompt(e) for e in examples] return tokenizer(texts, truncation=True, max_length=2048)

tokenized = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)

training_args = TrainingArguments( output_dir="./alpaca-finetuned", per_device_train_batch_size=4, gradient_accumulation_steps=8, learning_rate=2e-5, warmup_steps=100, max_steps=5000, fp16=True, logging_steps=50, save_steps=500, optim="adamw_torch", weight_decay=0.1, lr_scheduler_type="cosine", )

trainer = Trainer( model=model, args=training_args, train_dataset=tokenized, data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), )

trainer.train() `

For parameter-efficient alternatives to full fine-tuning, see our modules on LoRA and PEFT and QLoRA and Quantization.

When to Fine-tune

Fine-tuning Decision

textFineβˆ’tuneif:fractextPerformancetextftβˆ’textPerformancetextprompttextCosttextft>tau\\text{Fine-tune if: } \\frac{\\text{Performance}_{\\text{ft}} - \\text{Performance}_{\\text{prompt}}}{\\text{Cost}_{\\text{ft}}} > \\tau

Here,

  • Ο„\tau=Performance-cost threshold

Fine-tune when:

  • You have 100+ high-quality examples
  • The task requires specific formatting or domain knowledge
  • Prompt engineering alone is insufficient
  • You need lower latency (no few-shot examples in prompt)

Prompt instead when:

  • Few labeled examples are available
  • The task is general-purpose
  • You need rapid iteration
  • Compute budget is limited

Practice Exercises

  1. Mathematical: Calculate the total VRAM required to fine-tune a 7B parameter model in FP16 with gradient checkpointing. Assume sequence length 2048 and batch size 4.

  2. Implementation: Fine-tune Llama-2-7B on a small custom dataset using the Alpaca format. Evaluate the before/after performance on 10 held-out examples.

  3. Analysis: Compare the training curves (loss, learning rate) of fine-tuning with learning rates of 1e-5, 2e-5, and 5e-5. Which converges fastest without overfitting?

  4. Research: What are the failure modes of instruction tuning? Investigate cases where fine-tuning degrades the base model's capabilities.

Key Takeaways:

  • Fine-tuning updates all model parameters on task-specific data
  • Instruction tuning trains models to follow natural language instructions
  • Learning rate should be 10-100x smaller than pre-training with warmup
  • Alpaca, ShareGPT, and OpenAssistant are popular fine-tuning datasets
  • Full fine-tuning is expensive; consider LoRA/QLoRA for efficiency
  • Fine-tune when you have sufficient data and need task-specific behavior

Advanced Fine-tuning Techniques

Data Quality and Curation

The quality of fine-tuning data can be quantified by measuring diversity, accuracy, and relevance. Always prioritize data quality over quantity for instruction tuning.

Hyperparameter Sensitivity

Fine-tuning is highly sensitive to hyperparameters. The most critical are learning rate, batch size, and number of epochs. Always perform a learning rate sweep.

Recommended hyperparameter ranges:

ParameterRecommended RangeNotes
Learning rate1e-6 to 5e-5Start with 2e-5
Batch size4-32Larger is more stable
Epochs1-5Monitor validation loss
Warmup ratio0.03-0.15-10% of total steps
Weight decay0.0-0.2Regularization

Common Failure Modes

  1. Catastrophic forgetting: The model loses pre-trained knowledge. Mitigation: lower learning rate, fewer epochs, use LoRA.
  2. Overfitting: Model memorizes training data. Mitigation: more data, dropout, weight decay, early stopping.
  3. Mode collapse: Model produces the same output for all inputs. Mitigation: diverse training data, label smoothing.
  4. Alignment tax: Fine-tuning improves one task but degrades others. Mitigation: multi-task training, elastic weight consolidation.

Always evaluate fine-tuned models on both the target task and general capabilities. A model that excels at the target task but loses general reasoning is not useful in practice.

Evaluation During Fine-tuning

Monitor both training and validation metrics throughout fine-tuning. Key metrics include:

  • Training loss: Should decrease steadily
  • Validation loss: Should decrease then plateau (watch for overfitting)
  • Task-specific metrics: BLEU, ROUGE, accuracy, F1 depending on the task
  • Perplexity: Lower is better for language modeling tasks

Use early stopping when validation loss stops improving to prevent overfitting.

Advertisement

Need Expert LLM Help?

Get personalized tutoring, RAG system design, or production LLM consulting.

Advertisement