Skip to content

Chapter 6: Encoder-Only Architecture — BERT and the [CLS] Token

Learning Outcome

Understand the encoder-only transformer design, implement a BERT-style model, and learn how the [CLS] token embedding is used for sequence-level classification tasks.


Concepts

Encoder-Only Architecture

An encoder-only model processes input tokens with bidirectional attention — every token can attend to every other token. There is no causal mask.

Input:  [CLS] The cat sat [SEP]
         ↓     ↓    ↓   ↓    ↓
        Stacked Transformer Blocks (no causal mask)
         ↓     ↓    ↓   ↓    ↓
Output: h_CLS  h_1  h_2 h_3  h_SEP

This bidirectional context makes encoder-only models excellent at understanding tasks: classification, named entity recognition, question answering.

The [CLS] Token

BERT prepends a special [CLS] token to every input. Because it attends to all other tokens (and they attend to it), its final hidden state serves as a global sequence representation.

This representation is passed through the pooler head (a linear + tanh) and used for sequence-level tasks like sentiment classification and NLI.

The [SEP] Token

For two-sentence tasks (e.g., NLI, QA), sentences are separated with [SEP]:

[CLS] Sentence A [SEP] Sentence B [SEP]

Token Type IDs (Segment Embeddings)

BERT adds a second embedding layer (token_type_embeddings) that distinguishes sentence A (type 0) from sentence B (type 1). This helps the model understand sentence pair structure.

BERT's Pre-training Objectives

Masked Language Modeling (MLM): - Randomly mask 15% of tokens with [MASK]. - Train the model to predict the original token from context. - This forces bidirectional understanding.

Next Sentence Prediction (NSP): - Given two sentences A and B, predict whether B follows A in the corpus. - The [CLS] token's pooled output is used for this binary prediction. - (Later work found NSP less useful; RoBERTa dropped it.)

Fine-tuning Patterns

Task How to use BERT
Text classification Linear layer on [CLS] hidden state
Token classification (NER) Linear layer on each token's hidden state
Question answering Two linear layers predicting span start/end

Exercise 1 — Implement BertModel from Scratch

Guided Exercise

Build a BERT model using the TransformerBlock from Chapter 5.

import torch
import torch.nn as nn


class BertEmbeddings(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, max_len: int,
                 n_token_types: int = 2, dropout: float = 0.1):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.position_embeddings = nn.Embedding(max_len, d_model)
        self.token_type_embeddings = nn.Embedding(n_token_types, d_model)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-12)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor | None = None,
    ) -> torch.Tensor:
        batch, seq_len = input_ids.shape
        positions = torch.arange(seq_len, device=input_ids.device)

        x = self.word_embeddings(input_ids)
        x = x + self.position_embeddings(positions)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        x = x + self.token_type_embeddings(token_type_ids)

        return self.dropout(self.layer_norm(x))


class BertPooler(nn.Module):
    """Pools the [CLS] token hidden state for sequence classification."""
    def __init__(self, d_model: int):
        super().__init__()
        self.dense = nn.Linear(d_model, d_model)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # hidden_states: (batch, seq_len, d_model)
        cls_hidden = hidden_states[:, 0, :]  # Take [CLS] = position 0
        return torch.tanh(self.dense(cls_hidden))  # (batch, d_model)


class BertModel(nn.Module):
    def __init__(
        self,
        vocab_size: int = 30522,
        d_model: int = 768,
        n_heads: int = 12,
        n_layers: int = 12,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embeddings = BertEmbeddings(vocab_size, d_model, max_len, dropout=dropout)
        # BERT uses bidirectional attention (no causal mask)
        self.encoder = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout=dropout, causal=False, max_len=max_len)
            for _ in range(n_layers)
        ])
        self.pooler = BertPooler(d_model)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.embeddings(input_ids, token_type_ids)
        for block in self.encoder:
            x = block(x, attention_mask=attention_mask)
        pooled = self.pooler(x)
        return x, pooled  # (sequence_output, pooled_output)


