Skip to content

Chapter 9: The KV Cache — Efficient Autoregressive Inference

Learning Outcome

Understand why autoregressive decoding is expensive without caching, implement a KV cache from scratch, and measure the concrete speedup on GPU.


Concepts

The Problem: Redundant Computation

Naive autoregressive decoding recomputes keys and values for all previous tokens at every step.

For a sequence of length n with d_model dimensions and L layers: - Compute cost without cache: \(O(n^2 \cdot d_{model} \cdot L)\) total - Memory reads: proportional to \(n^2\) — each token pays for all previous tokens

Generating a 1000-token response with naive decoding computes ~500,000 attention operations (summing 1+2+...+999 = 499,500).

The KV Cache

At each decode step, we only need to compute Q, K, V for the new token. The K and V vectors for all previous tokens are cached and reused.

Step 1 (prefill): process [t_1, t_2, t_3] → compute K/V for all 3, cache them
Step 2 (decode):  process [t_4] only → compute new K/V, append to cache, attend over all 4
Step 3 (decode):  process [t_5] only → compute new K/V, append to cache, attend over all 5
...

Per-step cost: O(n · d_model · L) instead of O(n² · d_model · L).

Memory Cost

The KV cache stores: [ 2 \times L \times H \times d_{head} \times S \times \text{bytes_per_elem} ]

For LLaMA 3 8B (L=32, H=8 KV heads, d_head=128, S=4096, BF16): [ 2 \times 32 \times 8 \times 128 \times 4096 \times 2 \approx 536 \text{ MB} ]

This is why long context windows (128k tokens) require careful memory management (paged attention, as in vLLM).

HuggingFace past_key_values

HuggingFace returns the KV cache as past_key_values:

outputs = model(input_ids, use_cache=True)
past_kv = outputs.past_key_values  # tuple of (K, V) per layer
# Shape of each tensor: (batch, n_heads, seq_len, d_head)

At each subsequent step:

outputs = model(next_token_ids, past_key_values=past_kv, use_cache=True)
past_kv = outputs.past_key_values  # now one position longer

Grouped-Query Attention (GQA)

LLaMA 2/3, Mistral, and others use fewer K/V heads than Q heads. If there are G query groups, each group of n_heads/G Q heads shares one K/V head.

  • Multi-Head Attention (MHA): Q=K=V=H heads.
  • Multi-Query Attention (MQA): Q=H, K=V=1 head.
  • Grouped-Query Attention (GQA): Q=H, K=V=H/G heads.

GQA reduces KV cache size by with minimal quality loss.


Exercise 1 — Add KV Cache to the GPT Implementation

Guided Exercise

Modify MultiHeadAttention and GPTModel to accept and return a KV cache.

import torch
import torch.nn as nn
import math
from typing import Optional


class CachedMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        past_kv: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False,
    ) -> tuple[torch.Tensor, Optional[tuple]]:
        batch, seq_len, d_model = x.shape

        # Compute Q, K, V for the current input
        Q = self._split(self.W_q(x))  # (batch, n_heads, seq, d_head)
        K = self._split(self.W_k(x))
        V = self._split(self.W_v(x))

        # Append past K/V if cache is provided
        if past_kv is not None:
            past_K, past_V = past_kv
            K = torch.cat([past_K, K], dim=2)  # extend along seq dimension
            V = torch.cat([past_V, V], dim=2)

        # Save current K/V to cache
        new_kv = (K, V) if use_cache else None

        # Attention
        total_len = K.size(2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)

        # Causal mask: current positions cannot attend to future positions
        # For cached decoding, Q is a single step so this is just a row of all-attend
        if seq_len > 1:  # prefill or uncached
            causal = torch.triu(
                torch.full((seq_len, total_len), float('-inf'), device=x.device),
                diagonal=total_len - seq_len + 1,
            )
            scores = scores + causal.unsqueeze(0).unsqueeze(0)

        weights = torch.softmax(scores, dim=-1)
        weights = self.attn_drop(weights)
        out = torch.matmul(weights, V)

        return self.W_o(self._merge(out)), new_kv

    def _split(self, x):
        b, s, _ = x.shape
        return x.view(b, s, self.n_heads, self.d_head).transpose(1, 2)

    def _merge(self, x):
        b, _, s, _ = x.shape
        return x.transpose(1, 2).contiguous().view(b, s, -1)


class CachedTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int | None = None,
                 dropout: float = 0.0):
        super().__init__()
        if d_ff is None:
            d_ff = 4 * d_model
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attn = CachedMultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)

    def forward(self, x, past_kv=None, use_cache=False):
        attn_out, new_kv = self.attn(self.norm1(x), past_kv=past_kv, use_cache=use_cache)
        x = x + attn_out
        x = x + self.ff(self.norm2(x))
        return x, new_kv


