Understanding GPU Memory
The dreaded CUDA out of memory error stops ML engineers daily. Understanding what consumes GPU memory is the first step to optimization.
Memory Breakdown During Training
| Component | Memory Usage | Notes |
|---|---|---|
| Model Parameters | 2 bytes/param (FP16) | 7B model = 14GB |
| Gradients | Same as params | +14GB for 7B |
| Optimizer States | 4-8 bytes/param (Adam) | +28-56GB for 7B |
| Activations | Varies by batch/seq | Often largest! |
Total for training 7B in FP16 with Adam: ~100GB!
💡 What you'll learn
Techniques to reduce memory by 50-80%, letting you train larger models on smaller GPUs or increase batch sizes on existing hardware.
Technique 1: Mixed Precision Training
Use 16-bit precision instead of 32-bit—half the memory, faster compute:
import torch
from torch.cuda.amp import autocast, GradScaler
model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters())
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in FP16
with autocast():
outputs = model(batch)
loss = criterion(outputs)
# Backward pass with scaled gradients
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
BF16 vs FP16
- FP16: Works on all GPUs, needs loss scaling
- BF16: Better numerical stability, no scaling needed (A100+, RTX 30/40 series)
# BF16 is simpler (no scaler needed)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
outputs = model(batch)
loss = criterion(outputs)
loss.backward()
optimizer.step()
Memory savings: ~50% for model + gradients
Technique 2: Gradient Checkpointing
Trade compute for memory by recomputing activations during backward pass:
from torch.utils.checkpoint import checkpoint_sequential, checkpoint
# For sequential models
model = nn.Sequential(
Layer1(),
Layer2(),
# ... many layers
)
def forward(x):
# Split into chunks, checkpoint each
return checkpoint_sequential(model, chunks=4, input=x)
# For transformer models (HuggingFace)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.gradient_checkpointing_enable() # Single line!
⚠️ Speed Tradeoff
Gradient checkpointing increases training time by ~20-30% due to recomputation. Worth it when memory-constrained.
Memory savings: 50-70% on activations
Technique 3: Flash Attention
Flash Attention computes attention in tiles, dramatically reducing memory:
# Standard attention: O(n²) memory
# Flash Attention: O(n) memory!
# HuggingFace models
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
attn_implementation="flash_attention_2", # Enable Flash Attention
torch_dtype=torch.bfloat16,
)
# PyTorch native (2.0+)
from torch.nn.functional import scaled_dot_product_attention
# Automatically uses Flash Attention when possible
output = scaled_dot_product_attention(query, key, value)
Memory Comparison (8K context)
| Attention Type | Memory | Speed |
|---|---|---|
| Standard | ~8GB | 1x |
| Flash Attention 2 | ~0.5GB | 2-4x faster |
Memory savings: 10-20x for attention layers!
Technique 4: 8-bit Optimizers
Adam optimizer states consume 8 bytes per parameter. Use 8-bit versions:
import bitsandbytes as bnb
# Replace AdamW with 8-bit version
optimizer = bnb.optim.AdamW8bit(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.999)
)
# Works exactly like regular AdamW
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch).loss
loss.backward()
optimizer.step()
Memory savings: 75% on optimizer states
Technique 5: Gradient Accumulation
Simulate larger batches without more memory:
accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps
for i, batch in enumerate(dataloader):
# Forward pass
loss = model(batch).loss / accumulation_steps
loss.backward()
# Update weights every N steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
This doesn't reduce memory per batch but lets you achieve the same effective batch size with smaller actual batches.
Technique 6: CPU Offloading
Move unused data to CPU RAM:
# DeepSpeed ZeRO with CPU offload
ds_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
}
}
}
# HuggingFace Accelerate
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision="bf16",
gradient_accumulation_steps=4,
cpu_offload=True
)
⚠️ Speed Impact
CPU offloading significantly slows training (2-5x). Use only when you can't fit the model otherwise.
Technique 7: Model Sharding
Split model across multiple GPUs:
# Automatic device mapping
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70B",
device_map="auto", # Automatically splits across GPUs
torch_dtype=torch.bfloat16
)
# Manual device mapping
device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": 0,
# ... assign layers to GPUs
"model.layers.78": 1,
"model.layers.79": 1,
"model.norm": 1,
"lm_head": 1,
}
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70B",
device_map=device_map,
torch_dtype=torch.bfloat16
)
Technique 8: Efficient Data Loading
Data loading can consume significant GPU memory:
# Pin memory for faster transfers
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=4,
pin_memory=True, # Faster CPU→GPU transfer
persistent_workers=True # Keep workers alive
)
# Use memory-mapped datasets
from datasets import load_dataset
dataset = load_dataset(
"path/to/data",
streaming=True # Don't load entire dataset
)
Memory Profiling
Find what's consuming memory:
import torch
# Check current allocation
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
# Detailed breakdown
print(torch.cuda.memory_summary())
# Profile memory over time
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
profile_memory=True,
record_shapes=True
) as prof:
# Your training code
pass
print(prof.key_averages().table(sort_by="cuda_memory_usage"))
Clear Cache
# Free unused cached memory
torch.cuda.empty_cache()
# More aggressive cleanup
import gc
gc.collect()
torch.cuda.empty_cache()
Complete Example: Memory-Optimized Training
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
import bitsandbytes as bnb
# Load model with all optimizations
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Add LoRA for parameter-efficient training
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
)
model = get_peft_model(model, lora_config)
# 8-bit optimizer
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=2e-4)
# Training arguments
training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=1, # Small batch
gradient_accumulation_steps=16, # Effective batch = 16
bf16=True,
optim="adamw_bnb_8bit",
gradient_checkpointing=True,
max_grad_norm=1.0,
)
# Train!
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
optimizers=(optimizer, None),
)
trainer.train()
Maximize Your GPU Investment
Run larger models on GPUBrazil's affordable GPUs with these optimization techniques.
Get $5 Free Credit →Optimization Cheat Sheet
| Technique | Memory Savings | Speed Impact | Complexity |
|---|---|---|---|
| Mixed Precision | ~50% | +20-50% | Easy |
| Flash Attention | 10-20x (attention) | +100-300% | Easy |
| Gradient Checkpointing | 50-70% (activations) | -20-30% | Easy |
| 8-bit Optimizer | 75% (optimizer) | ~Same | Easy |
| CPU Offload | Variable | -50-80% | Medium |
| Model Sharding | Linear with GPUs | -10-20% | Medium |
Conclusion
GPU memory optimization is essential for modern ML. Combine multiple techniques for maximum effect:
- Start with: Mixed precision + Flash Attention (free performance!)
- Add if needed: Gradient checkpointing + 8-bit optimizer
- Last resort: CPU offloading or model sharding
With these techniques, you can train 7B models on 24GB GPUs and 70B models on 80GB GPUs. Deploy on GPUBrazil to maximize your compute budget.