# Test
bert = BertModel()
ids = torch.randint(1, 30522, (2, 32))
mask = torch.ones(2, 32)
seq_out, pooled = bert(ids, attention_mask=mask)
print("Sequence output:", seq_out.shape)   # (2, 32, 768)
print("Pooled output:", pooled.shape)       # (2, 768)

Add a classification head

class BertForSequenceClassification(nn.Module):
    def __init__(self, n_classes: int, **bert_kwargs):
        super().__init__()
        self.bert = BertModel(**bert_kwargs)
        d_model = bert_kwargs.get('d_model', 768)
        self.classifier = nn.Linear(d_model, n_classes)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None):
        _, pooled = self.bert(input_ids, attention_mask, token_type_ids)
        logits = self.classifier(self.dropout(pooled))

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return loss, logits

Exercise 2 — Fine-Tune on SST-2 Sentiment Classification

from transformers import BertTokenizer, BertForSequenceClassification as HfBert
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch

# Load the SST-2 dataset
dataset = load_dataset("glue", "sst2")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def preprocess(examples):
    return tokenizer(examples["sentence"], truncation=True,
                     padding="max_length", max_length=128)

tokenized = dataset.map(preprocess, batched=True)
tokenized = tokenized.rename_column("label", "labels")
tokenized.set_format("torch", columns=["input_ids", "attention_mask",
                                        "token_type_ids", "labels"])

train_loader = DataLoader(tokenized["train"].select(range(1000)),
                          batch_size=32, shuffle=True)
val_loader   = DataLoader(tokenized["validation"], batch_size=64)

# Use HuggingFace's BERT for simplicity (same architecture)
model = HfBert.from_pretrained("bert-base-uncased", num_labels=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Train for 3 epochs
for epoch in range(3):
    model.train()
    train_loss = 0
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()

    # Validation accuracy
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(**batch).logits
            preds = logits.argmax(dim=-1)
            correct += (preds == batch["labels"]).sum().item()
            total += len(batch["labels"])

    print(f"Epoch {epoch+1}: loss={train_loss/len(train_loader):.3f}, "
          f"val_acc={correct/total:.3f}")

Exercise 3 — Visualize [CLS] Attention Across Layers

from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True)
model.eval()

text = "The film was surprisingly moving and beautifully shot."
inputs = tokenizer(text, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

with torch.no_grad():
    outputs = model(**inputs)

# For each layer, show how much [CLS] attends to each token (mean over heads)
n_layers = 12
cls_attention_per_layer = []
for layer_attn in outputs.attentions:  # (batch, n_heads, seq, seq)
    # Average over heads, take [CLS] row (position 0)
    cls_attn = layer_attn[0].mean(0)[0]  # (seq_len,)
    cls_attention_per_layer.append(cls_attn.numpy())

import numpy as np
matrix = np.array(cls_attention_per_layer)  # (n_layers, seq_len)

fig, ax = plt.subplots(figsize=(12, 6))
im = ax.imshow(matrix, aspect='auto', cmap='Blues')
ax.set_xticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticks(range(n_layers))
ax.set_yticklabels([f"Layer {i+1}" for i in range(n_layers)])
plt.colorbar(im, ax=ax)
ax.set_title("[CLS] Token Attention to Each Position Across Layers")
plt.tight_layout()
plt.savefig('cls_attention.png', dpi=100)
plt.show()

Summary

  • Encoder-only models use bidirectional attention — every token sees every other token.
  • The [CLS] token's final hidden state is a learnable sentence representation.
  • BERT adds token type embeddings to distinguish sentence A from sentence B.
  • BERT was pre-trained with MLM (predict masked tokens) and NSP (predict sentence order).
  • Fine-tuning adds a task-specific head on top of the frozen or lightly-tuned BERT.

← Chapter 5 Chapter 7: Decoder-Only →