Skip to content

Chapter 12: Efficient Training and Inference at Scale

Learning Outcome

Apply techniques that make training and serving large transformer models feasible: parameter-efficient fine-tuning (LoRA), quantization, and an overview of distributed training strategies.


Concepts

LoRA — Low-Rank Adaptation

Fine-tuning all parameters of a large model is expensive. LoRA instead injects trainable low-rank matrices alongside frozen original weights:

\[ W' = W_0 + \Delta W = W_0 + BA \]

where B ∈ ℝ^{d×r}, A ∈ ℝ^{r×k}, and r ≪ min(d, k) is the rank. Only A and B are trained; W_0 is frozen.

Parameter savings example (GPT-2 attention W_Q, d=k=768, r=8): - Full: 768 × 768 = 589,824 parameters - LoRA: 768×8 + 8×768 = 12,288 parameters — 98% reduction

The peft library applies LoRA to all target linear layers automatically.

QLoRA

QLoRA combines 4-bit NF4 quantization with LoRA:

  1. Base model weights are quantized to 4-bit NF4 (stored frozen).
  2. LoRA adapters are trained in BF16.
  3. During the forward pass, base weights are dequantized on-the-fly.

This allows fine-tuning a 7B model on a single 24 GB GPU.

Post-Training Quantization (PTQ)

Quantize a trained model's weights to INT8 or INT4:

  • INT8 weight quantization: ~2× memory savings, minimal accuracy loss.
  • INT4/NF4: ~4× savings, larger accuracy loss (mitigated with careful calibration).

bitsandbytes provides drop-in replacements for nn.Linear that quantize on load.

Flash Attention

Standard attention materializes the full (seq, seq) attention matrix in GPU HBM (high-bandwidth memory). For long sequences, this is the bottleneck.

Flash Attention uses a tiling algorithm to compute attention in blocks, keeping everything in fast SRAM:

  • Memory complexity: O(n) instead of O(n²).
  • Requires no code changes — same mathematical result.
  • Enables much longer context windows (32k, 128k tokens).
# Enable Flash Attention in HuggingFace
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

Gradient Checkpointing

Trade compute for memory: discard activations during the forward pass, recompute them during backpropagation. Enables training larger models on smaller GPUs at ~33% extra compute cost.

model.gradient_checkpointing_enable()

Distributed Training Overview

Strategy What it splits Library
Data Parallelism (DDP) Input batches across GPUs torch.nn.parallel.DistributedDataParallel
Tensor Parallelism Weight matrices across GPUs accelerate, DeepSpeed
Pipeline Parallelism Layers across GPUs DeepSpeed
ZeRO Optimizer states across GPUs DeepSpeed ZeRO

For most fine-tuning use cases, DDP + LoRA + gradient checkpointing is sufficient.

Continuous Batching and PagedAttention (vLLM)

Traditional inference servers process one request at a time. vLLM uses:

  • Continuous batching: dynamically add/remove requests mid-batch.
  • PagedAttention: manage KV cache in fixed-size pages (like OS virtual memory).

Together these achieve 3–24× higher throughput than naive serving.


Exercise 1 — LoRA Fine-Tuning with peft

Guided Exercise

Apply LoRA to LLaMA-3.2-1B and compare memory and parameter counts.

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import torch


model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load base model in BF16
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Print trainable parameters before LoRA
def count_params(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total

trainable, total = count_params(base_model)
print(f"Before LoRA: {trainable:,} trainable / {total:,} total "
      f"({100*trainable/total:.1f}%)")

# Apply LoRA
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,                          # LoRA rank
    lora_alpha=32,                # scaling factor
    target_modules=["q_proj", "v_proj"],  # apply to Q and V projections
    lora_dropout=0.05,
    bias="none",
)

lora_model = get_peft_model(base_model, lora_config)

trainable, total = count_params(lora_model)
print(f"After LoRA:  {trainable:,} trainable / {total:,} total "
      f"({100*trainable/total:.2f}%)")
lora_model.print_trainable_parameters()

# Memory footprint
import torch
if torch.cuda.is_available():
    print(f"\nGPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"GPU memory reserved:  {torch.cuda.memory_reserved() / 1e9:.2f} GB")

Fine-tune on a custom instruction dataset

from datasets import Dataset


# Simple instruction-following dataset
instructions = [
    {"prompt": "Summarize the following text: ...", "response": "..."},
    {"prompt": "Translate to French: Hello world", "response": "Bonjour le monde"},
    # ... more examples
]

def format_example(example):
    text = f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['response']}"
    return {"text": text}

# Tokenize and train with standard training loop
# (similar to Chapter 10 but with lora_model and much fewer trainable parameters)

Exercise 2 — Quantize GPT-2 and Benchmark Inference

from transformers import GPT2LMHeadModel, GPT2Tokenizer, BitsAndBytesConfig
import torch
import time

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

prompt = "The future of artificial intelligence is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()

# ── FP16 baseline ──────────────────────────────────────────────────────────────
model_fp16 = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).cuda()
model_fp16.eval()

