Chapter 3: Scaled Dot-Product Attention¶
Learning Outcome¶
Derive and implement the core attention operation from first principles, understand the
role of the scaling factor, and implement both single-head and multi-head attention as
nn.Module classes that match HuggingFace weight shapes.
Concepts¶
Query, Key, Value Projections¶
Given an input tensor X of shape (batch, seq_len, d_model), three linear projections
create queries, keys, and values:
Geometrically:
- Queries represent what a token is looking for.
- Keys represent what a token is advertising about itself.
- Values represent the content to be retrieved when a query matches a key.
Scaled Dot-Product Attention¶
Steps:
- Compute similarity scores:
scores = Q @ K.T— shape(batch, seq_len, seq_len). - Scale: divide by
sqrt(d_k)whered_kis the head dimension. - Apply softmax row-wise: each row sums to 1.
- Weighted sum of values:
attn_weights @ V.
Why Scale by 1/sqrt(d_k)?¶
For a query and key vector of dimension d_k, if each element is ~ N(0,1), the dot
product has variance d_k. Without scaling, large dot products push softmax into
saturation (near-zero gradients). Dividing by sqrt(d_k) normalizes variance to ~1.
Multi-Head Attention¶
Instead of one large attention computation, split d_model into h smaller heads of
dimension d_head = d_model / h. Run attention in each head independently, then
concatenate results and project back:
Input X: (batch, seq, d_model)
↓ split into h heads
Heads Q_i, K_i, V_i: (batch, seq, d_head) for i=1..h
↓ attention in each head
Output per head: (batch, seq, d_head)
↓ concatenate
Concat: (batch, seq, d_model)
↓ output projection W_O
Output: (batch, seq, d_model)
Each head can attend to different aspects of the input. Some heads learn syntactic patterns; others learn semantic relationships.
Weight Matrix Shapes (BERT-base)¶
| Matrix | Shape | Notes |
|---|---|---|
W_Q |
(768, 768) |
Projects to all heads' queries |
W_K |
(768, 768) |
Projects to all heads' keys |
W_V |
(768, 768) |
Projects to all heads' values |
W_O |
(768, 768) |
Final output projection |
BERT-base: d_model=768, h=12, d_head=64.
Exercise 1 — Implement Scaled Dot-Product Attention¶
Guided Exercise
Build the core attention function and verify it against PyTorch's built-in.
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Q: (batch, n_heads, seq_q, d_head)
K: (batch, n_heads, seq_k, d_head)
V: (batch, n_heads, seq_k, d_head)
mask: (batch, 1, seq_q, seq_k) additive mask; -inf blocks attention
Returns:
output: (batch, n_heads, seq_q, d_head)
weights: (batch, n_heads, seq_q, seq_k)
"""
d_k = Q.size(-1)
# Step 1: similarity scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: apply mask (add -inf to masked positions)
if mask is not None:
scores = scores + mask
# Step 3: softmax
weights = torch.softmax(scores, dim=-1)
# Step 4: weighted sum of values
output = torch.matmul(weights, V)
return output, weights
# Verify against PyTorch's built-in
batch, n_heads, seq_len, d_head = 2, 8, 16, 64
Q = torch.randn(batch, n_heads, seq_len, d_head)
K = torch.randn(batch, n_heads, seq_len, d_head)
V = torch.randn(batch, n_heads, seq_len, d_head)
our_output, _ = scaled_dot_product_attention(Q, K, V)
ref_output = F.scaled_dot_product_attention(Q, K, V)
max_diff = (our_output - ref_output).abs().max().item()
print(f"Max difference from reference: {max_diff:.2e}")
assert max_diff < 1e-5, "Attention outputs don't match!"
print("✓ Matches F.scaled_dot_product_attention")
Exercise 2 — Implement Multi-Head Attention¶
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# Four projection matrices
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape (batch, seq, d_model) → (batch, n_heads, seq, d_head)."""
batch, seq, _ = x.shape
x = x.view(batch, seq, self.n_heads, self.d_head)
return x.transpose(1, 2) # (batch, n_heads, seq, d_head)
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape (batch, n_heads, seq, d_head) → (batch, seq, d_model)."""
batch, _, seq, _ = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(batch, seq, self.d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
# Project to Q, K, V
Q = self._split_heads(self.W_q(query))
K = self._split_heads(self.W_k(key))
V = self._split_heads(self.W_v(value))
# Attend
attn_out, _ = scaled_dot_product_attention(Q, K, V, mask=mask)
# Merge heads and project
out = self._merge_heads(attn_out)
return self.W_o(out)
# Quick shape test
mha = MultiHeadAttention(d_model=768, n_heads=12)
x = torch.randn(2, 16, 768)
out = mha(x, x, x)
print("MHA output shape:", out.shape) # (2, 16, 768)
Load weights from HuggingFace BERT and verify¶
from transformers import BertModel
import torch
bert = BertModel.from_pretrained("bert-base-uncased")
bert_layer = bert.encoder.layer[0].attention.self
# Extract HuggingFace weights (BERT uses biases)
# Our W_q, W_k, W_v are weight-only; we'll use the query_key_value approach
# BERT splits weights per head already; shape: (768, 768)
hf_Wq = bert_layer.query.weight.detach() # (768, 768)
hf_Wk = bert_layer.key.weight.detach()
hf_Wv = bert_layer.value.weight.detach()
print("HuggingFace W_Q shape:", hf_Wq.shape)
print("HuggingFace W_K shape:", hf_Wk.shape)
print("HuggingFace W_V shape:", hf_Wv.shape)
# Load into our module
mha = MultiHeadAttention(d_model=768, n_heads=12)
mha.W_q.weight.data.copy_(hf_Wq)
mha.W_k.weight.data.copy_(hf_Wk)
mha.W_v.weight.data.copy_(hf_Wv)
# Compare a forward pass (ignoring biases and output projection)
x = torch.randn(1, 10, 768)
with torch.no_grad():
our_q = mha._split_heads(mha.W_q(x))
hf_q_raw = torch.nn.functional.linear(x, hf_Wq) # no bias
hf_q = mha._split_heads(hf_q_raw)
diff = (our_q - hf_q).abs().max().item()
print(f"\nQ projection match: max diff = {diff:.2e}")
Exercise 3 — Visualize Attention Weights¶
from transformers import BertTokenizer, BertForMaskedLM
import matplotlib.pyplot as plt
import torch
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased", output_attentions=True)
model.eval()
sentence = "The cat sat on the [MASK] because it was tired."
inputs = tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# attentions is a tuple of (batch, n_heads, seq, seq) — one per layer
attentions = outputs.attentions # 12 layers
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# Plot attention for layer 0, head 0
layer, head = 0, 0
attn = attentions[layer][0, head].numpy() # (seq, seq)
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(attn, cmap='Blues')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90, fontsize=9)
ax.set_yticklabels(tokens, fontsize=9)
plt.colorbar(im, ax=ax)
ax.set_title(f"Attention weights — Layer {layer+1}, Head {head+1}")
plt.tight_layout()
plt.savefig('attention_weights.png', dpi=100)
plt.show()
Try different layer/head combinations. Some heads attend strongly to syntactic relationships (e.g., a verb attends to its subject); others focus on nearby tokens.
Summary¶
- Attention computes a weighted average of values, where weights are determined by
query-key similarity scaled by
1/sqrt(d_k). - Multi-head attention runs
hindependent attention heads in parallel, allowing the model to jointly attend to different representation subspaces. - The four projection matrices (
W_Q,W_K,W_V,W_O) each have shape(d_model, d_model). - Attention weight matrices reveal which tokens each position attends to, and different heads specialize in different patterns.