Skip to content

Chapter 13: BERT Pre-Training — Masked Language Modeling from Scratch

Learning Outcome

Build a complete BERT pre-training pipeline: implement the Masked Language Modeling (MLM) data collator with the canonical 80/10/10 masking strategy, train a compact BERT-style encoder on publicly available Wikipedia text, and probe the quality of the learned embeddings on a downstream classification task.


Concepts

The MLM Objective

BERT's key pre-training innovation is Masked Language Modeling: randomly mask a fraction of input tokens and train the model to reconstruct the originals from the surrounding context.

The loss is computed only at the masked positions:

\[ \mathcal{L}_\text{MLM} = -\frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \log P(x_i \mid \tilde{x}) \]

where \(\mathcal{M}\) is the set of masked positions and \(\tilde{x}\) is the corrupted input sequence. Every other position contributes zero loss, so the gradient signal comes exclusively from predicting masked tokens.

The 80/10/10 Masking Strategy

Of the 15% of tokens selected for prediction, not all are replaced with [MASK]:

Treatment Fraction Rationale
Replace with [MASK] 80% Primary masking signal
Replace with a random token 10% Forces real-token representations at all positions
Keep the original token 10% Forces prediction even without a visible mask

If every selected token were replaced with [MASK], the model would only ever have to handle [MASK] tokens during pre-training but never sees them at fine-tuning time — a distribution mismatch. The 10%/10% mixture forces the model to maintain a meaningful hidden state for every token, regardless of whether it appears masked.

Dynamic vs. Static Masking

The original BERT used static masking: mask positions were fixed during preprocessing and reused every epoch. RoBERTa introduced dynamic masking: a fresh mask is sampled for each batch, so each sequence is seen with different masked positions across epochs. Dynamic masking is now standard and is what we implement here.

Pre-Training Datasets

BERT was originally trained on BookCorpus (~800 MB) and English Wikipedia (~2.5 GB of text). Publicly available alternatives:

Dataset Size HuggingFace identifier
WikiText-2 ~2 MB wikitext, wikitext-2-raw-v1
WikiText-103 ~500 MB wikitext, wikitext-103-raw-v1
English Wikipedia ~20 GB wikimedia/wikipedia

For this tutorial we use WikiText-103 — large enough to demonstrate learning, small enough to pre-train a compact model on a single GPU in minutes to hours.

The MLM Pre-Training Head

A lightweight head sits on top of the BERT encoder and is discarded after pre-training:

sequence_output  (B, T, d_model)
  → Linear(d_model, d_model) + GELU + LayerNorm   # transform layer
  → Linear(d_model, vocab_size)                    # vocabulary projection
  → CrossEntropyLoss (only at masked positions)

Only the encoder weights are kept; the head is thrown away before fine-tuning.

From Pre-Training to Embeddings

Once pre-trained, the BERT encoder produces contextual representations:

  • Token embeddings: sequence_output[:, i, :] — representation of token i in its full left-right context.
  • Sentence embedding: sequence_output[:, 0, :] — the [CLS] token pooled over the whole sequence, used for sentence-level tasks.

The quality of these embeddings is measured by linear probing: freeze the encoder, extract [CLS] vectors, and train a simple linear classifier on top. Improvement over a randomly initialized encoder shows how much semantic content the MLM objective has encoded.


Exercise 1 — MLM Data Collator

Guided Exercise

Implement the 80/10/10 masking strategy as a PyTorch collate function that applies a fresh random mask to each batch.

import torch
from transformers import BertTokenizerFast


def mlm_collate_fn(
    batch: list[dict],
    tokenizer: BertTokenizerFast,
    mask_probability: float = 0.15,
) -> dict[str, torch.Tensor]:
    """
    Collate pre-tokenized examples and apply dynamic MLM masking.

    Returns a dict with keys:
        input_ids      (B, T) — corrupted token ids
        attention_mask (B, T) — 1 for real tokens, 0 for padding
        labels         (B, T) — original ids at selected positions, -100 elsewhere
    """
    input_ids      = torch.stack([item["input_ids"]      for item in batch])  # (B, T)
    attention_mask = torch.stack([item["attention_mask"] for item in batch])  # (B, T)

    labels = input_ids.clone()

    # Build a sampling probability matrix; zero out special tokens and padding
    probability_matrix = torch.full(input_ids.shape, mask_probability)

    special_tokens_mask = torch.tensor([
        tokenizer.get_special_tokens_mask(ids.tolist(), already_has_special_tokens=True)
        for ids in input_ids
    ], dtype=torch.bool)
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    probability_matrix.masked_fill_(attention_mask == 0, value=0.0)

    # Sample which positions to predict (~15% of non-special, non-padding tokens)
    selected = torch.bernoulli(probability_matrix).bool()

    # 80% of selected → replace with [MASK]
    replace_with_mask = selected & torch.bernoulli(
        torch.full(input_ids.shape, 0.8)
    ).bool()
    input_ids[replace_with_mask] = tokenizer.mask_token_id

    # 10% of selected (not already [MASK]) → replace with a random token
    replace_with_random = selected & ~replace_with_mask & torch.bernoulli(
        torch.full(input_ids.shape, 0.5)
    ).bool()
    random_tokens = torch.randint(
        low=0, high=len(tokenizer), size=input_ids.shape, dtype=torch.long
    )
    input_ids[replace_with_random] = random_tokens[replace_with_random]

    # 10% of selected → kept unchanged (no action required)

    # CrossEntropyLoss ignores positions where label == -100
    labels[~selected] = -100

    return {
        "input_ids":      input_ids,
        "attention_mask": attention_mask,
        "labels":         labels,
    }