class CachedGPT(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len=1024):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.blocks = nn.ModuleList([
            CachedTransformerBlock(d_model, n_heads)
            for _ in range(n_layers)
        ])
        self.norm_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.token_embed.weight

    def forward(self, input_ids, past_key_values=None, use_cache=False):
        batch, seq = input_ids.shape
        start_pos = 0 if past_key_values is None else past_key_values[0][0].size(2)
        positions = torch.arange(start_pos, start_pos + seq, device=input_ids.device)

        x = self.token_embed(input_ids) + self.pos_embed(positions)

        new_kvs = []
        for i, block in enumerate(self.blocks):
            past_kv = past_key_values[i] if past_key_values is not None else None
            x, new_kv = block(x, past_kv=past_kv, use_cache=use_cache)
            new_kvs.append(new_kv)

        logits = self.lm_head(self.norm_f(x))
        return logits, (new_kvs if use_cache else None)

    @torch.no_grad()
    def generate_cached(self, input_ids, max_new_tokens):
        # Prefill phase: process entire prompt at once
        logits, past_kv = self(input_ids, use_cache=True)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        all_ids = torch.cat([input_ids, next_token], dim=1)

        # Decode phase: one token at a time
        for _ in range(max_new_tokens - 1):
            logits, past_kv = self(next_token, past_key_values=past_kv, use_cache=True)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            all_ids = torch.cat([all_ids, next_token], dim=1)

        return all_ids

Exercise 2 — Benchmark Cached vs. Uncached Decoding

import time
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Benchmarking on: {device}")

model = CachedGPT(vocab_size=1000, d_model=256, n_heads=8, n_layers=6).to(device)
model.eval()

def naive_generate(model, input_ids, n_tokens):
    """No cache: re-runs full forward at each step."""
    ids = input_ids
    for _ in range(n_tokens):
        logits, _ = model(ids, use_cache=False)
        next_tok = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        ids = torch.cat([ids, next_tok], dim=1)
    return ids

prompt = torch.randint(1, 1000, (1, 8), device=device)

print(f"\n{'Tokens':>8}  {'Naive (s)':>12}  {'Cached (s)':>12}  {'Speedup':>8}")
print("-" * 50)
for n_tokens in [10, 50, 100, 200]:
    # Naive
    torch.manual_seed(0)
    start = time.perf_counter()
    with torch.no_grad():
        naive_generate(model, prompt, n_tokens)
    t_naive = time.perf_counter() - start

    # Cached
    torch.manual_seed(0)
    start = time.perf_counter()
    with torch.no_grad():
        model.generate_cached(prompt, n_tokens)
    t_cached = time.perf_counter() - start

    speedup = t_naive / t_cached
    print(f"{n_tokens:>8}  {t_naive:>12.3f}  {t_cached:>12.3f}  {speedup:>7.1f}x")

Exercise 3 — Verify KV Cache Correctness with HuggingFace

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()

prompt = "The quick brown fox"
input_ids = tokenizer.encode(prompt, return_tensors="pt")

with torch.no_grad():
    # Without cache
    out_no_cache = model(input_ids, use_cache=False)

    # With cache — prefill
    out_with_cache = model(input_ids, use_cache=True)
    past_kv = out_with_cache.past_key_values

    # Verify logits are identical
    diff = (out_no_cache.logits - out_with_cache.logits).abs().max().item()
    print(f"Logit diff (no cache vs with cache, prefill): {diff:.2e}")

    # Check that KV tensors grow by 1 each decode step
    print(f"\nKV cache shape after prefill: {past_kv[0][0].shape}")
    # (batch=1, n_heads=12, seq_len=4, d_head=64)

    # Decode one step
    next_tok = out_with_cache.logits[:, -1, :].argmax(dim=-1, keepdim=True)
    out_step2 = model(next_tok, past_key_values=past_kv, use_cache=True)
    past_kv2 = out_step2.past_key_values

    print(f"KV cache shape after 1 decode step: {past_kv2[0][0].shape}")
    # seq_len should be 5 now — increased by 1
    assert past_kv2[0][0].shape[2] == past_kv[0][0].shape[2] + 1
    print("✓ KV cache grows by 1 position per decode step")

Summary

  • Without a KV cache, each generation step recomputes K/V for all previous tokens: O(n²) total compute.
  • The KV cache stores past K/V tensors, reducing each step to O(n) amortized.
  • Memory cost is proportional to n_layers × n_kv_heads × d_head × seq_len.
  • GQA (LLaMA 2/3, Mistral) reduces KV head count to save cache memory.
  • HuggingFace implements caching via past_key_values; use_cache=True enables it.

← Chapter 8 Chapter 10: Training from Scratch →