Skip to content

Chapter 10: Training a Transformer from Scratch

Learning Outcome

Train a small GPT-style language model from scratch on a text corpus, understand the pre-training objective, implement the training loop with modern best practices, and fine-tune it on a downstream task.


Concepts

Causal Language Modeling Loss

The training objective is to predict the next token at every position:

\[ \mathcal{L} = -\frac{1}{N} \sum_{t=1}^{N} \log P(x_t \mid x_{<t}) \]

This is cross-entropy over next-token predictions. Each sequence of length T provides T-1 training examples (each token predicts the next).

Data Pipeline

For pre-training:

  1. Tokenize the entire corpus.
  2. Pack sequences end-to-end to fill context windows: avoids wasted padding.
  3. Shuffle at the chunk level.
# Sequence packing example
tokens = [1, 23, 45, 2, 67, 89, 3, ...]  # concatenated corpus

# Create non-overlapping chunks of context_length + 1
context_length = 512
for i in range(0, len(tokens) - context_length, context_length):
    chunk = tokens[i : i + context_length + 1]
    x = chunk[:-1]  # input
    y = chunk[1:]   # target (shifted by 1)

AdamW Optimizer

AdamW decouples weight decay from the adaptive gradient scaling of Adam:

# Adam: weight_decay applied to gradient step (incorrect)
param -= lr * (grad / (sqrt(v) + eps) + weight_decay * param)

# AdamW: weight_decay applied directly to params (correct)
param = param * (1 - lr * weight_decay) - lr * grad / (sqrt(v) + eps)

Typical hyperparameters: - lr = 3e-4 (peak) - weight_decay = 0.1 - betas = (0.9, 0.95) - Apply weight decay only to weight matrices, not biases or LayerNorm params.

Learning Rate Schedule

Warmup + cosine decay:

lr
  ╭──────╮
  │      ╲
  │       ╲___________
  0     warmup     max_steps
def get_lr(step, warmup_steps, max_steps, max_lr, min_lr=0.0):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    if step > max_steps:
        return min_lr
    decay = (step - warmup_steps) / (max_steps - warmup_steps)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay))
    return min_lr + coeff * (max_lr - min_lr)

Mixed-Precision Training

Use BF16 (on Ampere/Hopper GPUs) for the forward pass; accumulate gradients in FP32:

with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    logits = model(input_ids)
    loss = criterion(logits.view(-1, vocab_size), targets.view(-1))

Gradient Clipping

Clip the global gradient norm before the optimizer step:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Perplexity

Perplexity is the exponential of average cross-entropy loss: [ \text{PPL} = e^{\mathcal{L}} ]

Lower is better. A model with perplexity 100 is as surprised as if it randomly chose from 100 equally probable tokens.


Exercise 1 — Complete Training Script

Guided Exercise

Build a training loop for a small GPT model on WikiText-2.

import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer


# ── Data Pipeline ──────────────────────────────────────────────────────────────

class PackedTextDataset(Dataset):
    """Tokenize a corpus and pack into fixed-length chunks."""

    def __init__(self, texts: list[str], tokenizer, context_length: int = 256):
        self.context_length = context_length
        tokens = []
        for text in texts:
            ids = tokenizer.encode(text)
            tokens.extend(ids)
            tokens.append(tokenizer.eos_token_id)

        # Create non-overlapping chunks
        self.chunks = []
        for i in range(0, len(tokens) - context_length, context_length):
            chunk = tokens[i : i + context_length + 1]
            self.chunks.append(torch.tensor(chunk, dtype=torch.long))

    def __len__(self):
        return len(self.chunks)

    def __getitem__(self, idx):
        chunk = self.chunks[idx]
        return chunk[:-1], chunk[1:]  # (input, target)


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

wikitext = load_dataset("wikitext", "wikitext-2-raw-v1")
train_texts = [t for t in wikitext["train"]["text"] if len(t.strip()) > 50]
val_texts   = [t for t in wikitext["validation"]["text"] if len(t.strip()) > 50]

CONTEXT = 128  # Use smaller context for training speed
train_set = PackedTextDataset(train_texts[:500], tokenizer, CONTEXT)
val_set   = PackedTextDataset(val_texts[:100],  tokenizer, CONTEXT)

train_loader = DataLoader(train_set, batch_size=16, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_set,   batch_size=16, shuffle=False)

print(f"Train chunks: {len(train_set)}, Val chunks: {len(val_set)}")


# ── Model ──────────────────────────────────────────────────────────────────────

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = GPTModel(
    vocab_size=50257,
    d_model=256,
    n_heads=8,
    n_layers=4,
    max_len=CONTEXT,
    dropout=0.1,
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")


# ── Optimizer and LR Schedule ──────────────────────────────────────────────────

# Separate weight decay: apply only to 2D+ tensors (weight matrices)
decay_params = [p for n, p in model.named_parameters()
                if p.ndim >= 2 and 'norm' not in n]
nodecay_params = [p for n, p in model.named_parameters()
                  if p.ndim < 2 or 'norm' in n]

optimizer = torch.optim.AdamW([
    {"params": decay_params,   "weight_decay": 0.1},
    {"params": nodecay_params, "weight_decay": 0.0},
], lr=3e-4, betas=(0.9, 0.95))


def get_lr(step, warmup=200, max_steps=2000, max_lr=3e-4, min_lr=3e-5):
    if step < warmup:
        return max_lr * step / warmup
    if step > max_steps:
        return min_lr
    decay = (step - warmup) / (max_steps - warmup)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay))
    return min_lr + coeff * (max_lr - min_lr)


