Autoregressive Models
Autoregressive Language Model
import torch
import torch.nn as nn
import torch.nn.functional as F
class AutoregressiveLM(nn.Module):
def __init__(self, vocab_size, embed_dim=256, num_heads=8, num_layers=6):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.positional = nn.Embedding(512, embed_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.output = nn.Linear(embed_dim, vocab_size)
def forward(self, x, targets=None):
seq_len = x.shape[1]
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
x = self.embedding(x) + self.positional(positions)
mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
x = self.transformer(x, mask=mask)
logits = self.output(x)
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
return logits, loss
return logits
model = AutoregressiveLM(vocab_size=50000)
Teacher Forcing Training
class TeacherForcingTrainer:
def __init__(self, model, optimizer, teacher_forcing_ratio=0.5):
self.model = model
self.optimizer = optimizer
self.teacher_forcing_ratio = teacher_forcing_ratio
def train_step(self, input_ids, targets):
self.model.train()
use_teacher_forcing = torch.rand(1).item() < self.teacher_forcing_ratio
if use_teacher_forcing:
logits, loss = self.model(input_ids, targets)
else:
logits = self.model(input_ids)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
trainer = TeacherForcingTrainer(model, optimizer, teacher_forcing_ratio=0.5)
loss = trainer.train_step(input_ids, targets)
Sampling Strategies
class SamplingStrategies:
@staticmethod
def greedy_decode(model, start_token, max_len=50):
model.eval()
tokens = [start_token]
for _ in range(max_len):
x = torch.tensor([tokens])
logits = model(x)
next_token = logits[0, -1].argmax().item()
tokens.append(next_token)
return tokens
@staticmethod
def top_k_decode(model, start_token, k=10, max_len=50):
model.eval()
tokens = [start_token]
for _ in range(max_len):
x = torch.tensor([tokens])
logits = model(x)
top_k_logits, top_k_indices = torch.topk(logits[0, -1], k)
probs = F.softmax(top_k_logits, dim=-1)
idx = torch.multinomial(probs, 1)
next_token = top_k_indices[idx].item()
tokens.append(next_token)
return tokens
@staticmethod
def nucleus_decode(model, start_token, p=0.9, max_len=50):
model.eval()
tokens = [start_token]
for _ in range(max_len):
x = torch.tensor([tokens])
logits = model(x)
sorted_logits, sorted_indices = torch.sort(logits[0, -1], descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1]
sorted_indices_to_remove[0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[0, indices_to_remove] = float('-inf')
probs = F.softmax(logits[0, -1], dim=-1)
next_token = torch.multinomial(probs, 1).item()
tokens.append(next_token)
return tokens
strategies = SamplingStrategies()
greedy_tokens = strategies.greedy_decode(model, start_token=1)
top_k_tokens = strategies.top_k_decode(model, start_token=1, k=10)
nucleus_tokens = strategies.nucleus_decode(model, start_token=1, p=0.9)
KV Cache for Efficiency
class KVCache:
def __init__(self):
self.key_cache = []
self.value_cache = []
def update(self, layer_idx, new_key, new_value):
if layer_idx < len(self.key_cache):
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], new_key], dim=2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], new_value], dim=2
)
else:
self.key_cache.append(new_key)
self.value_cache.append(new_value)
def get(self, layer_idx):
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def clear(self):
self.key_cache = []
self.value_cache = []
def generate_with_cache(model, start_tokens, max_len=100):
kv_cache = KVCache()
tokens = start_tokens.clone()
for _ in range(max_len):
logits, kv_cache = model.forward_with_cache(tokens, kv_cache)
next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
tokens = torch.cat([tokens, next_token], dim=1)
return tokens
Best Practices
- Use teacher forcing with scheduled sampling
- Implement KV cache for efficient generation
- Apply temperature scaling for diversity control
- Use top-k or nucleus sampling for quality
- Monitor perplexity during training
- Implement early stopping based on validation loss