# ── Sanity check ──────────────────────────────────────────────────────────────

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

sentence = "The transformer architecture revolutionized natural language processing."
encoded = tokenizer(sentence, return_tensors="pt",
                    padding="max_length", max_length=32, truncation=True)

batch = [{"input_ids":      encoded["input_ids"][0],
          "attention_mask": encoded["attention_mask"][0]}]

result = mlm_collate_fn(batch, tokenizer)

orig_tokens   = tokenizer.convert_ids_to_tokens(batch[0]["input_ids"].tolist())
masked_tokens = tokenizer.convert_ids_to_tokens(result["input_ids"][0].tolist())

print(f"{'Original':>20}  {'Masked':>20}  Label")
print("-" * 60)
for orig, masked, label in zip(
    orig_tokens, masked_tokens, result["labels"][0].tolist()
):
    marker = " ←" if label != -100 else ""
    label_str = str(label) if label != -100 else "(ignored)"
    print(f"{orig:>20}  {masked:>20}  {label_str}{marker}")

Run this cell and confirm that:

  1. Approximately 15% of the non-padding tokens are selected (marked with ).
  2. Most selected tokens appear as [MASK]; occasionally you will see a random token or the original token left in place.
  3. Every non-selected position shows label (ignored).

Exercise 2 — Pre-Train a Small BERT on WikiText-103

Guided Exercise

Build a complete MLM pre-training loop for a compact 6-layer, 256-dimensional BERT encoder and train it on publicly available Wikipedia text.

This exercise uses the BertModel implemented in Chapter 6 and the mlm_collate_fn from Exercise 1.

import math
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from datasets import load_dataset


# ── Hyperparameters ───────────────────────────────────────────────────────────

VOCAB_SIZE   = 30522   # bert-base-uncased vocabulary
D_MODEL      = 256
N_HEADS      = 8
N_LAYERS     = 6
MAX_LEN      = 128
BATCH_SIZE   = 64
LR           = 1e-3
WARMUP_STEPS = 500
MAX_STEPS    = 5_000
MASK_PROB    = 0.15

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# ── MLM Pre-Training Head ─────────────────────────────────────────────────────

class MLMHead(nn.Module):
    """Transform BERT hidden states into vocabulary logits."""

    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.dense   = nn.Linear(d_model, d_model)
        self.act     = nn.GELU()
        self.norm    = nn.LayerNorm(d_model, eps=1e-12)
        self.decoder = nn.Linear(d_model, vocab_size, bias=False)
        self.bias    = nn.Parameter(torch.zeros(vocab_size))
        self.decoder.bias = self.bias  # tie bias as a parameter

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # hidden_states: (B, T, d_model)
        x = self.act(self.dense(hidden_states))
        x = self.norm(x)
        return self.decoder(x)          # (B, T, vocab_size)


class BertForMLM(nn.Module):
    """BertModel (Chapter 6) wrapped with an MLM head for pre-training."""

    def __init__(self, bert: "BertModel", vocab_size: int):
        super().__init__()
        self.bert     = bert
        self.mlm_head = MLMHead(
            bert.embeddings.word_embeddings.embedding_dim, vocab_size
        )

    def forward(
        self,
        input_ids:      torch.Tensor,
        attention_mask: torch.Tensor,
        labels:         torch.Tensor | None = None,
    ) -> tuple[torch.Tensor | None, torch.Tensor]:
        sequence_output, _ = self.bert(input_ids, attention_mask)
        logits = self.mlm_head(sequence_output)     # (B, T, vocab_size)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
            )
        return loss, logits


# ── Instantiate model ─────────────────────────────────────────────────────────

bert = BertModel(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    max_len=MAX_LEN,
    dropout=0.1,
).to(device)

model = BertForMLM(bert, VOCAB_SIZE).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")


