Chapter 6: Encoder-Only Architecture — BERT and the [CLS] Token¶
Learning Outcome¶
Understand the encoder-only transformer design, implement a BERT-style model, and learn
how the [CLS] token embedding is used for sequence-level classification tasks.
Concepts¶
Encoder-Only Architecture¶
An encoder-only model processes input tokens with bidirectional attention — every token can attend to every other token. There is no causal mask.
Input: [CLS] The cat sat [SEP]
↓ ↓ ↓ ↓ ↓
Stacked Transformer Blocks (no causal mask)
↓ ↓ ↓ ↓ ↓
Output: h_CLS h_1 h_2 h_3 h_SEP
This bidirectional context makes encoder-only models excellent at understanding tasks: classification, named entity recognition, question answering.
The [CLS] Token¶
BERT prepends a special [CLS] token to every input. Because it attends to all other
tokens (and they attend to it), its final hidden state serves as a global sequence
representation.
This representation is passed through the pooler head (a linear + tanh) and used for sequence-level tasks like sentiment classification and NLI.
The [SEP] Token¶
For two-sentence tasks (e.g., NLI, QA), sentences are separated with [SEP]:
Token Type IDs (Segment Embeddings)¶
BERT adds a second embedding layer (token_type_embeddings) that distinguishes sentence
A (type 0) from sentence B (type 1). This helps the model understand sentence pair
structure.
BERT's Pre-training Objectives¶
Masked Language Modeling (MLM):
- Randomly mask 15% of tokens with [MASK].
- Train the model to predict the original token from context.
- This forces bidirectional understanding.
Next Sentence Prediction (NSP):
- Given two sentences A and B, predict whether B follows A in the corpus.
- The [CLS] token's pooled output is used for this binary prediction.
- (Later work found NSP less useful; RoBERTa dropped it.)
Fine-tuning Patterns¶
| Task | How to use BERT |
|---|---|
| Text classification | Linear layer on [CLS] hidden state |
| Token classification (NER) | Linear layer on each token's hidden state |
| Question answering | Two linear layers predicting span start/end |
Exercise 1 — Implement BertModel from Scratch¶
Guided Exercise
Build a BERT model using the TransformerBlock from Chapter 5.
import torch
import torch.nn as nn
class BertEmbeddings(nn.Module):
def __init__(self, vocab_size: int, d_model: int, max_len: int,
n_token_types: int = 2, dropout: float = 0.1):
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.position_embeddings = nn.Embedding(max_len, d_model)
self.token_type_embeddings = nn.Embedding(n_token_types, d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-12)
self.dropout = nn.Dropout(dropout)
def forward(
self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor | None = None,
) -> torch.Tensor:
batch, seq_len = input_ids.shape
positions = torch.arange(seq_len, device=input_ids.device)
x = self.word_embeddings(input_ids)
x = x + self.position_embeddings(positions)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
x = x + self.token_type_embeddings(token_type_ids)
return self.dropout(self.layer_norm(x))
class BertPooler(nn.Module):
"""Pools the [CLS] token hidden state for sequence classification."""
def __init__(self, d_model: int):
super().__init__()
self.dense = nn.Linear(d_model, d_model)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states: (batch, seq_len, d_model)
cls_hidden = hidden_states[:, 0, :] # Take [CLS] = position 0
return torch.tanh(self.dense(cls_hidden)) # (batch, d_model)
class BertModel(nn.Module):
def __init__(
self,
vocab_size: int = 30522,
d_model: int = 768,
n_heads: int = 12,
n_layers: int = 12,
max_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.embeddings = BertEmbeddings(vocab_size, d_model, max_len, dropout=dropout)
# BERT uses bidirectional attention (no causal mask)
self.encoder = nn.ModuleList([
TransformerBlock(d_model, n_heads, dropout=dropout, causal=False, max_len=max_len)
for _ in range(n_layers)
])
self.pooler = BertPooler(d_model)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
x = self.embeddings(input_ids, token_type_ids)
for block in self.encoder:
x = block(x, attention_mask=attention_mask)
pooled = self.pooler(x)
return x, pooled # (sequence_output, pooled_output)
# Test
bert = BertModel()
ids = torch.randint(1, 30522, (2, 32))
mask = torch.ones(2, 32)
seq_out, pooled = bert(ids, attention_mask=mask)
print("Sequence output:", seq_out.shape) # (2, 32, 768)
print("Pooled output:", pooled.shape) # (2, 768)
Add a classification head¶
class BertForSequenceClassification(nn.Module):
def __init__(self, n_classes: int, **bert_kwargs):
super().__init__()
self.bert = BertModel(**bert_kwargs)
d_model = bert_kwargs.get('d_model', 768)
self.classifier = nn.Linear(d_model, n_classes)
self.dropout = nn.Dropout(0.1)
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
labels=None):
_, pooled = self.bert(input_ids, attention_mask, token_type_ids)
logits = self.classifier(self.dropout(pooled))
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(logits, labels)
return loss, logits
Exercise 2 — Fine-Tune on SST-2 Sentiment Classification¶
from transformers import BertTokenizer, BertForSequenceClassification as HfBert
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
# Load the SST-2 dataset
dataset = load_dataset("glue", "sst2")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def preprocess(examples):
return tokenizer(examples["sentence"], truncation=True,
padding="max_length", max_length=128)
tokenized = dataset.map(preprocess, batched=True)
tokenized = tokenized.rename_column("label", "labels")
tokenized.set_format("torch", columns=["input_ids", "attention_mask",
"token_type_ids", "labels"])
train_loader = DataLoader(tokenized["train"].select(range(1000)),
batch_size=32, shuffle=True)
val_loader = DataLoader(tokenized["validation"], batch_size=64)
# Use HuggingFace's BERT for simplicity (same architecture)
model = HfBert.from_pretrained("bert-base-uncased", num_labels=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Train for 3 epochs
for epoch in range(3):
model.train()
train_loss = 0
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss += loss.item()
# Validation accuracy
model.eval()
correct = total = 0
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
logits = model(**batch).logits
preds = logits.argmax(dim=-1)
correct += (preds == batch["labels"]).sum().item()
total += len(batch["labels"])
print(f"Epoch {epoch+1}: loss={train_loss/len(train_loader):.3f}, "
f"val_acc={correct/total:.3f}")
Exercise 3 — Visualize [CLS] Attention Across Layers¶
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True)
model.eval()
text = "The film was surprisingly moving and beautifully shot."
inputs = tokenizer(text, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
with torch.no_grad():
outputs = model(**inputs)
# For each layer, show how much [CLS] attends to each token (mean over heads)
n_layers = 12
cls_attention_per_layer = []
for layer_attn in outputs.attentions: # (batch, n_heads, seq, seq)
# Average over heads, take [CLS] row (position 0)
cls_attn = layer_attn[0].mean(0)[0] # (seq_len,)
cls_attention_per_layer.append(cls_attn.numpy())
import numpy as np
matrix = np.array(cls_attention_per_layer) # (n_layers, seq_len)
fig, ax = plt.subplots(figsize=(12, 6))
im = ax.imshow(matrix, aspect='auto', cmap='Blues')
ax.set_xticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticks(range(n_layers))
ax.set_yticklabels([f"Layer {i+1}" for i in range(n_layers)])
plt.colorbar(im, ax=ax)
ax.set_title("[CLS] Token Attention to Each Position Across Layers")
plt.tight_layout()
plt.savefig('cls_attention.png', dpi=100)
plt.show()
Summary¶
- Encoder-only models use bidirectional attention — every token sees every other token.
- The
[CLS]token's final hidden state is a learnable sentence representation. - BERT adds token type embeddings to distinguish sentence A from sentence B.
- BERT was pre-trained with MLM (predict masked tokens) and NSP (predict sentence order).
- Fine-tuning adds a task-specific head on top of the frozen or lightly-tuned BERT.