Skip to content

Chapter 8: Encoder-Decoder Architecture — T5 and Cross-Attention

Learning Outcome

Understand the encoder-decoder transformer (seq2seq) design, implement cross-attention, and see how tasks like translation, summarization, and question-answering map to this architecture.


Concepts

Encoder-Decoder Overview

In a seq2seq model, the encoder processes the full input sequence bidirectionally. The decoder generates the output autoregressively, attending to the encoder's output at each step via cross-attention.

Input:  "translate English to French: The sky is blue"
        Encoder (bidirectional)
        Encoder hidden states: (seq_in, d_model)
         ↓ ←————————————————————————————————
        Decoder (causal self-attn + cross-attn)
Output: "Le ciel est bleu"

Cross-Attention

Cross-attention differs from self-attention in one key way: - Queries come from the decoder's hidden states. - Keys and Values come from the encoder's output.

\[ \text{CrossAttention}(Q_{dec}, K_{enc}, V_{enc}) = \text{softmax}\!\left(\frac{Q_{dec} K_{enc}^\top}{\sqrt{d_k}}\right) V_{enc} \]

This allows each decoder position to "look up" relevant encoder context — e.g., when generating "bleu", the decoder attends to "blue" in the input.

Decoder Block Structure

The decoder has two attention sublayers per block:

  1. Causal self-attention on the decoder's own (so-far-generated) output.
  2. Cross-attention with the encoder's output (keys and values are fixed).
  3. FFN.

Each sublayer is wrapped with pre-norm and residual connections.

T5: Text-to-Text Transfer Transformer

T5 frames every NLP task as text-in → text-out:

Task Input Output
Translation translate English to French: ... ...
Summarization summarize: ... ...
Classification sst2 sentence: ... positive or negative
QA question: ... context: ... answer string

T5 uses relative position biases instead of positional embeddings: scalar biases are added to attention logits based on relative distance.

Teacher Forcing

During training, the decoder receives the ground truth previous tokens as input (even if it would have predicted something different). This avoids error accumulation during training and is called "teacher forcing". At inference, the model's own predictions are fed back.


Exercise 1 — Implement DecoderBlock with Cross-Attention

Guided Exercise

Extend the TransformerBlock from Chapter 5 to include cross-attention.

import torch
import torch.nn as nn