# ── Training Loop ──────────────────────────────────────────────────────────────

criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))

global_step = 0
MAX_STEPS = 2000

for epoch in range(20):
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        # Update LR
        lr = get_lr(global_step)
        for pg in optimizer.param_groups:
            pg["lr"] = lr

        # Forward pass with mixed precision
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16,
                            enabled=(device.type == 'cuda')):
            logits = model(x)
            loss = criterion(logits.view(-1, 50257), y.view(-1))

        # Backward pass
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        global_step += 1
        if global_step >= MAX_STEPS:
            break

    # Validation perplexity
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            val_loss += criterion(logits.view(-1, 50257), y.view(-1)).item()
    val_loss /= len(val_loader)
    val_ppl = math.exp(val_loss)

    print(f"Epoch {epoch+1:2d} | step={global_step:5d} | "
          f"lr={lr:.2e} | val_loss={val_loss:.4f} | val_ppl={val_ppl:.1f}")

    if global_step >= MAX_STEPS:
        break

Exercise 2 — Log and Visualize Learning Curves

import matplotlib.pyplot as plt

# Collect metrics during training (add to the loop above)
train_losses = []
val_perplexities = []

# After training, plot:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(train_losses)
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Train Loss")
axes[0].set_title("Training Loss")
axes[0].grid(True, alpha=0.3)

axes[1].plot(val_perplexities)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Perplexity")
axes[1].set_title("Validation Perplexity")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('learning_curves.png', dpi=100)
plt.show()

# Generate text to qualitatively assess the model
model.eval()
seed = "The history of"
seed_ids = torch.tensor([tokenizer.encode(seed)]).to(device)
with torch.no_grad():
    out = model.generate(seed_ids, max_new_tokens=50, temperature=0.8, top_k=40)
print(tokenizer.decode(out[0]))

Exercise 3 — Fine-Tune on AG News Classification

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torch

ag_news = load_dataset("ag_news")

# Classification head on the last token embedding
class GPTForClassification(nn.Module):
    def __init__(self, gpt_model: GPTModel, n_classes: int):
        super().__init__()
        self.gpt = gpt_model
        d_model = gpt_model.token_embed.embedding_dim
        self.classifier = nn.Linear(d_model, n_classes)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # Use the last token's hidden state as the sequence representation
        # (GPT-style: no [CLS] token, so we use the last position)
        batch, seq = input_ids.shape
        positions = torch.arange(seq, device=input_ids.device)
        x = self.gpt.drop(
            self.gpt.token_embed(input_ids) + self.gpt.pos_embed(positions)
        )
        for block in self.gpt.blocks:
            x = block(x)
        last_hidden = self.gpt.norm_f(x)[:, -1, :]  # (batch, d_model)
        return self.classifier(last_hidden)


class AGNewsDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=64, n_samples=None):
        self.data = data if n_samples is None else data.select(range(n_samples))
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        ids = self.tokenizer.encode(item["text"], max_length=self.max_len,
                                    truncation=True)
        # Pad to max_len
        pad_len = self.max_len - len(ids)
        ids = ids + [self.tokenizer.eos_token_id] * pad_len
        return torch.tensor(ids), item["label"]


# Few-shot (100 samples) vs full fine-tuning
for n_train in [100, len(ag_news["train"])]:
    clf_model = GPTForClassification(model, n_classes=4).to(device)
    optimizer = torch.optim.AdamW(clf_model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    train_ds = AGNewsDataset(ag_news["train"], tokenizer, n_samples=n_train)
    val_ds   = AGNewsDataset(ag_news["test"], tokenizer)

    train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_dl   = DataLoader(val_ds,   batch_size=64)

    for epoch in range(5):
        clf_model.train()
        for ids, labels in train_dl:
            ids, labels = ids.to(device), labels.to(device)
            logits = clf_model(ids)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    clf_model.eval()
    correct = total = 0
    with torch.no_grad():
        for ids, labels in val_dl:
            ids, labels = ids.to(device), labels.to(device)
            preds = clf_model(ids).argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += len(labels)

    print(f"n_train={n_train:7d}  val_acc={correct/total:.3f}")

Summary

  • Pre-training minimizes next-token cross-entropy (causal language modeling loss).
  • Data pipelines pack sequences end-to-end to maximize GPU utilization.
  • AdamW with decoupled weight decay, linear warmup, and cosine decay is the standard recipe.
  • Mixed-precision (BF16) and gradient clipping are essential for stable training.
  • Fine-tuning adds a task head to a pre-trained checkpoint at a lower learning rate.

← Chapter 9 Chapter 11: HuggingFace Internals →