# ── Dataset ───────────────────────────────────────────────────────────────────

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

wikitext = load_dataset("wikitext", "wikitext-103-raw-v1")

def tokenize(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_LEN,
        padding="max_length",
    )

def is_long_enough(example):
    return len(example["text"].strip()) > 50

train_raw = wikitext["train"].filter(is_long_enough)
val_raw   = wikitext["validation"].filter(is_long_enough)

train_ds = train_raw.map(tokenize, batched=True, remove_columns=["text"])
val_ds   = val_raw.map(tokenize,   batched=True, remove_columns=["text"])

train_ds.set_format("torch", columns=["input_ids", "attention_mask"])
val_ds.set_format("torch",   columns=["input_ids", "attention_mask"])

collate = partial(mlm_collate_fn, tokenizer=tokenizer, mask_probability=MASK_PROB)

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=collate, drop_last=True,
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False,
    collate_fn=collate,
)

print(f"Train examples: {len(train_ds):,}  |  Val examples: {len(val_ds):,}")


# ── Optimizer and LR schedule ─────────────────────────────────────────────────

optimizer = torch.optim.AdamW(
    model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=0.01
)


def get_lr_scale(step: int) -> float:
    """Linear warmup followed by cosine decay."""
    if step < WARMUP_STEPS:
        return step / max(1, WARMUP_STEPS)
    progress = (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS)
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))


# ── Training loop ─────────────────────────────────────────────────────────────

global_step = 0

