Skip to content

Chapter 5: The Transformer Block — Feed-Forward, LayerNorm, and Residuals

Learning Outcome

Implement a complete transformer block (attention + FFN + residual + normalization) and understand the design choices that make deep transformer stacks trainable.


Concepts

Position-Wise Feed-Forward Network (FFN)

After attention, each position is independently passed through a two-layer MLP:

\[ \text{FFN}(x) = \text{GELU}(x W_1 + b_1) W_2 + b_2 \]
  • Width: 4 × d_model by convention. This is where most model parameters live. For GPT-2 (d_model=768): 768 → 3072 → 768.
  • Activation: GELU in BERT/GPT-2; SwiGLU in LLaMA (uses a gating mechanism).

GELU vs. SwiGLU

GELU (Gaussian Error Linear Unit): [ \text{GELU}(x) = x \cdot \Phi(x) ] where \(\Phi\) is the standard normal CDF. Smoother than ReLU.

SwiGLU (used in LLaMA, PaLM): [ \text{SwiGLU}(x, W, V) = \text{Swish}(xW) \cdot xV ] Uses three projections (gate, up, down) and a gating mechanism. More parameter-efficient at the same perplexity.

Layer Normalization

LayerNorm normalizes across the feature dimension (not the batch dimension):

\[ \text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma + \epsilon} + \beta \]

where μ and σ are computed over the d_model dimension for each token independently. γ and β are learnable per-dimension parameters.

BatchNorm vs. LayerNorm:

BatchNorm LayerNorm
Normalizes over Batch dimension Feature dimension
Works with batch size 1 No Yes
Requires running stats at inference Yes No
Works for variable-length sequences Poorly Yes

Pre-Norm vs. Post-Norm

The original transformer paper (Vaswani 2017) used post-norm:

x = LayerNorm(x + Attention(x))
x = LayerNorm(x + FFN(x))

Modern models (GPT-2, LLaMA) use pre-norm (norm before the sublayer):

x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))

Pre-norm is more stable during training because the residual stream always carries unnormalized values, ensuring gradient flow even early in training.

Residual Connections

Residual (skip) connections add the input directly to the output:

x = x + sublayer(x)

This creates a "gradient highway": gradients can flow directly back through the addition without passing through all the non-linearities of the sublayer. Deep networks (100+ layers) would be essentially untrainable without residuals.


Exercise 1 — Implement a Pre-Norm TransformerBlock

Guided Exercise

Build a complete transformer block using the attention module from Chapter 3.

import torch
import torch.nn as nn
import torch.nn.functional as F


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.dropout(F.gelu(self.fc1(x))))


class TransformerBlock(nn.Module):
    """Pre-norm transformer block (GPT-2 / LLaMA style)."""

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int | None = None,
        dropout: float = 0.1,
        causal: bool = True,
        max_len: int = 2048,
    ):
        super().__init__()
        if d_ff is None:
            d_ff = 4 * d_model

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attn = CausalMultiHeadAttention(d_model, n_heads, max_len=max_len,
                                              dropout=dropout) if causal \
                    else MultiHeadAttention(d_model, n_heads, dropout=dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Pre-norm attention + residual
        x = x + self.attn(self.norm1(x), attention_mask=attention_mask)
        # Pre-norm FFN + residual
        x = x + self.ff(self.norm2(x))
        return x


# Test
block = TransformerBlock(d_model=256, n_heads=8, dropout=0.0)
x = torch.randn(2, 16, 256)
out = block(x)
print("TransformerBlock output shape:", out.shape)  # (2, 16, 256)

# Count parameters in one block (d_model=768, n_heads=12, d_ff=3072)
block_768 = TransformerBlock(d_model=768, n_heads=12, d_ff=3072)
n_params = sum(p.numel() for p in block_768.parameters())
print(f"\nParameters in one TransformerBlock (d=768): {n_params:,}")
# Expected: ~7.1M per block for GPT-2 small

Exercise 2 — Stack 12 Blocks and Verify Parameter Count

class SmallGPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        n_heads: int,
        n_layers: int,
        max_len: int = 1024,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout=dropout, max_len=max_len)
            for _ in range(n_layers)
        ])
        self.norm_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        batch, seq_len = input_ids.shape
        positions = torch.arange(seq_len, device=input_ids.device)
        x = self.drop(self.token_embed(input_ids) + self.pos_embed(positions))
        for block in self.blocks:
            x = block(x)
        x = self.norm_f(x)
        return self.lm_head(x)  # (batch, seq, vocab_size) logits


# GPT-2 small configuration
gpt2_small = SmallGPT(
    vocab_size=50257,
    d_model=768,
    n_heads=12,
    n_layers=12,
    max_len=1024,
    dropout=0.1,
)

