Skip to content

Chapter 3: Scaled Dot-Product Attention

Learning Outcome

Derive and implement the core attention operation from first principles, understand the role of the scaling factor, and implement both single-head and multi-head attention as nn.Module classes that match HuggingFace weight shapes.


Concepts

Query, Key, Value Projections

Given an input tensor X of shape (batch, seq_len, d_model), three linear projections create queries, keys, and values:

\[ Q = X W_Q, \quad K = X W_K, \quad V = X W_V \]

Geometrically:

  • Queries represent what a token is looking for.
  • Keys represent what a token is advertising about itself.
  • Values represent the content to be retrieved when a query matches a key.

Scaled Dot-Product Attention

\[ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V \]

Steps:

  1. Compute similarity scores: scores = Q @ K.T — shape (batch, seq_len, seq_len).
  2. Scale: divide by sqrt(d_k) where d_k is the head dimension.
  3. Apply softmax row-wise: each row sums to 1.
  4. Weighted sum of values: attn_weights @ V.

Why Scale by 1/sqrt(d_k)?

For a query and key vector of dimension d_k, if each element is ~ N(0,1), the dot product has variance d_k. Without scaling, large dot products push softmax into saturation (near-zero gradients). Dividing by sqrt(d_k) normalizes variance to ~1.

Multi-Head Attention

Instead of one large attention computation, split d_model into h smaller heads of dimension d_head = d_model / h. Run attention in each head independently, then concatenate results and project back:

Input X: (batch, seq, d_model)
          ↓  split into h heads
Heads Q_i, K_i, V_i: (batch, seq, d_head)  for i=1..h
          ↓  attention in each head
Output per head: (batch, seq, d_head)
          ↓  concatenate
Concat: (batch, seq, d_model)
          ↓  output projection W_O
Output: (batch, seq, d_model)

Each head can attend to different aspects of the input. Some heads learn syntactic patterns; others learn semantic relationships.

Weight Matrix Shapes (BERT-base)

Matrix Shape Notes
W_Q (768, 768) Projects to all heads' queries
W_K (768, 768) Projects to all heads' keys
W_V (768, 768) Projects to all heads' values
W_O (768, 768) Final output projection

BERT-base: d_model=768, h=12, d_head=64.


Exercise 1 — Implement Scaled Dot-Product Attention

Guided Exercise

Build the core attention function and verify it against PyTorch's built-in.

import torch
import torch.nn.functional as F
import math