for epoch in range(100):           # breaks early when MAX_STEPS is reached
    model.train()

    for batch in train_loader:
        input_ids      = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels         = batch["labels"].to(device)

        # Adjust learning rate
        lr_scale = get_lr_scale(global_step)
        for pg in optimizer.param_groups:
            pg["lr"] = LR * lr_scale

        loss, _ = model(input_ids, attention_mask, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        global_step += 1

        if global_step % 500 == 0:
            # Compute validation loss over up to 50 batches
            model.eval()
            val_loss, val_steps = 0.0, 0
            with torch.no_grad():
                for vbatch in val_loader:
                    vloss, _ = model(
                        vbatch["input_ids"].to(device),
                        vbatch["attention_mask"].to(device),
                        vbatch["labels"].to(device),
                    )
                    val_loss  += vloss.item()
                    val_steps += 1
                    if val_steps >= 50:
                        break
            val_loss /= val_steps

            print(
                f"step={global_step:5d}  "
                f"train_loss={loss.item():.4f}  "
                f"val_loss={val_loss:.4f}  "
                f"val_ppl={math.exp(val_loss):.1f}  "
                f"lr={LR * lr_scale:.2e}"
            )
            model.train()

        if global_step >= MAX_STEPS:
            break
    if global_step >= MAX_STEPS:
        break

print("Pre-training complete. Saving encoder weights...")
torch.save(bert.state_dict(), "bert_pretrained.pt")

Expected learning curve:

Step Val MLM perplexity
500 ~200–500
1000 ~50–100
2500 ~30–50
5000 ~20–40

A randomly initialized encoder starts near the full vocabulary size (~30,000) and rapidly learns token co-occurrence patterns in the first few hundred steps. After 5,000 steps on WikiText-103 the model reaches roughly 20–40 perplexity on the MLM task — confirming it has learned to exploit bidirectional context to reconstruct masked tokens.


Exercise 3 — Evaluate Embeddings via Linear Probing

Guided Exercise

Freeze the pre-trained encoder, extract [CLS] embeddings, and train a linear classifier on AG News (4-class topic classification). Compare against a randomly initialized BERT to quantify how much pre-training helped.

from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn as nn


# ── Helpers ───────────────────────────────────────────────────────────────────

def extract_cls_embeddings(
    bert_model: "BertModel",
    data_loader: DataLoader,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Return CLS embeddings and labels stacked into tensors."""
    bert_model.eval()
    all_emb, all_lbl = [], []
    with torch.no_grad():
        for batch in data_loader:
            seq_out, _ = bert_model(
                batch["input_ids"].to(device),
                batch["attention_mask"].to(device),
            )
            all_emb.append(seq_out[:, 0, :].cpu())   # [CLS] = position 0
            all_lbl.append(batch["label"].cpu())
    return torch.cat(all_emb), torch.cat(all_lbl)


def linear_probe(
    train_emb:    torch.Tensor,
    train_labels: torch.Tensor,
    test_emb:     torch.Tensor,
    test_labels:  torch.Tensor,
    d_model:      int,
    n_classes:    int = 4,
    n_epochs:     int = 20,
) -> float:
    """Train a linear head on frozen embeddings; return test accuracy."""
    probe     = nn.Linear(d_model, n_classes)
    optimizer = torch.optim.AdamW(probe.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    train_dl = DataLoader(
        TensorDataset(train_emb, train_labels), batch_size=64, shuffle=True
    )

    for _ in range(n_epochs):
        probe.train()
        for emb_batch, lbl_batch in train_dl:
            loss = criterion(probe(emb_batch), lbl_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    probe.eval()
    with torch.no_grad():
        preds = probe(test_emb).argmax(dim=-1)
    return (preds == test_labels).float().mean().item()


# ── AG News dataset ───────────────────────────────────────────────────────────

ag_news = load_dataset("ag_news")

def encode_texts(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_LEN,
        padding="max_length",
    )

ag_train_ds = ag_news["train"].select(range(2_000)).map(encode_texts, batched=True)
ag_test_ds  = ag_news["test"].select(range(500)).map(encode_texts, batched=True)

for ds in (ag_train_ds, ag_test_ds):
    ds.set_format("torch", columns=["input_ids", "attention_mask", "label"])

ag_train_loader = DataLoader(ag_train_ds, batch_size=64)
ag_test_loader  = DataLoader(ag_test_ds,  batch_size=64)


# ── Compare pre-trained vs. random initialization ─────────────────────────────

results = {}

for name, checkpoint in [("random_init", None), ("pretrained", "bert_pretrained.pt")]:
    eval_bert = BertModel(
        vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        max_len=MAX_LEN,
    ).to(device)

    if checkpoint is not None:
        eval_bert.load_state_dict(torch.load(checkpoint, map_location=device))

    train_emb, train_lbl = extract_cls_embeddings(eval_bert, ag_train_loader, device)
    test_emb,  test_lbl  = extract_cls_embeddings(eval_bert, ag_test_loader,  device)

    acc = linear_probe(train_emb, train_lbl, test_emb, test_lbl, D_MODEL)
    results[name] = acc
    print(f"{name:20s}  accuracy: {acc:.3f}")

improvement = results["pretrained"] - results["random_init"]
print(f"\nPre-training gain: +{improvement:.3f} ({improvement * 100:.1f} pp)")

Interpreting the results:

A randomly initialized BERT encoder is not entirely useless — its structured computation (LayerNorm, residual connections, learned projections) means it produces some non-trivial embeddings, typically giving ~30–40% linear-probe accuracy on AG News. After pre-training for 5,000 steps the same architecture should reach 60–75%, demonstrating that the MLM objective has encoded meaningful semantic and topical content into the [CLS] vector.

For reference, bert-base-uncased pre-trained on the full Wikipedia + BookCorpus achieves ~92% after fine-tuning — illustrating how much larger-scale pre-training helps.

Visualize the embedding space

Plot UMAP or t-SNE projections of the [CLS] embeddings to confirm that pre-training produces a more structured embedding space:

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

# Simple PCA projection (no extra library required)
def plot_pca_embeddings(embeddings, labels, title):
    pca   = PCA(n_components=2)
    proj  = pca.fit_transform(embeddings.numpy())
    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"]
    class_names = ["World", "Sports", "Business", "Sci/Tech"]

    fig, ax = plt.subplots(figsize=(8, 6))
    for cls_idx in range(4):
        mask = labels == cls_idx
        ax.scatter(proj[mask, 0], proj[mask, 1],
                   c=colors[cls_idx], label=class_names[cls_idx],
                   alpha=0.5, s=10)
    ax.set_title(title)
    ax.legend()
    plt.tight_layout()
    plt.savefig(f"{title.lower().replace(' ', '_')}.png", dpi=100)
    plt.show()

plot_pca_embeddings(test_emb_random, test_lbl, "Random BERT — PCA")
# (re-run extract_cls_embeddings with the pretrained model to get test_emb_pretrained)
plot_pca_embeddings(test_emb,        test_lbl, "Pre-trained BERT — PCA")

The pre-trained embeddings should show clear topic clusters in PCA space; random embeddings will appear as an undifferentiated blob.


Summary

  • BERT's MLM objective trains the model to reconstruct 15% of tokens masked from their left-right context, yielding rich bidirectional representations.
  • The 80/10/10 masking strategy (replace/random/unchanged) prevents the model from exploiting [MASK] as a shortcut and improves transfer to fine-tuning, where [MASK] never appears.
  • Dynamic masking samples a fresh mask every batch, giving each sequence multiple masking views across epochs and is now the standard approach.
  • The MLM head (a two-layer MLP + vocabulary projection) is attached only during pre-training and discarded afterwards.
  • Linear probing on the [CLS] embedding measures representation quality without fine-tuning: even a small encoder trained for 5,000 steps on WikiText-103 significantly outperforms a random baseline.

← Chapter 12 ← Home