Chapter 9: The KV Cache — Efficient Autoregressive Inference¶
Learning Outcome¶
Understand why autoregressive decoding is expensive without caching, implement a KV cache from scratch, and measure the concrete speedup on GPU.
Concepts¶
The Problem: Redundant Computation¶
Naive autoregressive decoding recomputes keys and values for all previous tokens at every step.
For a sequence of length n with d_model dimensions and L layers:
- Compute cost without cache: \(O(n^2 \cdot d_{model} \cdot L)\) total
- Memory reads: proportional to \(n^2\) — each token pays for all previous tokens
Generating a 1000-token response with naive decoding computes ~500,000 attention operations (summing 1+2+...+999 = 499,500).
The KV Cache¶
At each decode step, we only need to compute Q, K, V for the new token. The K and V vectors for all previous tokens are cached and reused.
Step 1 (prefill): process [t_1, t_2, t_3] → compute K/V for all 3, cache them
Step 2 (decode): process [t_4] only → compute new K/V, append to cache, attend over all 4
Step 3 (decode): process [t_5] only → compute new K/V, append to cache, attend over all 5
...
Per-step cost: O(n · d_model · L) instead of O(n² · d_model · L).
Memory Cost¶
The KV cache stores: [ 2 \times L \times H \times d_{head} \times S \times \text{bytes_per_elem} ]
For LLaMA 3 8B (L=32, H=8 KV heads, d_head=128, S=4096, BF16): [ 2 \times 32 \times 8 \times 128 \times 4096 \times 2 \approx 536 \text{ MB} ]
This is why long context windows (128k tokens) require careful memory management (paged attention, as in vLLM).
HuggingFace past_key_values¶
HuggingFace returns the KV cache as past_key_values:
outputs = model(input_ids, use_cache=True)
past_kv = outputs.past_key_values # tuple of (K, V) per layer
# Shape of each tensor: (batch, n_heads, seq_len, d_head)
At each subsequent step:
outputs = model(next_token_ids, past_key_values=past_kv, use_cache=True)
past_kv = outputs.past_key_values # now one position longer
Grouped-Query Attention (GQA)¶
LLaMA 2/3, Mistral, and others use fewer K/V heads than Q heads.
If there are G query groups, each group of n_heads/G Q heads shares one K/V head.
- Multi-Head Attention (MHA): Q=K=V=H heads.
- Multi-Query Attention (MQA): Q=H, K=V=1 head.
- Grouped-Query Attention (GQA): Q=H, K=V=H/G heads.
GQA reduces KV cache size by G× with minimal quality loss.
Exercise 1 — Add KV Cache to the GPT Implementation¶
Guided Exercise
Modify MultiHeadAttention and GPTModel to accept and return a KV cache.
import torch
import torch.nn as nn
import math
from typing import Optional
class CachedMultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.attn_drop = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
past_kv: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> tuple[torch.Tensor, Optional[tuple]]:
batch, seq_len, d_model = x.shape
# Compute Q, K, V for the current input
Q = self._split(self.W_q(x)) # (batch, n_heads, seq, d_head)
K = self._split(self.W_k(x))
V = self._split(self.W_v(x))
# Append past K/V if cache is provided
if past_kv is not None:
past_K, past_V = past_kv
K = torch.cat([past_K, K], dim=2) # extend along seq dimension
V = torch.cat([past_V, V], dim=2)
# Save current K/V to cache
new_kv = (K, V) if use_cache else None
# Attention
total_len = K.size(2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)
# Causal mask: current positions cannot attend to future positions
# For cached decoding, Q is a single step so this is just a row of all-attend
if seq_len > 1: # prefill or uncached
causal = torch.triu(
torch.full((seq_len, total_len), float('-inf'), device=x.device),
diagonal=total_len - seq_len + 1,
)
scores = scores + causal.unsqueeze(0).unsqueeze(0)
weights = torch.softmax(scores, dim=-1)
weights = self.attn_drop(weights)
out = torch.matmul(weights, V)
return self.W_o(self._merge(out)), new_kv
def _split(self, x):
b, s, _ = x.shape
return x.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
def _merge(self, x):
b, _, s, _ = x.shape
return x.transpose(1, 2).contiguous().view(b, s, -1)
class CachedTransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int | None = None,
dropout: float = 0.0):
super().__init__()
if d_ff is None:
d_ff = 4 * d_model
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.attn = CachedMultiHeadAttention(d_model, n_heads, dropout)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(self, x, past_kv=None, use_cache=False):
attn_out, new_kv = self.attn(self.norm1(x), past_kv=past_kv, use_cache=use_cache)
x = x + attn_out
x = x + self.ff(self.norm2(x))
return x, new_kv
class CachedGPT(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len=1024):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(max_len, d_model)
self.blocks = nn.ModuleList([
CachedTransformerBlock(d_model, n_heads)
for _ in range(n_layers)
])
self.norm_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.token_embed.weight
def forward(self, input_ids, past_key_values=None, use_cache=False):
batch, seq = input_ids.shape
start_pos = 0 if past_key_values is None else past_key_values[0][0].size(2)
positions = torch.arange(start_pos, start_pos + seq, device=input_ids.device)
x = self.token_embed(input_ids) + self.pos_embed(positions)
new_kvs = []
for i, block in enumerate(self.blocks):
past_kv = past_key_values[i] if past_key_values is not None else None
x, new_kv = block(x, past_kv=past_kv, use_cache=use_cache)
new_kvs.append(new_kv)
logits = self.lm_head(self.norm_f(x))
return logits, (new_kvs if use_cache else None)
@torch.no_grad()
def generate_cached(self, input_ids, max_new_tokens):
# Prefill phase: process entire prompt at once
logits, past_kv = self(input_ids, use_cache=True)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
all_ids = torch.cat([input_ids, next_token], dim=1)
# Decode phase: one token at a time
for _ in range(max_new_tokens - 1):
logits, past_kv = self(next_token, past_key_values=past_kv, use_cache=True)
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
all_ids = torch.cat([all_ids, next_token], dim=1)
return all_ids
Exercise 2 — Benchmark Cached vs. Uncached Decoding¶
import time
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Benchmarking on: {device}")
model = CachedGPT(vocab_size=1000, d_model=256, n_heads=8, n_layers=6).to(device)
model.eval()
def naive_generate(model, input_ids, n_tokens):
"""No cache: re-runs full forward at each step."""
ids = input_ids
for _ in range(n_tokens):
logits, _ = model(ids, use_cache=False)
next_tok = logits[:, -1, :].argmax(dim=-1, keepdim=True)
ids = torch.cat([ids, next_tok], dim=1)
return ids
prompt = torch.randint(1, 1000, (1, 8), device=device)
print(f"\n{'Tokens':>8} {'Naive (s)':>12} {'Cached (s)':>12} {'Speedup':>8}")
print("-" * 50)
for n_tokens in [10, 50, 100, 200]:
# Naive
torch.manual_seed(0)
start = time.perf_counter()
with torch.no_grad():
naive_generate(model, prompt, n_tokens)
t_naive = time.perf_counter() - start
# Cached
torch.manual_seed(0)
start = time.perf_counter()
with torch.no_grad():
model.generate_cached(prompt, n_tokens)
t_cached = time.perf_counter() - start
speedup = t_naive / t_cached
print(f"{n_tokens:>8} {t_naive:>12.3f} {t_cached:>12.3f} {speedup:>7.1f}x")
Exercise 3 — Verify KV Cache Correctness with HuggingFace¶
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
prompt = "The quick brown fox"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
# Without cache
out_no_cache = model(input_ids, use_cache=False)
# With cache — prefill
out_with_cache = model(input_ids, use_cache=True)
past_kv = out_with_cache.past_key_values
# Verify logits are identical
diff = (out_no_cache.logits - out_with_cache.logits).abs().max().item()
print(f"Logit diff (no cache vs with cache, prefill): {diff:.2e}")
# Check that KV tensors grow by 1 each decode step
print(f"\nKV cache shape after prefill: {past_kv[0][0].shape}")
# (batch=1, n_heads=12, seq_len=4, d_head=64)
# Decode one step
next_tok = out_with_cache.logits[:, -1, :].argmax(dim=-1, keepdim=True)
out_step2 = model(next_tok, past_key_values=past_kv, use_cache=True)
past_kv2 = out_step2.past_key_values
print(f"KV cache shape after 1 decode step: {past_kv2[0][0].shape}")
# seq_len should be 5 now — increased by 1
assert past_kv2[0][0].shape[2] == past_kv[0][0].shape[2] + 1
print("✓ KV cache grows by 1 position per decode step")
Summary¶
- Without a KV cache, each generation step recomputes K/V for all previous tokens: O(n²) total compute.
- The KV cache stores past K/V tensors, reducing each step to O(n) amortized.
- Memory cost is proportional to
n_layers × n_kv_heads × d_head × seq_len. - GQA (LLaMA 2/3, Mistral) reduces KV head count to save cache memory.
- HuggingFace implements caching via
past_key_values;use_cache=Trueenables it.