Chapter 10: Training a Transformer from Scratch¶
Learning Outcome¶
Train a small GPT-style language model from scratch on a text corpus, understand the pre-training objective, implement the training loop with modern best practices, and fine-tune it on a downstream task.
Concepts¶
Causal Language Modeling Loss¶
The training objective is to predict the next token at every position:
This is cross-entropy over next-token predictions. Each sequence of length T
provides T-1 training examples (each token predicts the next).
Data Pipeline¶
For pre-training:
- Tokenize the entire corpus.
- Pack sequences end-to-end to fill context windows: avoids wasted padding.
- Shuffle at the chunk level.
# Sequence packing example
tokens = [1, 23, 45, 2, 67, 89, 3, ...] # concatenated corpus
# Create non-overlapping chunks of context_length + 1
context_length = 512
for i in range(0, len(tokens) - context_length, context_length):
chunk = tokens[i : i + context_length + 1]
x = chunk[:-1] # input
y = chunk[1:] # target (shifted by 1)
AdamW Optimizer¶
AdamW decouples weight decay from the adaptive gradient scaling of Adam:
# Adam: weight_decay applied to gradient step (incorrect)
param -= lr * (grad / (sqrt(v) + eps) + weight_decay * param)
# AdamW: weight_decay applied directly to params (correct)
param = param * (1 - lr * weight_decay) - lr * grad / (sqrt(v) + eps)
Typical hyperparameters:
- lr = 3e-4 (peak)
- weight_decay = 0.1
- betas = (0.9, 0.95)
- Apply weight decay only to weight matrices, not biases or LayerNorm params.
Learning Rate Schedule¶
Warmup + cosine decay:
def get_lr(step, warmup_steps, max_steps, max_lr, min_lr=0.0):
if step < warmup_steps:
return max_lr * step / warmup_steps
if step > max_steps:
return min_lr
decay = (step - warmup_steps) / (max_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay))
return min_lr + coeff * (max_lr - min_lr)
Mixed-Precision Training¶
Use BF16 (on Ampere/Hopper GPUs) for the forward pass; accumulate gradients in FP32:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
logits = model(input_ids)
loss = criterion(logits.view(-1, vocab_size), targets.view(-1))
Gradient Clipping¶
Clip the global gradient norm before the optimizer step:
Perplexity¶
Perplexity is the exponential of average cross-entropy loss: [ \text{PPL} = e^{\mathcal{L}} ]
Lower is better. A model with perplexity 100 is as surprised as if it randomly chose from 100 equally probable tokens.
Exercise 1 — Complete Training Script¶
Guided Exercise
Build a training loop for a small GPT model on WikiText-2.
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer
# ── Data Pipeline ──────────────────────────────────────────────────────────────
class PackedTextDataset(Dataset):
"""Tokenize a corpus and pack into fixed-length chunks."""
def __init__(self, texts: list[str], tokenizer, context_length: int = 256):
self.context_length = context_length
tokens = []
for text in texts:
ids = tokenizer.encode(text)
tokens.extend(ids)
tokens.append(tokenizer.eos_token_id)
# Create non-overlapping chunks
self.chunks = []
for i in range(0, len(tokens) - context_length, context_length):
chunk = tokens[i : i + context_length + 1]
self.chunks.append(torch.tensor(chunk, dtype=torch.long))
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
chunk = self.chunks[idx]
return chunk[:-1], chunk[1:] # (input, target)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
wikitext = load_dataset("wikitext", "wikitext-2-raw-v1")
train_texts = [t for t in wikitext["train"]["text"] if len(t.strip()) > 50]
val_texts = [t for t in wikitext["validation"]["text"] if len(t.strip()) > 50]
CONTEXT = 128 # Use smaller context for training speed
train_set = PackedTextDataset(train_texts[:500], tokenizer, CONTEXT)
val_set = PackedTextDataset(val_texts[:100], tokenizer, CONTEXT)
train_loader = DataLoader(train_set, batch_size=16, shuffle=True, drop_last=True)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False)
print(f"Train chunks: {len(train_set)}, Val chunks: {len(val_set)}")
# ── Model ──────────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(
vocab_size=50257,
d_model=256,
n_heads=8,
n_layers=4,
max_len=CONTEXT,
dropout=0.1,
).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
# ── Optimizer and LR Schedule ──────────────────────────────────────────────────
# Separate weight decay: apply only to 2D+ tensors (weight matrices)
decay_params = [p for n, p in model.named_parameters()
if p.ndim >= 2 and 'norm' not in n]
nodecay_params = [p for n, p in model.named_parameters()
if p.ndim < 2 or 'norm' in n]
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": 0.1},
{"params": nodecay_params, "weight_decay": 0.0},
], lr=3e-4, betas=(0.9, 0.95))
def get_lr(step, warmup=200, max_steps=2000, max_lr=3e-4, min_lr=3e-5):
if step < warmup:
return max_lr * step / warmup
if step > max_steps:
return min_lr
decay = (step - warmup) / (max_steps - warmup)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay))
return min_lr + coeff * (max_lr - min_lr)
# ── Training Loop ──────────────────────────────────────────────────────────────
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))
global_step = 0
MAX_STEPS = 2000
for epoch in range(20):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# Update LR
lr = get_lr(global_step)
for pg in optimizer.param_groups:
pg["lr"] = lr
# Forward pass with mixed precision
with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
enabled=(device.type == 'cuda')):
logits = model(x)
loss = criterion(logits.view(-1, 50257), y.view(-1))
# Backward pass
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if global_step >= MAX_STEPS:
break
# Validation perplexity
model.eval()
val_loss = 0.0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
val_loss += criterion(logits.view(-1, 50257), y.view(-1)).item()
val_loss /= len(val_loader)
val_ppl = math.exp(val_loss)
print(f"Epoch {epoch+1:2d} | step={global_step:5d} | "
f"lr={lr:.2e} | val_loss={val_loss:.4f} | val_ppl={val_ppl:.1f}")
if global_step >= MAX_STEPS:
break
Exercise 2 — Log and Visualize Learning Curves¶
import matplotlib.pyplot as plt
# Collect metrics during training (add to the loop above)
train_losses = []
val_perplexities = []
# After training, plot:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(train_losses)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Train Loss")
axes[0].set_title("Training Loss")
axes[0].grid(True, alpha=0.3)
axes[1].plot(val_perplexities)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Perplexity")
axes[1].set_title("Validation Perplexity")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('learning_curves.png', dpi=100)
plt.show()
# Generate text to qualitatively assess the model
model.eval()
seed = "The history of"
seed_ids = torch.tensor([tokenizer.encode(seed)]).to(device)
with torch.no_grad():
out = model.generate(seed_ids, max_new_tokens=50, temperature=0.8, top_k=40)
print(tokenizer.decode(out[0]))
Exercise 3 — Fine-Tune on AG News Classification¶
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torch
ag_news = load_dataset("ag_news")
# Classification head on the last token embedding
class GPTForClassification(nn.Module):
def __init__(self, gpt_model: GPTModel, n_classes: int):
super().__init__()
self.gpt = gpt_model
d_model = gpt_model.token_embed.embedding_dim
self.classifier = nn.Linear(d_model, n_classes)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
# Use the last token's hidden state as the sequence representation
# (GPT-style: no [CLS] token, so we use the last position)
batch, seq = input_ids.shape
positions = torch.arange(seq, device=input_ids.device)
x = self.gpt.drop(
self.gpt.token_embed(input_ids) + self.gpt.pos_embed(positions)
)
for block in self.gpt.blocks:
x = block(x)
last_hidden = self.gpt.norm_f(x)[:, -1, :] # (batch, d_model)
return self.classifier(last_hidden)
class AGNewsDataset(Dataset):
def __init__(self, data, tokenizer, max_len=64, n_samples=None):
self.data = data if n_samples is None else data.select(range(n_samples))
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
ids = self.tokenizer.encode(item["text"], max_length=self.max_len,
truncation=True)
# Pad to max_len
pad_len = self.max_len - len(ids)
ids = ids + [self.tokenizer.eos_token_id] * pad_len
return torch.tensor(ids), item["label"]
# Few-shot (100 samples) vs full fine-tuning
for n_train in [100, len(ag_news["train"])]:
clf_model = GPTForClassification(model, n_classes=4).to(device)
optimizer = torch.optim.AdamW(clf_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
train_ds = AGNewsDataset(ag_news["train"], tokenizer, n_samples=n_train)
val_ds = AGNewsDataset(ag_news["test"], tokenizer)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=64)
for epoch in range(5):
clf_model.train()
for ids, labels in train_dl:
ids, labels = ids.to(device), labels.to(device)
logits = clf_model(ids)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
clf_model.eval()
correct = total = 0
with torch.no_grad():
for ids, labels in val_dl:
ids, labels = ids.to(device), labels.to(device)
preds = clf_model(ids).argmax(dim=-1)
correct += (preds == labels).sum().item()
total += len(labels)
print(f"n_train={n_train:7d} val_acc={correct/total:.3f}")
Summary¶
- Pre-training minimizes next-token cross-entropy (causal language modeling loss).
- Data pipelines pack sequences end-to-end to maximize GPU utilization.
- AdamW with decoupled weight decay, linear warmup, and cosine decay is the standard recipe.
- Mixed-precision (BF16) and gradient clipping are essential for stable training.
- Fine-tuning adds a task head to a pre-trained checkpoint at a lower learning rate.