Skip to content

Chapter 4: Causal (Masked) Attention and the Attention Mask

Learning Outcome

Understand the difference between full (bidirectional) attention and causal (unidirectional) attention, implement the causal mask, and connect this to how BERT and GPT-style models differ fundamentally.


Concepts

Two Kinds of Masking

Transformers use masks for two distinct purposes:

Mask type Purpose Used by
Padding mask Ignore [PAD] tokens in variable-length batches BERT, GPT-2
Causal mask Prevent attending to future tokens GPT-2, LLaMA

The Causal Mask

For autoregressive generation, position i must only attend to positions ≤ i. We implement this as an upper-triangular matrix of -inf:

seq_len = 4:

Positions:  0   1   2   3
         ┌──────────────┐
   0     │  0  -∞  -∞  -∞│
   1     │  0   0  -∞  -∞│
   2     │  0   0   0  -∞│
   3     │  0   0   0   0│
         └──────────────┘

Adding this to the attention scores before softmax:

  • Scores at -inf → softmax output 0 (no attention to those positions).
  • This is mathematically equivalent to zeroing out those weights.

Why -inf Instead of a Binary Mask?

Because attention uses softmax. A binary zero would contribute to the denominator but not the numerator — the weight would still be non-zero for masked positions. Using -inf ensures exp(-inf) = 0, so those positions truly contribute nothing.

The HuggingFace attention_mask

HuggingFace uses a different convention internally:

  • Input attention_mask: 1 for attend, 0 for ignore (padding).
  • This is converted to an additive mask inside the model: 0 → 0.0 (no change), 1 → -10000.0 (effectively -inf).
# HuggingFace converts padding mask to additive format
mask = (1 - attention_mask) * -10000.0
# Then adds to scores before softmax

Combining Padding and Causal Masks

In GPT-2 with batched padded inputs, both masks are needed:

causal_mask = build_causal_mask(seq_len)     # (1, 1, seq, seq)
padding_mask = build_padding_mask(attn_mask) # (batch, 1, 1, seq)
combined = causal_mask + padding_mask        # broadcasting handles the combination

Exercise 1 — Extend Attention with Both Masks

Guided Exercise

Add causal masking support to the attention function from Chapter 3.

Step 1: Build mask constructors

import torch
import math


def build_causal_mask(seq_len: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
    """
    Build an upper-triangular causal mask.
    Returns tensor of shape (1, 1, seq_len, seq_len) with 0 or -inf values.
    """
    # Upper triangle (excluding diagonal) is masked
    mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)


def build_padding_mask(attention_mask: torch.Tensor) -> torch.Tensor:
    """
    Convert a HuggingFace-style attention mask (1=attend, 0=ignore) to
    an additive mask compatible with our attention function.

    Args:
        attention_mask: (batch, seq_len) with values 0 or 1
    Returns:
        additive mask: (batch, 1, 1, seq_len) with values 0 or -inf
    """
    # Flip: 0 → -inf, 1 → 0
    pad_mask = (1.0 - attention_mask.float()) * float('-inf')
    return pad_mask.unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, seq_len)


# Test causal mask
causal = build_causal_mask(5)
print("Causal mask (seq_len=5):\n", causal.squeeze())

# Test padding mask
attn = torch.tensor([[1, 1, 1, 0, 0]])  # 3 real tokens, 2 padding
pad = build_padding_mask(attn)
print("\nPadding mask:", pad)

Step 2: Update MultiHeadAttention to use causal masking

class CausalMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, max_len: int = 2048,
                 dropout: float = 0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

        # Register causal mask as buffer (not a parameter)
        causal = torch.triu(torch.full((max_len, max_len), float('-inf')), diagonal=1)
        self.register_buffer('causal_mask', causal)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        batch, seq, d_model = x.shape

        # Project
        Q = self._split_heads(self.W_q(x))
        K = self._split_heads(self.W_k(x))
        V = self._split_heads(self.W_v(x))

        # Build combined mask
        mask = self.causal_mask[:seq, :seq].unsqueeze(0).unsqueeze(0)  # (1,1,seq,seq)
        if attention_mask is not None:
            pad = build_padding_mask(attention_mask)
            mask = mask + pad

        # Attention
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        scores = scores + mask
        weights = torch.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        out = torch.matmul(weights, V)

        return self.W_o(self._merge_heads(out))

    def _split_heads(self, x):
        b, s, _ = x.shape
        return x.view(b, s, self.n_heads, self.d_head).transpose(1, 2)

    def _merge_heads(self, x):
        b, _, s, _ = x.shape
        return x.transpose(1, 2).contiguous().view(b, s, -1)


# Verify
attn = CausalMultiHeadAttention(d_model=256, n_heads=8, max_len=64)
x = torch.randn(2, 10, 256)
out = attn(x)
print("Causal MHA output shape:", out.shape)  # (2, 10, 256)

Exercise 2 — Intercept Attention Masks in GPT-2

from transformers import GPT2Tokenizer, GPT2Model
import torch
import matplotlib.pyplot as plt

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 has no pad token by default
model = GPT2Model.from_pretrained("gpt2", output_attentions=True)
model.eval()

# Create a padded batch: two sentences of different lengths
sentences = ["The quick brown fox", "Hello world"]
inputs = tokenizer(
    sentences,
    return_tensors="pt",
    padding=True,
    truncation=True,
)

print("Input IDs shape:", inputs["input_ids"].shape)
print("Attention mask:\n", inputs["attention_mask"])

with torch.no_grad():
    outputs = model(**inputs)

# Inspect attention for layer 0, head 0
attn_weights = outputs.attentions[0][0, 0].numpy()  # (seq, seq) for batch[0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(attn_weights, cmap='Blues', vmin=0, vmax=1)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_yticklabels(tokens)
ax.set_title("GPT-2 Causal Attention — Layer 1, Head 1\n(note: lower triangle only)")
plt.tight_layout()
plt.savefig('causal_attention.png', dpi=100)
plt.show()

Note that the attention weight matrix has zeros in the upper triangle — tokens cannot attend to future positions.


Exercise 3 — Verify Causal Independence

Demonstrate that a causally-masked model produces identical results regardless of whether future tokens are present.

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()

full_sentence = "The transformer architecture was introduced in"
short_sentence = "The transformer architecture was"

full_ids = tokenizer.encode(full_sentence, return_tensors="pt")
short_ids = tokenizer.encode(short_sentence, return_tensors="pt")

n_short = short_ids.shape[1]

with torch.no_grad():
    full_logits = model(full_ids).logits
    short_logits = model(short_ids).logits

# The logits for the short prefix should be identical in both cases
diff = (full_logits[:, :n_short, :] - short_logits).abs().max().item()
print(f"Max difference in prefix logits: {diff:.2e}")
assert diff < 1e-4, "Causal masking is not working correctly!"
print("✓ Causal mask ensures prefix logits are independent of future tokens")

This confirms the fundamental property of causal models: each position's output depends only on its context (positions ≤ i), never on future tokens.


Summary

  • The causal mask is an upper-triangular matrix of -inf that prevents positions from attending to future tokens.
  • The padding mask zeros out attention to [PAD] tokens.
  • HuggingFace's attention_mask (1/0 integers) is converted to additive -inf masks.
  • Both masks can be summed and added to attention scores before softmax.
  • GPT-style (decoder-only) models use causal masks; BERT-style (encoder-only) uses only padding masks.

← Chapter 3 Chapter 5: Transformer Block →