def measure_inference(model, n_tokens=50, n_runs=3):
    times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        with torch.no_grad():
            out = model.generate(
                input_ids, max_new_tokens=n_tokens, do_sample=False,
                pad_token_id=tokenizer.eos_token_id, use_cache=True,
            )
        torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    return sum(times) / n_runs, n_tokens / (sum(times) / n_runs)

fp16_time, fp16_tps = measure_inference(model_fp16)
fp16_mem = torch.cuda.memory_allocated() / 1e9
del model_fp16
torch.cuda.empty_cache()

# ── INT8 quantization ─────────────────────────────────────────────────────────
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model_int8 = GPT2LMHeadModel.from_pretrained(
    "gpt2", quantization_config=bnb_config, device_map="auto"
)
model_int8.eval()

int8_time, int8_tps = measure_inference(model_int8)
int8_mem = torch.cuda.memory_allocated() / 1e9
del model_int8
torch.cuda.empty_cache()

print(f"\n{'':12}  {'Time (s)':>10}  {'Tokens/s':>10}  {'GPU GB':>8}")
print("-" * 45)
print(f"FP16       {fp16_time:>10.3f}  {fp16_tps:>10.1f}  {fp16_mem:>8.2f}")
print(f"INT8       {int8_time:>10.3f}  {int8_tps:>10.1f}  {int8_mem:>8.2f}")
print(f"\nINT8 memory savings: {fp16_mem/int8_mem:.1f}×")

Exercise 3 — Flash Attention Benchmark

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time


def load_model(attn_impl: str):
    """Load GPT-2 with specified attention implementation."""
    return AutoModelForCausalLM.from_pretrained(
        "gpt2",
        attn_implementation=attn_impl,
        torch_dtype=torch.float16,
    ).cuda().eval()


def benchmark_attention(model, seq_len: int, n_runs: int = 5) -> float:
    """Measure forward pass time for a given sequence length."""
    input_ids = torch.randint(1, 50257, (1, seq_len), device="cuda")
    # Warmup
    with torch.no_grad():
        model(input_ids)
    torch.cuda.synchronize()
    # Measure
    times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        with torch.no_grad():
            model(input_ids)
        torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    return sum(times) / n_runs


# Note: flash_attention_2 requires flash-attn package:
# pip install flash-attn --no-build-isolation
# GPT-2 is small enough that differences may be subtle;
# the effect is most dramatic on long sequences with large models.

print("Benchmarking attention implementations...")
print(f"\n{'Seq len':>10}  {'Eager (ms)':>12}  {'SDPA (ms)':>12}")
print("-" * 38)

model_eager = load_model("eager")
model_sdpa  = load_model("sdpa")  # PyTorch scaled dot-product attention (fused)

for seq_len in [256, 512, 1024]:
    t_eager = benchmark_attention(model_eager, seq_len) * 1000
    t_sdpa  = benchmark_attention(model_sdpa,  seq_len) * 1000
    print(f"{seq_len:>10}  {t_eager:>12.2f}  {t_sdpa:>12.2f}")

    # Memory usage
    mem_eager = torch.cuda.max_memory_allocated() / 1e9
    print(f"{'':10}  max GPU mem: {mem_eager:.3f} GB")
    torch.cuda.reset_peak_memory_stats()

Summary

  • LoRA adds trainable low-rank matrices to frozen base weights, reducing trainable parameters by >99% with minimal quality loss.
  • QLoRA combines 4-bit NF4 quantization with LoRA, enabling fine-tuning large models on consumer GPUs.
  • INT8/INT4 quantization reduces memory by 2–4× with moderate accuracy trade-offs.
  • Flash Attention computes attention in tiled SRAM blocks, reducing memory from O(n²) to O(n) and enabling much longer contexts.
  • Gradient checkpointing trades ~33% extra compute for significantly reduced activation memory during training.
  • vLLM's paged attention and continuous batching achieve high serving throughput by treating the KV cache like virtual memory.

← Chapter 11 Chapter 13: BERT Pre-Training →