Chapter 8: Encoder-Decoder Architecture — T5 and Cross-Attention¶
Learning Outcome¶
Understand the encoder-decoder transformer (seq2seq) design, implement cross-attention, and see how tasks like translation, summarization, and question-answering map to this architecture.
Concepts¶
Encoder-Decoder Overview¶
In a seq2seq model, the encoder processes the full input sequence bidirectionally. The decoder generates the output autoregressively, attending to the encoder's output at each step via cross-attention.
Input: "translate English to French: The sky is blue"
↓
Encoder (bidirectional)
↓
Encoder hidden states: (seq_in, d_model)
↓ ←————————————————————————————————
Decoder (causal self-attn + cross-attn)
↓
Output: "Le ciel est bleu"
Cross-Attention¶
Cross-attention differs from self-attention in one key way: - Queries come from the decoder's hidden states. - Keys and Values come from the encoder's output.
This allows each decoder position to "look up" relevant encoder context — e.g., when generating "bleu", the decoder attends to "blue" in the input.
Decoder Block Structure¶
The decoder has two attention sublayers per block:
- Causal self-attention on the decoder's own (so-far-generated) output.
- Cross-attention with the encoder's output (keys and values are fixed).
- FFN.
Each sublayer is wrapped with pre-norm and residual connections.
T5: Text-to-Text Transfer Transformer¶
T5 frames every NLP task as text-in → text-out:
| Task | Input | Output |
|---|---|---|
| Translation | translate English to French: ... |
... |
| Summarization | summarize: ... |
... |
| Classification | sst2 sentence: ... |
positive or negative |
| QA | question: ... context: ... |
answer string |
T5 uses relative position biases instead of positional embeddings: scalar biases are added to attention logits based on relative distance.
Teacher Forcing¶
During training, the decoder receives the ground truth previous tokens as input (even if it would have predicted something different). This avoids error accumulation during training and is called "teacher forcing". At inference, the model's own predictions are fed back.
Exercise 1 — Implement DecoderBlock with Cross-Attention¶
Guided Exercise
Extend the TransformerBlock from Chapter 5 to include cross-attention.
import torch
import torch.nn as nn
class DecoderBlock(nn.Module):
"""Transformer decoder block with causal self-attention + cross-attention."""
def __init__(self, d_model: int, n_heads: int, d_ff: int | None = None,
dropout: float = 0.1, max_len: int = 512):
super().__init__()
if d_ff is None:
d_ff = 4 * d_model
# Three sublayers
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.self_attn = CausalMultiHeadAttention(d_model, n_heads,
max_len=max_len, dropout=dropout)
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout=dropout)
self.ff = FeedForward(d_model, d_ff, dropout=dropout)
def forward(
self,
x: torch.Tensor, # decoder input
enc_out: torch.Tensor, # encoder output (K, V source)
tgt_mask: torch.Tensor | None = None, # causal mask for decoder
src_mask: torch.Tensor | None = None, # padding mask for encoder
) -> torch.Tensor:
# 1. Causal self-attention
x = x + self.self_attn(self.norm1(x), attention_mask=tgt_mask)
# 2. Cross-attention: Q from decoder, K/V from encoder
x = x + self.cross_attn(
query=self.norm2(x),
key=enc_out,
value=enc_out,
mask=src_mask,
)
# 3. FFN
x = x + self.ff(self.norm3(x))
return x
# Minimal encoder-decoder model
class EncoderDecoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 256,
n_heads: int = 8,
n_encoder_layers: int = 6,
n_decoder_layers: int = 6,
max_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.src_embed = nn.Embedding(vocab_size, d_model)
self.tgt_embed = nn.Embedding(vocab_size, d_model)
self.src_pos = nn.Embedding(max_len, d_model)
self.tgt_pos = nn.Embedding(max_len, d_model)
self.encoder_layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, causal=False, max_len=max_len,
dropout=dropout)
for _ in range(n_encoder_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderBlock(d_model, n_heads, max_len=max_len, dropout=dropout)
for _ in range(n_decoder_layers)
])
self.norm_enc = nn.LayerNorm(d_model)
self.norm_dec = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
def encode(self, src: torch.Tensor, src_mask=None) -> torch.Tensor:
pos = torch.arange(src.size(1), device=src.device)
x = self.src_embed(src) + self.src_pos(pos)
for layer in self.encoder_layers:
x = layer(x, attention_mask=src_mask)
return self.norm_enc(x)
def decode(self, tgt: torch.Tensor, enc_out: torch.Tensor,
tgt_mask=None, src_mask=None) -> torch.Tensor:
pos = torch.arange(tgt.size(1), device=tgt.device)
x = self.tgt_embed(tgt) + self.tgt_pos(pos)
for layer in self.decoder_layers:
x = layer(x, enc_out, tgt_mask=tgt_mask, src_mask=src_mask)
return self.norm_dec(x)
def forward(self, src, tgt, src_mask=None, tgt_mask=None) -> torch.Tensor:
enc_out = self.encode(src, src_mask)
dec_out = self.decode(tgt, enc_out, tgt_mask, src_mask)
return self.lm_head(dec_out) # (batch, tgt_len, vocab_size)
# Test
model = EncoderDecoder(vocab_size=1000, d_model=128, n_heads=4,
n_encoder_layers=2, n_decoder_layers=2)
src = torch.randint(1, 1000, (2, 10))
tgt = torch.randint(1, 1000, (2, 8))
logits = model(src, tgt)
print("Encoder-decoder logits shape:", logits.shape) # (2, 8, 1000)
Exercise 2 — Trace Cross-Attention in T5¶
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import matplotlib.pyplot as plt
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small", output_attentions=True)
model.eval()
# Translation task
input_text = "translate English to French: The sky is blue and the sun is shining."
target_text = "Le ciel est bleu et le soleil brille."
inputs = tokenizer(input_text, return_tensors="pt")
targets = tokenizer(target_text, return_tensors="pt").input_ids
with torch.no_grad():
outputs = model(**inputs, labels=targets)
print("T5-small output keys:", outputs.keys())
# Cross-attention weights: shape (batch, n_heads, tgt_len, src_len)
cross_attentions = outputs.cross_attentions
print(f"\nNumber of decoder layers with cross-attention: {len(cross_attentions)}")
print(f"Cross-attention shape per layer: {cross_attentions[0].shape}")
# Visualize cross-attention for layer 0, head 0
src_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
tgt_tokens = tokenizer.convert_ids_to_tokens(targets[0])
layer, head = 0, 0
ca = cross_attentions[layer][0, head].numpy() # (tgt_len, src_len)
fig, ax = plt.subplots(figsize=(12, 5))
ax.imshow(ca, cmap='Blues', aspect='auto')
ax.set_xticks(range(len(src_tokens)))
ax.set_yticks(range(len(tgt_tokens)))
ax.set_xticklabels(src_tokens, rotation=45, ha='right', fontsize=8)
ax.set_yticklabels(tgt_tokens, fontsize=9)
ax.set_xlabel("Source tokens (encoder)")
ax.set_ylabel("Target tokens (decoder)")
ax.set_title(f"T5 Cross-Attention — Layer {layer+1}, Head {head+1}")
plt.tight_layout()
plt.savefig('cross_attention.png', dpi=100)
plt.show()
Exercise 3 — Train on a Toy Sequence Reversal Task¶
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
class ReversalDataset(Dataset):
"""Reverse a sequence of integers."""
def __init__(self, n_samples: int, seq_len: int = 8, vocab_size: int = 20):
self.data = []
for _ in range(n_samples):
src = [random.randint(2, vocab_size - 1) for _ in range(seq_len)]
tgt = [1] + src[::-1] + [0] # 1=BOS, 0=EOS
self.data.append((torch.tensor(src), torch.tensor(tgt)))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
VOCAB_SIZE = 22 # 0=EOS, 1=BOS, 2–21=tokens
SEQ_LEN = 8
train_set = ReversalDataset(5000, SEQ_LEN, VOCAB_SIZE)
val_set = ReversalDataset(500, SEQ_LEN, VOCAB_SIZE)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64)
model = EncoderDecoder(vocab_size=VOCAB_SIZE, d_model=64, n_heads=4,
n_encoder_layers=2, n_decoder_layers=2, max_len=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
for epoch in range(10):
model.train()
total_loss = 0
for src, tgt in train_loader:
tgt_in = tgt[:, :-1] # decoder input (without EOS)
tgt_out = tgt[:, 1:] # expected output (without BOS)
logits = model(src, tgt_in) # (batch, tgt_len-1, vocab)
loss = criterion(logits.reshape(-1, VOCAB_SIZE), tgt_out.reshape(-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
if (epoch + 1) % 2 == 0:
model.eval()
correct = total = 0
with torch.no_grad():
for src, tgt in val_loader:
tgt_in = tgt[:, :-1]
tgt_out = tgt[:, 1:]
logits = model(src, tgt_in)
preds = logits.argmax(dim=-1)
correct += (preds == tgt_out).all(dim=-1).sum().item()
total += src.size(0)
print(f"Epoch {epoch+1:2d}: loss={total_loss/len(train_loader):.4f}, "
f"exact_match={correct/total:.3f}")
Summary¶
- Encoder-decoder models process input (encoder) and output (decoder) separately.
- Cross-attention lets decoder positions attend to encoder outputs using encoder K/V.
- T5 unifies all NLP tasks as text-to-text generation using task-specific prefixes.
- Teacher forcing uses ground-truth tokens during training to avoid compounding errors.
- Cross-attention weight matrices reveal which source tokens each target position attends to.