Chapter 13: BERT Pre-Training — Masked Language Modeling from Scratch¶
Learning Outcome¶
Build a complete BERT pre-training pipeline: implement the Masked Language Modeling (MLM) data collator with the canonical 80/10/10 masking strategy, train a compact BERT-style encoder on publicly available Wikipedia text, and probe the quality of the learned embeddings on a downstream classification task.
Concepts¶
The MLM Objective¶
BERT's key pre-training innovation is Masked Language Modeling: randomly mask a fraction of input tokens and train the model to reconstruct the originals from the surrounding context.
The loss is computed only at the masked positions:
where \(\mathcal{M}\) is the set of masked positions and \(\tilde{x}\) is the corrupted input sequence. Every other position contributes zero loss, so the gradient signal comes exclusively from predicting masked tokens.
The 80/10/10 Masking Strategy¶
Of the 15% of tokens selected for prediction, not all are replaced with
[MASK]:
| Treatment | Fraction | Rationale |
|---|---|---|
Replace with [MASK] |
80% | Primary masking signal |
| Replace with a random token | 10% | Forces real-token representations at all positions |
| Keep the original token | 10% | Forces prediction even without a visible mask |
If every selected token were replaced with [MASK], the model would only ever
have to handle [MASK] tokens during pre-training but never sees them at
fine-tuning time — a distribution mismatch. The 10%/10% mixture forces the
model to maintain a meaningful hidden state for every token, regardless of
whether it appears masked.
Dynamic vs. Static Masking¶
The original BERT used static masking: mask positions were fixed during preprocessing and reused every epoch. RoBERTa introduced dynamic masking: a fresh mask is sampled for each batch, so each sequence is seen with different masked positions across epochs. Dynamic masking is now standard and is what we implement here.
Pre-Training Datasets¶
BERT was originally trained on BookCorpus (~800 MB) and English Wikipedia (~2.5 GB of text). Publicly available alternatives:
| Dataset | Size | HuggingFace identifier |
|---|---|---|
| WikiText-2 | ~2 MB | wikitext, wikitext-2-raw-v1 |
| WikiText-103 | ~500 MB | wikitext, wikitext-103-raw-v1 |
| English Wikipedia | ~20 GB | wikimedia/wikipedia |
For this tutorial we use WikiText-103 — large enough to demonstrate learning, small enough to pre-train a compact model on a single GPU in minutes to hours.
The MLM Pre-Training Head¶
A lightweight head sits on top of the BERT encoder and is discarded after pre-training:
sequence_output (B, T, d_model)
→ Linear(d_model, d_model) + GELU + LayerNorm # transform layer
→ Linear(d_model, vocab_size) # vocabulary projection
→ CrossEntropyLoss (only at masked positions)
Only the encoder weights are kept; the head is thrown away before fine-tuning.
From Pre-Training to Embeddings¶
Once pre-trained, the BERT encoder produces contextual representations:
- Token embeddings:
sequence_output[:, i, :]— representation of tokeniin its full left-right context. - Sentence embedding:
sequence_output[:, 0, :]— the[CLS]token pooled over the whole sequence, used for sentence-level tasks.
The quality of these embeddings is measured by linear probing: freeze the
encoder, extract [CLS] vectors, and train a simple linear classifier on top.
Improvement over a randomly initialized encoder shows how much semantic content
the MLM objective has encoded.
Exercise 1 — MLM Data Collator¶
Guided Exercise
Implement the 80/10/10 masking strategy as a PyTorch collate function that applies a fresh random mask to each batch.
import torch
from transformers import BertTokenizerFast
def mlm_collate_fn(
batch: list[dict],
tokenizer: BertTokenizerFast,
mask_probability: float = 0.15,
) -> dict[str, torch.Tensor]:
"""
Collate pre-tokenized examples and apply dynamic MLM masking.
Returns a dict with keys:
input_ids (B, T) — corrupted token ids
attention_mask (B, T) — 1 for real tokens, 0 for padding
labels (B, T) — original ids at selected positions, -100 elsewhere
"""
input_ids = torch.stack([item["input_ids"] for item in batch]) # (B, T)
attention_mask = torch.stack([item["attention_mask"] for item in batch]) # (B, T)
labels = input_ids.clone()
# Build a sampling probability matrix; zero out special tokens and padding
probability_matrix = torch.full(input_ids.shape, mask_probability)
special_tokens_mask = torch.tensor([
tokenizer.get_special_tokens_mask(ids.tolist(), already_has_special_tokens=True)
for ids in input_ids
], dtype=torch.bool)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
probability_matrix.masked_fill_(attention_mask == 0, value=0.0)
# Sample which positions to predict (~15% of non-special, non-padding tokens)
selected = torch.bernoulli(probability_matrix).bool()
# 80% of selected → replace with [MASK]
replace_with_mask = selected & torch.bernoulli(
torch.full(input_ids.shape, 0.8)
).bool()
input_ids[replace_with_mask] = tokenizer.mask_token_id
# 10% of selected (not already [MASK]) → replace with a random token
replace_with_random = selected & ~replace_with_mask & torch.bernoulli(
torch.full(input_ids.shape, 0.5)
).bool()
random_tokens = torch.randint(
low=0, high=len(tokenizer), size=input_ids.shape, dtype=torch.long
)
input_ids[replace_with_random] = random_tokens[replace_with_random]
# 10% of selected → kept unchanged (no action required)
# CrossEntropyLoss ignores positions where label == -100
labels[~selected] = -100
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
# ── Sanity check ──────────────────────────────────────────────────────────────
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
sentence = "The transformer architecture revolutionized natural language processing."
encoded = tokenizer(sentence, return_tensors="pt",
padding="max_length", max_length=32, truncation=True)
batch = [{"input_ids": encoded["input_ids"][0],
"attention_mask": encoded["attention_mask"][0]}]
result = mlm_collate_fn(batch, tokenizer)
orig_tokens = tokenizer.convert_ids_to_tokens(batch[0]["input_ids"].tolist())
masked_tokens = tokenizer.convert_ids_to_tokens(result["input_ids"][0].tolist())
print(f"{'Original':>20} {'Masked':>20} Label")
print("-" * 60)
for orig, masked, label in zip(
orig_tokens, masked_tokens, result["labels"][0].tolist()
):
marker = " ←" if label != -100 else ""
label_str = str(label) if label != -100 else "(ignored)"
print(f"{orig:>20} {masked:>20} {label_str}{marker}")
Run this cell and confirm that:
- Approximately 15% of the non-padding tokens are selected (marked with
←). - Most selected tokens appear as
[MASK]; occasionally you will see a random token or the original token left in place. - Every non-selected position shows label
(ignored).
Exercise 2 — Pre-Train a Small BERT on WikiText-103¶
Guided Exercise
Build a complete MLM pre-training loop for a compact 6-layer, 256-dimensional BERT encoder and train it on publicly available Wikipedia text.
This exercise uses the BertModel implemented in Chapter 6 and the
mlm_collate_fn from Exercise 1.
import math
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from datasets import load_dataset
# ── Hyperparameters ───────────────────────────────────────────────────────────
VOCAB_SIZE = 30522 # bert-base-uncased vocabulary
D_MODEL = 256
N_HEADS = 8
N_LAYERS = 6
MAX_LEN = 128
BATCH_SIZE = 64
LR = 1e-3
WARMUP_STEPS = 500
MAX_STEPS = 5_000
MASK_PROB = 0.15
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ── MLM Pre-Training Head ─────────────────────────────────────────────────────
class MLMHead(nn.Module):
"""Transform BERT hidden states into vocabulary logits."""
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.dense = nn.Linear(d_model, d_model)
self.act = nn.GELU()
self.norm = nn.LayerNorm(d_model, eps=1e-12)
self.decoder = nn.Linear(d_model, vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(vocab_size))
self.decoder.bias = self.bias # tie bias as a parameter
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states: (B, T, d_model)
x = self.act(self.dense(hidden_states))
x = self.norm(x)
return self.decoder(x) # (B, T, vocab_size)
class BertForMLM(nn.Module):
"""BertModel (Chapter 6) wrapped with an MLM head for pre-training."""
def __init__(self, bert: "BertModel", vocab_size: int):
super().__init__()
self.bert = bert
self.mlm_head = MLMHead(
bert.embeddings.word_embeddings.embedding_dim, vocab_size
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
sequence_output, _ = self.bert(input_ids, attention_mask)
logits = self.mlm_head(sequence_output) # (B, T, vocab_size)
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits.view(-1, logits.size(-1)),
labels.view(-1),
)
return loss, logits
# ── Instantiate model ─────────────────────────────────────────────────────────
bert = BertModel(
vocab_size=VOCAB_SIZE,
d_model=D_MODEL,
n_heads=N_HEADS,
n_layers=N_LAYERS,
max_len=MAX_LEN,
dropout=0.1,
).to(device)
model = BertForMLM(bert, VOCAB_SIZE).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
# ── Dataset ───────────────────────────────────────────────────────────────────
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
wikitext = load_dataset("wikitext", "wikitext-103-raw-v1")
def tokenize(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=MAX_LEN,
padding="max_length",
)
def is_long_enough(example):
return len(example["text"].strip()) > 50
train_raw = wikitext["train"].filter(is_long_enough)
val_raw = wikitext["validation"].filter(is_long_enough)
train_ds = train_raw.map(tokenize, batched=True, remove_columns=["text"])
val_ds = val_raw.map(tokenize, batched=True, remove_columns=["text"])
train_ds.set_format("torch", columns=["input_ids", "attention_mask"])
val_ds.set_format("torch", columns=["input_ids", "attention_mask"])
collate = partial(mlm_collate_fn, tokenizer=tokenizer, mask_probability=MASK_PROB)
train_loader = DataLoader(
train_ds, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=collate, drop_last=True,
)
val_loader = DataLoader(
val_ds, batch_size=BATCH_SIZE, shuffle=False,
collate_fn=collate,
)
print(f"Train examples: {len(train_ds):,} | Val examples: {len(val_ds):,}")
# ── Optimizer and LR schedule ─────────────────────────────────────────────────
optimizer = torch.optim.AdamW(
model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=0.01
)
def get_lr_scale(step: int) -> float:
"""Linear warmup followed by cosine decay."""
if step < WARMUP_STEPS:
return step / max(1, WARMUP_STEPS)
progress = (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
# ── Training loop ─────────────────────────────────────────────────────────────
global_step = 0
for epoch in range(100): # breaks early when MAX_STEPS is reached
model.train()
for batch in train_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
# Adjust learning rate
lr_scale = get_lr_scale(global_step)
for pg in optimizer.param_groups:
pg["lr"] = LR * lr_scale
loss, _ = model(input_ids, attention_mask, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if global_step % 500 == 0:
# Compute validation loss over up to 50 batches
model.eval()
val_loss, val_steps = 0.0, 0
with torch.no_grad():
for vbatch in val_loader:
vloss, _ = model(
vbatch["input_ids"].to(device),
vbatch["attention_mask"].to(device),
vbatch["labels"].to(device),
)
val_loss += vloss.item()
val_steps += 1
if val_steps >= 50:
break
val_loss /= val_steps
print(
f"step={global_step:5d} "
f"train_loss={loss.item():.4f} "
f"val_loss={val_loss:.4f} "
f"val_ppl={math.exp(val_loss):.1f} "
f"lr={LR * lr_scale:.2e}"
)
model.train()
if global_step >= MAX_STEPS:
break
if global_step >= MAX_STEPS:
break
print("Pre-training complete. Saving encoder weights...")
torch.save(bert.state_dict(), "bert_pretrained.pt")
Expected learning curve:
| Step | Val MLM perplexity |
|---|---|
| 500 | ~200–500 |
| 1000 | ~50–100 |
| 2500 | ~30–50 |
| 5000 | ~20–40 |
A randomly initialized encoder starts near the full vocabulary size (~30,000) and rapidly learns token co-occurrence patterns in the first few hundred steps. After 5,000 steps on WikiText-103 the model reaches roughly 20–40 perplexity on the MLM task — confirming it has learned to exploit bidirectional context to reconstruct masked tokens.
Exercise 3 — Evaluate Embeddings via Linear Probing¶
Guided Exercise
Freeze the pre-trained encoder, extract [CLS] embeddings, and train a
linear classifier on AG News (4-class topic classification). Compare against
a randomly initialized BERT to quantify how much pre-training helped.
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn as nn
# ── Helpers ───────────────────────────────────────────────────────────────────
def extract_cls_embeddings(
bert_model: "BertModel",
data_loader: DataLoader,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return CLS embeddings and labels stacked into tensors."""
bert_model.eval()
all_emb, all_lbl = [], []
with torch.no_grad():
for batch in data_loader:
seq_out, _ = bert_model(
batch["input_ids"].to(device),
batch["attention_mask"].to(device),
)
all_emb.append(seq_out[:, 0, :].cpu()) # [CLS] = position 0
all_lbl.append(batch["label"].cpu())
return torch.cat(all_emb), torch.cat(all_lbl)
def linear_probe(
train_emb: torch.Tensor,
train_labels: torch.Tensor,
test_emb: torch.Tensor,
test_labels: torch.Tensor,
d_model: int,
n_classes: int = 4,
n_epochs: int = 20,
) -> float:
"""Train a linear head on frozen embeddings; return test accuracy."""
probe = nn.Linear(d_model, n_classes)
optimizer = torch.optim.AdamW(probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_dl = DataLoader(
TensorDataset(train_emb, train_labels), batch_size=64, shuffle=True
)
for _ in range(n_epochs):
probe.train()
for emb_batch, lbl_batch in train_dl:
loss = criterion(probe(emb_batch), lbl_batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
probe.eval()
with torch.no_grad():
preds = probe(test_emb).argmax(dim=-1)
return (preds == test_labels).float().mean().item()
# ── AG News dataset ───────────────────────────────────────────────────────────
ag_news = load_dataset("ag_news")
def encode_texts(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=MAX_LEN,
padding="max_length",
)
ag_train_ds = ag_news["train"].select(range(2_000)).map(encode_texts, batched=True)
ag_test_ds = ag_news["test"].select(range(500)).map(encode_texts, batched=True)
for ds in (ag_train_ds, ag_test_ds):
ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])
ag_train_loader = DataLoader(ag_train_ds, batch_size=64)
ag_test_loader = DataLoader(ag_test_ds, batch_size=64)
# ── Compare pre-trained vs. random initialization ─────────────────────────────
results = {}
for name, checkpoint in [("random_init", None), ("pretrained", "bert_pretrained.pt")]:
eval_bert = BertModel(
vocab_size=VOCAB_SIZE,
d_model=D_MODEL,
n_heads=N_HEADS,
n_layers=N_LAYERS,
max_len=MAX_LEN,
).to(device)
if checkpoint is not None:
eval_bert.load_state_dict(torch.load(checkpoint, map_location=device))
train_emb, train_lbl = extract_cls_embeddings(eval_bert, ag_train_loader, device)
test_emb, test_lbl = extract_cls_embeddings(eval_bert, ag_test_loader, device)
acc = linear_probe(train_emb, train_lbl, test_emb, test_lbl, D_MODEL)
results[name] = acc
print(f"{name:20s} accuracy: {acc:.3f}")
improvement = results["pretrained"] - results["random_init"]
print(f"\nPre-training gain: +{improvement:.3f} ({improvement * 100:.1f} pp)")
Interpreting the results:
A randomly initialized BERT encoder is not entirely useless — its structured
computation (LayerNorm, residual connections, learned projections) means it
produces some non-trivial embeddings, typically giving ~30–40% linear-probe
accuracy on AG News. After pre-training for 5,000 steps the same architecture
should reach 60–75%, demonstrating that the MLM objective has encoded
meaningful semantic and topical content into the [CLS] vector.
For reference, bert-base-uncased pre-trained on the full Wikipedia +
BookCorpus achieves ~92% after fine-tuning — illustrating how much larger-scale
pre-training helps.
Visualize the embedding space¶
Plot UMAP or t-SNE projections of the [CLS] embeddings to confirm that
pre-training produces a more structured embedding space:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
# Simple PCA projection (no extra library required)
def plot_pca_embeddings(embeddings, labels, title):
pca = PCA(n_components=2)
proj = pca.fit_transform(embeddings.numpy())
colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"]
class_names = ["World", "Sports", "Business", "Sci/Tech"]
fig, ax = plt.subplots(figsize=(8, 6))
for cls_idx in range(4):
mask = labels == cls_idx
ax.scatter(proj[mask, 0], proj[mask, 1],
c=colors[cls_idx], label=class_names[cls_idx],
alpha=0.5, s=10)
ax.set_title(title)
ax.legend()
plt.tight_layout()
plt.savefig(f"{title.lower().replace(' ', '_')}.png", dpi=100)
plt.show()
plot_pca_embeddings(test_emb_random, test_lbl, "Random BERT — PCA")
# (re-run extract_cls_embeddings with the pretrained model to get test_emb_pretrained)
plot_pca_embeddings(test_emb, test_lbl, "Pre-trained BERT — PCA")
The pre-trained embeddings should show clear topic clusters in PCA space; random embeddings will appear as an undifferentiated blob.
Summary¶
- BERT's MLM objective trains the model to reconstruct 15% of tokens masked from their left-right context, yielding rich bidirectional representations.
- The 80/10/10 masking strategy (replace/random/unchanged) prevents the
model from exploiting
[MASK]as a shortcut and improves transfer to fine-tuning, where[MASK]never appears. - Dynamic masking samples a fresh mask every batch, giving each sequence multiple masking views across epochs and is now the standard approach.
- The MLM head (a two-layer MLP + vocabulary projection) is attached only during pre-training and discarded afterwards.
- Linear probing on the
[CLS]embedding measures representation quality without fine-tuning: even a small encoder trained for 5,000 steps on WikiText-103 significantly outperforms a random baseline.