total_params = sum(p.numel() for p in gpt2_small.parameters())
print(f"Total parameters: {total_params:,}")
# GPT-2 small: ~117M parameters

# Quick forward pass
ids = torch.randint(0, 50257, (1, 32))
logits = gpt2_small(ids)
print(f"Logits shape: {logits.shape}")  # (1, 32, 50257)

Exercise 3 — Load GPT-2 Weights and Verify Output

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
hf_model = GPT2LMHeadModel.from_pretrained("gpt2")
hf_model.eval()

sentence = "The transformer architecture changed natural language"
input_ids = tokenizer.encode(sentence, return_tensors="pt")

# Get HuggingFace logits
with torch.no_grad():
    hf_logits = hf_model(input_ids).logits

print("HuggingFace logits shape:", hf_logits.shape)
print("Top 5 next-token predictions:")
top5 = hf_logits[0, -1].topk(5)
for tok_id, score in zip(top5.indices, top5.values):
    print(f"  {tokenizer.decode([tok_id])!r:20s}  logit={score.item():.3f}")

Now let's load the weights into our custom implementation:

def load_gpt2_weights(custom_model: SmallGPT, hf_model: GPT2LMHeadModel):
    """Load GPT-2 weights from HuggingFace into our custom model."""
    hf_state = hf_model.state_dict()
    custom_state = custom_model.state_dict()

    mapping = {
        "token_embed.weight": "transformer.wte.weight",
        "pos_embed.weight": "transformer.wpe.weight",
        "norm_f.weight": "transformer.ln_f.weight",
        "norm_f.bias": "transformer.ln_f.bias",
    }

    # Map transformer blocks
    for i in range(12):
        prefix = f"blocks.{i}"
        hf_prefix = f"transformer.h.{i}"
        block_mapping = {
            f"{prefix}.norm1.weight": f"{hf_prefix}.ln_1.weight",
            f"{prefix}.norm1.bias": f"{hf_prefix}.ln_1.bias",
            f"{prefix}.norm2.weight": f"{hf_prefix}.ln_2.weight",
            f"{prefix}.norm2.bias": f"{hf_prefix}.ln_2.bias",
            f"{prefix}.ff.fc1.weight": f"{hf_prefix}.mlp.c_fc.weight",
            f"{prefix}.ff.fc1.bias": f"{hf_prefix}.mlp.c_fc.bias",
            f"{prefix}.ff.fc2.weight": f"{hf_prefix}.mlp.c_proj.weight",
            f"{prefix}.ff.fc2.bias": f"{hf_prefix}.mlp.c_proj.bias",
        }
        mapping.update(block_mapping)

    # Load mapped weights
    for custom_key, hf_key in mapping.items():
        if custom_key in custom_state and hf_key in hf_state:
            custom_state[custom_key].copy_(hf_state[hf_key])

    # Attention weights (GPT-2 uses a fused QKV projection)
    # c_attn is (3*d_model, d_model); we split it into W_q, W_k, W_v
    for i in range(12):
        prefix = f"blocks.{i}"
        hf_prefix = f"transformer.h.{i}"
        c_attn_w = hf_state[f"{hf_prefix}.attn.c_attn.weight"]  # (768, 2304)
        c_attn_b = hf_state[f"{hf_prefix}.attn.c_attn.bias"]    # (2304,)
        # Note: GPT-2 uses Conv1D (transposed), split along last dim
        W_q, W_k, W_v = c_attn_w.split(768, dim=1)
        b_q, b_k, b_v = c_attn_b.split(768, dim=0)

        # Our Linear has shape (out, in); GPT-2 Conv1D is (in, out) → need to transpose
        custom_state[f"{prefix}.attn.W_q.weight"].copy_(W_q.T)
        custom_state[f"{prefix}.attn.W_k.weight"].copy_(W_k.T)
        custom_state[f"{prefix}.attn.W_v.weight"].copy_(W_v.T)

    custom_model.load_state_dict(custom_state)
    print("Weights loaded successfully.")


# This weight loading is illustrative; a full implementation would also handle
# the output projection (c_proj) and LM head weight tying.
print("Weight loading structure demonstrated above.")
print("Full implementation left as exercise to complete in session.")

Summary

  • The transformer block has two sublayers: multi-head attention and FFN, each wrapped with a residual connection and layer normalization.
  • Pre-norm (modern) places LayerNorm before each sublayer; post-norm (original) places it after.
  • The FFN uses 4 × d_model hidden dimension and GELU (or SwiGLU in modern models).
  • Residual connections are critical for training deep networks — they create a gradient highway that bypasses the non-linearities.
  • A 12-layer GPT-2 small has ~117M parameters, mostly in the 12 FFN layers.

← Chapter 4 Chapter 6: Encoder-Only →