Chapter 4: Causal (Masked) Attention and the Attention Mask¶
Learning Outcome¶
Understand the difference between full (bidirectional) attention and causal (unidirectional) attention, implement the causal mask, and connect this to how BERT and GPT-style models differ fundamentally.
Concepts¶
Two Kinds of Masking¶
Transformers use masks for two distinct purposes:
| Mask type | Purpose | Used by |
|---|---|---|
| Padding mask | Ignore [PAD] tokens in variable-length batches |
BERT, GPT-2 |
| Causal mask | Prevent attending to future tokens | GPT-2, LLaMA |
The Causal Mask¶
For autoregressive generation, position i must only attend to positions ≤ i.
We implement this as an upper-triangular matrix of -inf:
seq_len = 4:
Positions: 0 1 2 3
┌──────────────┐
0 │ 0 -∞ -∞ -∞│
1 │ 0 0 -∞ -∞│
2 │ 0 0 0 -∞│
3 │ 0 0 0 0│
└──────────────┘
Adding this to the attention scores before softmax:
- Scores at
-inf→ softmax output 0 (no attention to those positions). - This is mathematically equivalent to zeroing out those weights.
Why -inf Instead of a Binary Mask?¶
Because attention uses softmax. A binary zero would contribute to the denominator but
not the numerator — the weight would still be non-zero for masked positions. Using
-inf ensures exp(-inf) = 0, so those positions truly contribute nothing.
The HuggingFace attention_mask¶
HuggingFace uses a different convention internally:
- Input
attention_mask:1for attend,0for ignore (padding). - This is converted to an additive mask inside the model:
0 → 0.0(no change),1 → -10000.0(effectively-inf).
# HuggingFace converts padding mask to additive format
mask = (1 - attention_mask) * -10000.0
# Then adds to scores before softmax
Combining Padding and Causal Masks¶
In GPT-2 with batched padded inputs, both masks are needed:
causal_mask = build_causal_mask(seq_len) # (1, 1, seq, seq)
padding_mask = build_padding_mask(attn_mask) # (batch, 1, 1, seq)
combined = causal_mask + padding_mask # broadcasting handles the combination
Exercise 1 — Extend Attention with Both Masks¶
Guided Exercise
Add causal masking support to the attention function from Chapter 3.
Step 1: Build mask constructors¶
import torch
import math
def build_causal_mask(seq_len: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
"""
Build an upper-triangular causal mask.
Returns tensor of shape (1, 1, seq_len, seq_len) with 0 or -inf values.
"""
# Upper triangle (excluding diagonal) is masked
mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
def build_padding_mask(attention_mask: torch.Tensor) -> torch.Tensor:
"""
Convert a HuggingFace-style attention mask (1=attend, 0=ignore) to
an additive mask compatible with our attention function.
Args:
attention_mask: (batch, seq_len) with values 0 or 1
Returns:
additive mask: (batch, 1, 1, seq_len) with values 0 or -inf
"""
# Flip: 0 → -inf, 1 → 0
pad_mask = (1.0 - attention_mask.float()) * float('-inf')
return pad_mask.unsqueeze(1).unsqueeze(2) # (batch, 1, 1, seq_len)
# Test causal mask
causal = build_causal_mask(5)
print("Causal mask (seq_len=5):\n", causal.squeeze())
# Test padding mask
attn = torch.tensor([[1, 1, 1, 0, 0]]) # 3 real tokens, 2 padding
pad = build_padding_mask(attn)
print("\nPadding mask:", pad)
Step 2: Update MultiHeadAttention to use causal masking¶
class CausalMultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, max_len: int = 2048,
dropout: float = 0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
# Register causal mask as buffer (not a parameter)
causal = torch.triu(torch.full((max_len, max_len), float('-inf')), diagonal=1)
self.register_buffer('causal_mask', causal)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
batch, seq, d_model = x.shape
# Project
Q = self._split_heads(self.W_q(x))
K = self._split_heads(self.W_k(x))
V = self._split_heads(self.W_v(x))
# Build combined mask
mask = self.causal_mask[:seq, :seq].unsqueeze(0).unsqueeze(0) # (1,1,seq,seq)
if attention_mask is not None:
pad = build_padding_mask(attention_mask)
mask = mask + pad
# Attention
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
weights = self.dropout(weights)
out = torch.matmul(weights, V)
return self.W_o(self._merge_heads(out))
def _split_heads(self, x):
b, s, _ = x.shape
return x.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
def _merge_heads(self, x):
b, _, s, _ = x.shape
return x.transpose(1, 2).contiguous().view(b, s, -1)
# Verify
attn = CausalMultiHeadAttention(d_model=256, n_heads=8, max_len=64)
x = torch.randn(2, 10, 256)
out = attn(x)
print("Causal MHA output shape:", out.shape) # (2, 10, 256)
Exercise 2 — Intercept Attention Masks in GPT-2¶
from transformers import GPT2Tokenizer, GPT2Model
import torch
import matplotlib.pyplot as plt
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # GPT-2 has no pad token by default
model = GPT2Model.from_pretrained("gpt2", output_attentions=True)
model.eval()
# Create a padded batch: two sentences of different lengths
sentences = ["The quick brown fox", "Hello world"]
inputs = tokenizer(
sentences,
return_tensors="pt",
padding=True,
truncation=True,
)
print("Input IDs shape:", inputs["input_ids"].shape)
print("Attention mask:\n", inputs["attention_mask"])
with torch.no_grad():
outputs = model(**inputs)
# Inspect attention for layer 0, head 0
attn_weights = outputs.attentions[0][0, 0].numpy() # (seq, seq) for batch[0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(attn_weights, cmap='Blues', vmin=0, vmax=1)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)
ax.set_title("GPT-2 Causal Attention — Layer 1, Head 1\n(note: lower triangle only)")
plt.tight_layout()
plt.savefig('causal_attention.png', dpi=100)
plt.show()
Note that the attention weight matrix has zeros in the upper triangle — tokens cannot attend to future positions.
Exercise 3 — Verify Causal Independence¶
Demonstrate that a causally-masked model produces identical results regardless of whether future tokens are present.
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
full_sentence = "The transformer architecture was introduced in"
short_sentence = "The transformer architecture was"
full_ids = tokenizer.encode(full_sentence, return_tensors="pt")
short_ids = tokenizer.encode(short_sentence, return_tensors="pt")
n_short = short_ids.shape[1]
with torch.no_grad():
full_logits = model(full_ids).logits
short_logits = model(short_ids).logits
# The logits for the short prefix should be identical in both cases
diff = (full_logits[:, :n_short, :] - short_logits).abs().max().item()
print(f"Max difference in prefix logits: {diff:.2e}")
assert diff < 1e-4, "Causal masking is not working correctly!"
print("✓ Causal mask ensures prefix logits are independent of future tokens")
This confirms the fundamental property of causal models: each position's output depends only on its context (positions ≤ i), never on future tokens.
Summary¶
- The causal mask is an upper-triangular matrix of
-infthat prevents positions from attending to future tokens. - The padding mask zeros out attention to
[PAD]tokens. - HuggingFace's
attention_mask(1/0 integers) is converted to additive-infmasks. - Both masks can be summed and added to attention scores before softmax.
- GPT-style (decoder-only) models use causal masks; BERT-style (encoder-only) uses only padding masks.