Chapter 12: Efficient Training and Inference at Scale¶
Learning Outcome¶
Apply techniques that make training and serving large transformer models feasible: parameter-efficient fine-tuning (LoRA), quantization, and an overview of distributed training strategies.
Concepts¶
LoRA — Low-Rank Adaptation¶
Fine-tuning all parameters of a large model is expensive. LoRA instead injects trainable low-rank matrices alongside frozen original weights:
where B ∈ ℝ^{d×r}, A ∈ ℝ^{r×k}, and r ≪ min(d, k) is the rank.
Only A and B are trained; W_0 is frozen.
Parameter savings example (GPT-2 attention W_Q, d=k=768, r=8):
- Full: 768 × 768 = 589,824 parameters
- LoRA: 768×8 + 8×768 = 12,288 parameters — 98% reduction
The peft library applies LoRA to all target linear layers automatically.
QLoRA¶
QLoRA combines 4-bit NF4 quantization with LoRA:
- Base model weights are quantized to 4-bit NF4 (stored frozen).
- LoRA adapters are trained in BF16.
- During the forward pass, base weights are dequantized on-the-fly.
This allows fine-tuning a 7B model on a single 24 GB GPU.
Post-Training Quantization (PTQ)¶
Quantize a trained model's weights to INT8 or INT4:
- INT8 weight quantization: ~2× memory savings, minimal accuracy loss.
- INT4/NF4: ~4× savings, larger accuracy loss (mitigated with careful calibration).
bitsandbytes provides drop-in replacements for nn.Linear that quantize on load.
Flash Attention¶
Standard attention materializes the full (seq, seq) attention matrix in GPU HBM
(high-bandwidth memory). For long sequences, this is the bottleneck.
Flash Attention uses a tiling algorithm to compute attention in blocks, keeping everything in fast SRAM:
- Memory complexity: O(n) instead of O(n²).
- Requires no code changes — same mathematical result.
- Enables much longer context windows (32k, 128k tokens).
# Enable Flash Attention in HuggingFace
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
Gradient Checkpointing¶
Trade compute for memory: discard activations during the forward pass, recompute them during backpropagation. Enables training larger models on smaller GPUs at ~33% extra compute cost.
Distributed Training Overview¶
| Strategy | What it splits | Library |
|---|---|---|
| Data Parallelism (DDP) | Input batches across GPUs | torch.nn.parallel.DistributedDataParallel |
| Tensor Parallelism | Weight matrices across GPUs | accelerate, DeepSpeed |
| Pipeline Parallelism | Layers across GPUs | DeepSpeed |
| ZeRO | Optimizer states across GPUs | DeepSpeed ZeRO |
For most fine-tuning use cases, DDP + LoRA + gradient checkpointing is sufficient.
Continuous Batching and PagedAttention (vLLM)¶
Traditional inference servers process one request at a time. vLLM uses:
- Continuous batching: dynamically add/remove requests mid-batch.
- PagedAttention: manage KV cache in fixed-size pages (like OS virtual memory).
Together these achieve 3–24× higher throughput than naive serving.
Exercise 1 — LoRA Fine-Tuning with peft¶
Guided Exercise
Apply LoRA to LLaMA-3.2-1B and compare memory and parameter counts.
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import torch
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load base model in BF16
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Print trainable parameters before LoRA
def count_params(model):
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
return trainable, total
trainable, total = count_params(base_model)
print(f"Before LoRA: {trainable:,} trainable / {total:,} total "
f"({100*trainable/total:.1f}%)")
# Apply LoRA
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8, # LoRA rank
lora_alpha=32, # scaling factor
target_modules=["q_proj", "v_proj"], # apply to Q and V projections
lora_dropout=0.05,
bias="none",
)
lora_model = get_peft_model(base_model, lora_config)
trainable, total = count_params(lora_model)
print(f"After LoRA: {trainable:,} trainable / {total:,} total "
f"({100*trainable/total:.2f}%)")
lora_model.print_trainable_parameters()
# Memory footprint
import torch
if torch.cuda.is_available():
print(f"\nGPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
Fine-tune on a custom instruction dataset¶
from datasets import Dataset
# Simple instruction-following dataset
instructions = [
{"prompt": "Summarize the following text: ...", "response": "..."},
{"prompt": "Translate to French: Hello world", "response": "Bonjour le monde"},
# ... more examples
]
def format_example(example):
text = f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['response']}"
return {"text": text}
# Tokenize and train with standard training loop
# (similar to Chapter 10 but with lora_model and much fewer trainable parameters)
Exercise 2 — Quantize GPT-2 and Benchmark Inference¶
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BitsAndBytesConfig
import torch
import time
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
prompt = "The future of artificial intelligence is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
# ── FP16 baseline ──────────────────────────────────────────────────────────────
model_fp16 = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).cuda()
model_fp16.eval()
def measure_inference(model, n_tokens=50, n_runs=3):
times = []
for _ in range(n_runs):
start = time.perf_counter()
with torch.no_grad():
out = model.generate(
input_ids, max_new_tokens=n_tokens, do_sample=False,
pad_token_id=tokenizer.eos_token_id, use_cache=True,
)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
return sum(times) / n_runs, n_tokens / (sum(times) / n_runs)
fp16_time, fp16_tps = measure_inference(model_fp16)
fp16_mem = torch.cuda.memory_allocated() / 1e9
del model_fp16
torch.cuda.empty_cache()
# ── INT8 quantization ─────────────────────────────────────────────────────────
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model_int8 = GPT2LMHeadModel.from_pretrained(
"gpt2", quantization_config=bnb_config, device_map="auto"
)
model_int8.eval()
int8_time, int8_tps = measure_inference(model_int8)
int8_mem = torch.cuda.memory_allocated() / 1e9
del model_int8
torch.cuda.empty_cache()
print(f"\n{'':12} {'Time (s)':>10} {'Tokens/s':>10} {'GPU GB':>8}")
print("-" * 45)
print(f"FP16 {fp16_time:>10.3f} {fp16_tps:>10.1f} {fp16_mem:>8.2f}")
print(f"INT8 {int8_time:>10.3f} {int8_tps:>10.1f} {int8_mem:>8.2f}")
print(f"\nINT8 memory savings: {fp16_mem/int8_mem:.1f}×")
Exercise 3 — Flash Attention Benchmark¶
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
def load_model(attn_impl: str):
"""Load GPT-2 with specified attention implementation."""
return AutoModelForCausalLM.from_pretrained(
"gpt2",
attn_implementation=attn_impl,
torch_dtype=torch.float16,
).cuda().eval()
def benchmark_attention(model, seq_len: int, n_runs: int = 5) -> float:
"""Measure forward pass time for a given sequence length."""
input_ids = torch.randint(1, 50257, (1, seq_len), device="cuda")
# Warmup
with torch.no_grad():
model(input_ids)
torch.cuda.synchronize()
# Measure
times = []
for _ in range(n_runs):
start = time.perf_counter()
with torch.no_grad():
model(input_ids)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
return sum(times) / n_runs
# Note: flash_attention_2 requires flash-attn package:
# pip install flash-attn --no-build-isolation
# GPT-2 is small enough that differences may be subtle;
# the effect is most dramatic on long sequences with large models.
print("Benchmarking attention implementations...")
print(f"\n{'Seq len':>10} {'Eager (ms)':>12} {'SDPA (ms)':>12}")
print("-" * 38)
model_eager = load_model("eager")
model_sdpa = load_model("sdpa") # PyTorch scaled dot-product attention (fused)
for seq_len in [256, 512, 1024]:
t_eager = benchmark_attention(model_eager, seq_len) * 1000
t_sdpa = benchmark_attention(model_sdpa, seq_len) * 1000
print(f"{seq_len:>10} {t_eager:>12.2f} {t_sdpa:>12.2f}")
# Memory usage
mem_eager = torch.cuda.max_memory_allocated() / 1e9
print(f"{'':10} max GPU mem: {mem_eager:.3f} GB")
torch.cuda.reset_peak_memory_stats()
Summary¶
- LoRA adds trainable low-rank matrices to frozen base weights, reducing trainable parameters by >99% with minimal quality loss.
- QLoRA combines 4-bit NF4 quantization with LoRA, enabling fine-tuning large models on consumer GPUs.
- INT8/INT4 quantization reduces memory by 2–4× with moderate accuracy trade-offs.
- Flash Attention computes attention in tiled SRAM blocks, reducing memory from O(n²) to O(n) and enabling much longer contexts.
- Gradient checkpointing trades ~33% extra compute for significantly reduced activation memory during training.
- vLLM's paged attention and continuous batching achieve high serving throughput by treating the KV cache like virtual memory.