def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        Q: (batch, n_heads, seq_q, d_head)
        K: (batch, n_heads, seq_k, d_head)
        V: (batch, n_heads, seq_k, d_head)
        mask: (batch, 1, seq_q, seq_k) additive mask; -inf blocks attention
    Returns:
        output: (batch, n_heads, seq_q, d_head)
        weights: (batch, n_heads, seq_q, seq_k)
    """
    d_k = Q.size(-1)
    # Step 1: similarity scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # Step 2: apply mask (add -inf to masked positions)
    if mask is not None:
        scores = scores + mask
    # Step 3: softmax
    weights = torch.softmax(scores, dim=-1)
    # Step 4: weighted sum of values
    output = torch.matmul(weights, V)
    return output, weights


# Verify against PyTorch's built-in
batch, n_heads, seq_len, d_head = 2, 8, 16, 64

Q = torch.randn(batch, n_heads, seq_len, d_head)
K = torch.randn(batch, n_heads, seq_len, d_head)
V = torch.randn(batch, n_heads, seq_len, d_head)

our_output, _ = scaled_dot_product_attention(Q, K, V)
ref_output = F.scaled_dot_product_attention(Q, K, V)

max_diff = (our_output - ref_output).abs().max().item()
print(f"Max difference from reference: {max_diff:.2e}")
assert max_diff < 1e-5, "Attention outputs don't match!"
print("✓ Matches F.scaled_dot_product_attention")

Exercise 2 — Implement Multi-Head Attention

import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # Four projection matrices
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """Reshape (batch, seq, d_model) → (batch, n_heads, seq, d_head)."""
        batch, seq, _ = x.shape
        x = x.view(batch, seq, self.n_heads, self.d_head)
        return x.transpose(1, 2)  # (batch, n_heads, seq, d_head)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """Reshape (batch, n_heads, seq, d_head) → (batch, seq, d_model)."""
        batch, _, seq, _ = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq, self.d_model)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Project to Q, K, V
        Q = self._split_heads(self.W_q(query))
        K = self._split_heads(self.W_k(key))
        V = self._split_heads(self.W_v(value))

        # Attend
        attn_out, _ = scaled_dot_product_attention(Q, K, V, mask=mask)

        # Merge heads and project
        out = self._merge_heads(attn_out)
        return self.W_o(out)


# Quick shape test
mha = MultiHeadAttention(d_model=768, n_heads=12)
x = torch.randn(2, 16, 768)
out = mha(x, x, x)
print("MHA output shape:", out.shape)  # (2, 16, 768)

Load weights from HuggingFace BERT and verify

from transformers import BertModel
import torch

bert = BertModel.from_pretrained("bert-base-uncased")
bert_layer = bert.encoder.layer[0].attention.self

# Extract HuggingFace weights (BERT uses biases)
# Our W_q, W_k, W_v are weight-only; we'll use the query_key_value approach
# BERT splits weights per head already; shape: (768, 768)
hf_Wq = bert_layer.query.weight.detach()  # (768, 768)
hf_Wk = bert_layer.key.weight.detach()
hf_Wv = bert_layer.value.weight.detach()

print("HuggingFace W_Q shape:", hf_Wq.shape)
print("HuggingFace W_K shape:", hf_Wk.shape)
print("HuggingFace W_V shape:", hf_Wv.shape)

# Load into our module
mha = MultiHeadAttention(d_model=768, n_heads=12)
mha.W_q.weight.data.copy_(hf_Wq)
mha.W_k.weight.data.copy_(hf_Wk)
mha.W_v.weight.data.copy_(hf_Wv)

# Compare a forward pass (ignoring biases and output projection)
x = torch.randn(1, 10, 768)
with torch.no_grad():
    our_q = mha._split_heads(mha.W_q(x))
    hf_q_raw = torch.nn.functional.linear(x, hf_Wq)  # no bias
    hf_q = mha._split_heads(hf_q_raw)

diff = (our_q - hf_q).abs().max().item()
print(f"\nQ projection match: max diff = {diff:.2e}")

Exercise 3 — Visualize Attention Weights

from transformers import BertTokenizer, BertForMaskedLM
import matplotlib.pyplot as plt
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased", output_attentions=True)
model.eval()

sentence = "The cat sat on the [MASK] because it was tired."
inputs = tokenizer(sentence, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# attentions is a tuple of (batch, n_heads, seq, seq) — one per layer
attentions = outputs.attentions  # 12 layers
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

# Plot attention for layer 0, head 0
layer, head = 0, 0
attn = attentions[layer][0, head].numpy()  # (seq, seq)

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(attn, cmap='Blues')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90, fontsize=9)
ax.set_yticklabels(tokens, fontsize=9)
plt.colorbar(im, ax=ax)
ax.set_title(f"Attention weights — Layer {layer+1}, Head {head+1}")
plt.tight_layout()
plt.savefig('attention_weights.png', dpi=100)
plt.show()

Try different layer/head combinations. Some heads attend strongly to syntactic relationships (e.g., a verb attends to its subject); others focus on nearby tokens.


Summary

  • Attention computes a weighted average of values, where weights are determined by query-key similarity scaled by 1/sqrt(d_k).
  • Multi-head attention runs h independent attention heads in parallel, allowing the model to jointly attend to different representation subspaces.
  • The four projection matrices (W_Q, W_K, W_V, W_O) each have shape (d_model, d_model).
  • Attention weight matrices reveal which tokens each position attends to, and different heads specialize in different patterns.

← Chapter 2 Chapter 4: Causal Attention →