class DecoderBlock(nn.Module):
    """Transformer decoder block with causal self-attention + cross-attention."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int | None = None,
                 dropout: float = 0.1, max_len: int = 512):
        super().__init__()
        if d_ff is None:
            d_ff = 4 * d_model

        # Three sublayers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.self_attn = CausalMultiHeadAttention(d_model, n_heads,
                                                   max_len=max_len, dropout=dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout=dropout)
        self.ff = FeedForward(d_model, d_ff, dropout=dropout)

    def forward(
        self,
        x: torch.Tensor,                           # decoder input
        enc_out: torch.Tensor,                     # encoder output (K, V source)
        tgt_mask: torch.Tensor | None = None,      # causal mask for decoder
        src_mask: torch.Tensor | None = None,      # padding mask for encoder
    ) -> torch.Tensor:
        # 1. Causal self-attention
        x = x + self.self_attn(self.norm1(x), attention_mask=tgt_mask)
        # 2. Cross-attention: Q from decoder, K/V from encoder
        x = x + self.cross_attn(
            query=self.norm2(x),
            key=enc_out,
            value=enc_out,
            mask=src_mask,
        )
        # 3. FFN
        x = x + self.ff(self.norm3(x))
        return x


# Minimal encoder-decoder model
class EncoderDecoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 256,
        n_heads: int = 8,
        n_encoder_layers: int = 6,
        n_decoder_layers: int = 6,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.src_embed = nn.Embedding(vocab_size, d_model)
        self.tgt_embed = nn.Embedding(vocab_size, d_model)
        self.src_pos = nn.Embedding(max_len, d_model)
        self.tgt_pos = nn.Embedding(max_len, d_model)

        self.encoder_layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, causal=False, max_len=max_len,
                             dropout=dropout)
            for _ in range(n_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(d_model, n_heads, max_len=max_len, dropout=dropout)
            for _ in range(n_decoder_layers)
        ])

        self.norm_enc = nn.LayerNorm(d_model)
        self.norm_dec = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def encode(self, src: torch.Tensor, src_mask=None) -> torch.Tensor:
        pos = torch.arange(src.size(1), device=src.device)
        x = self.src_embed(src) + self.src_pos(pos)
        for layer in self.encoder_layers:
            x = layer(x, attention_mask=src_mask)
        return self.norm_enc(x)

    def decode(self, tgt: torch.Tensor, enc_out: torch.Tensor,
               tgt_mask=None, src_mask=None) -> torch.Tensor:
        pos = torch.arange(tgt.size(1), device=tgt.device)
        x = self.tgt_embed(tgt) + self.tgt_pos(pos)
        for layer in self.decoder_layers:
            x = layer(x, enc_out, tgt_mask=tgt_mask, src_mask=src_mask)
        return self.norm_dec(x)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None) -> torch.Tensor:
        enc_out = self.encode(src, src_mask)
        dec_out = self.decode(tgt, enc_out, tgt_mask, src_mask)
        return self.lm_head(dec_out)  # (batch, tgt_len, vocab_size)


# Test
model = EncoderDecoder(vocab_size=1000, d_model=128, n_heads=4,
                       n_encoder_layers=2, n_decoder_layers=2)
src = torch.randint(1, 1000, (2, 10))
tgt = torch.randint(1, 1000, (2, 8))
logits = model(src, tgt)
print("Encoder-decoder logits shape:", logits.shape)  # (2, 8, 1000)

Exercise 2 — Trace Cross-Attention in T5

from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import matplotlib.pyplot as plt

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small", output_attentions=True)
model.eval()

# Translation task
input_text = "translate English to French: The sky is blue and the sun is shining."
target_text = "Le ciel est bleu et le soleil brille."

inputs = tokenizer(input_text, return_tensors="pt")
targets = tokenizer(target_text, return_tensors="pt").input_ids

with torch.no_grad():
    outputs = model(**inputs, labels=targets)

print("T5-small output keys:", outputs.keys())

# Cross-attention weights: shape (batch, n_heads, tgt_len, src_len)
cross_attentions = outputs.cross_attentions
print(f"\nNumber of decoder layers with cross-attention: {len(cross_attentions)}")
print(f"Cross-attention shape per layer: {cross_attentions[0].shape}")

# Visualize cross-attention for layer 0, head 0
src_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
tgt_tokens = tokenizer.convert_ids_to_tokens(targets[0])

layer, head = 0, 0
ca = cross_attentions[layer][0, head].numpy()  # (tgt_len, src_len)

fig, ax = plt.subplots(figsize=(12, 5))
ax.imshow(ca, cmap='Blues', aspect='auto')
ax.set_xticks(range(len(src_tokens)))
ax.set_yticks(range(len(tgt_tokens)))
ax.set_xticklabels(src_tokens, rotation=45, ha='right', fontsize=8)
ax.set_yticklabels(tgt_tokens, fontsize=9)
ax.set_xlabel("Source tokens (encoder)")
ax.set_ylabel("Target tokens (decoder)")
ax.set_title(f"T5 Cross-Attention — Layer {layer+1}, Head {head+1}")
plt.tight_layout()
plt.savefig('cross_attention.png', dpi=100)
plt.show()

Exercise 3 — Train on a Toy Sequence Reversal Task

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

class ReversalDataset(Dataset):
    """Reverse a sequence of integers."""
    def __init__(self, n_samples: int, seq_len: int = 8, vocab_size: int = 20):
        self.data = []
        for _ in range(n_samples):
            src = [random.randint(2, vocab_size - 1) for _ in range(seq_len)]
            tgt = [1] + src[::-1] + [0]  # 1=BOS, 0=EOS
            self.data.append((torch.tensor(src), torch.tensor(tgt)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


VOCAB_SIZE = 22  # 0=EOS, 1=BOS, 2–21=tokens
SEQ_LEN = 8

train_set = ReversalDataset(5000, SEQ_LEN, VOCAB_SIZE)
val_set   = ReversalDataset(500, SEQ_LEN, VOCAB_SIZE)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_set, batch_size=64)

model = EncoderDecoder(vocab_size=VOCAB_SIZE, d_model=64, n_heads=4,
                       n_encoder_layers=2, n_decoder_layers=2, max_len=32)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

for epoch in range(10):
    model.train()
    total_loss = 0
    for src, tgt in train_loader:
        tgt_in  = tgt[:, :-1]  # decoder input (without EOS)
        tgt_out = tgt[:, 1:]   # expected output (without BOS)

        logits = model(src, tgt_in)  # (batch, tgt_len-1, vocab)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), tgt_out.reshape(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

    if (epoch + 1) % 2 == 0:
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for src, tgt in val_loader:
                tgt_in = tgt[:, :-1]
                tgt_out = tgt[:, 1:]
                logits = model(src, tgt_in)
                preds = logits.argmax(dim=-1)
                correct += (preds == tgt_out).all(dim=-1).sum().item()
                total += src.size(0)
        print(f"Epoch {epoch+1:2d}: loss={total_loss/len(train_loader):.4f}, "
              f"exact_match={correct/total:.3f}")

Summary

  • Encoder-decoder models process input (encoder) and output (decoder) separately.
  • Cross-attention lets decoder positions attend to encoder outputs using encoder K/V.
  • T5 unifies all NLP tasks as text-to-text generation using task-specific prefixes.
  • Teacher forcing uses ground-truth tokens during training to avoid compounding errors.
  • Cross-attention weight matrices reveal which source tokens each target position attends to.

← Chapter 7 Chapter 9: KV Cache →