Understanding GPU Memory

The dreaded CUDA out of memory error stops ML engineers daily. Understanding what consumes GPU memory is the first step to optimization.

Memory Breakdown During Training

ComponentMemory UsageNotes
Model Parameters2 bytes/param (FP16)7B model = 14GB
GradientsSame as params+14GB for 7B
Optimizer States4-8 bytes/param (Adam)+28-56GB for 7B
ActivationsVaries by batch/seqOften largest!

Total for training 7B in FP16 with Adam: ~100GB!

💡 What you'll learn

Techniques to reduce memory by 50-80%, letting you train larger models on smaller GPUs or increase batch sizes on existing hardware.

Technique 1: Mixed Precision Training

Use 16-bit precision instead of 32-bit—half the memory, faster compute:

import torch
from torch.cuda.amp import autocast, GradScaler

model = Model().cuda()
optimizer = torch.optim.AdamW(model.parameters())
scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    # Forward pass in FP16
    with autocast():
        outputs = model(batch)
        loss = criterion(outputs)
    
    # Backward pass with scaled gradients
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

BF16 vs FP16

# BF16 is simpler (no scaler needed)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    outputs = model(batch)
    loss = criterion(outputs)
loss.backward()
optimizer.step()

Memory savings: ~50% for model + gradients

Technique 2: Gradient Checkpointing

Trade compute for memory by recomputing activations during backward pass:

from torch.utils.checkpoint import checkpoint_sequential, checkpoint

# For sequential models
model = nn.Sequential(
    Layer1(),
    Layer2(),
    # ... many layers
)

def forward(x):
    # Split into chunks, checkpoint each
    return checkpoint_sequential(model, chunks=4, input=x)

# For transformer models (HuggingFace)
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
model.gradient_checkpointing_enable()  # Single line!

⚠️ Speed Tradeoff

Gradient checkpointing increases training time by ~20-30% due to recomputation. Worth it when memory-constrained.

Memory savings: 50-70% on activations

Technique 3: Flash Attention

Flash Attention computes attention in tiles, dramatically reducing memory:

# Standard attention: O(n²) memory
# Flash Attention: O(n) memory!

# HuggingFace models
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    attn_implementation="flash_attention_2",  # Enable Flash Attention
    torch_dtype=torch.bfloat16,
)

# PyTorch native (2.0+)
from torch.nn.functional import scaled_dot_product_attention

# Automatically uses Flash Attention when possible
output = scaled_dot_product_attention(query, key, value)

Memory Comparison (8K context)

Attention TypeMemorySpeed
Standard~8GB1x
Flash Attention 2~0.5GB2-4x faster

Memory savings: 10-20x for attention layers!

Technique 4: 8-bit Optimizers

Adam optimizer states consume 8 bytes per parameter. Use 8-bit versions:

import bitsandbytes as bnb

# Replace AdamW with 8-bit version
optimizer = bnb.optim.AdamW8bit(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999)
)

# Works exactly like regular AdamW
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch).loss
    loss.backward()
    optimizer.step()

Memory savings: 75% on optimizer states

Technique 5: Gradient Accumulation

Simulate larger batches without more memory:

accumulation_steps = 4
effective_batch_size = batch_size * accumulation_steps

for i, batch in enumerate(dataloader):
    # Forward pass
    loss = model(batch).loss / accumulation_steps
    loss.backward()
    
    # Update weights every N steps
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

This doesn't reduce memory per batch but lets you achieve the same effective batch size with smaller actual batches.

Technique 6: CPU Offloading

Move unused data to CPU RAM:

# DeepSpeed ZeRO with CPU offload
ds_config = {
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        }
    }
}

# HuggingFace Accelerate
from accelerate import Accelerator

accelerator = Accelerator(
    mixed_precision="bf16",
    gradient_accumulation_steps=4,
    cpu_offload=True
)

⚠️ Speed Impact

CPU offloading significantly slows training (2-5x). Use only when you can't fit the model otherwise.

Technique 7: Model Sharding

Split model across multiple GPUs:

# Automatic device mapping
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70B",
    device_map="auto",  # Automatically splits across GPUs
    torch_dtype=torch.bfloat16
)

# Manual device mapping
device_map = {
    "model.embed_tokens": 0,
    "model.layers.0": 0,
    "model.layers.1": 0,
    # ... assign layers to GPUs
    "model.layers.78": 1,
    "model.layers.79": 1,
    "model.norm": 1,
    "lm_head": 1,
}

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70B",
    device_map=device_map,
    torch_dtype=torch.bfloat16
)

Technique 8: Efficient Data Loading

Data loading can consume significant GPU memory:

# Pin memory for faster transfers
dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,
    pin_memory=True,  # Faster CPU→GPU transfer
    persistent_workers=True  # Keep workers alive
)

# Use memory-mapped datasets
from datasets import load_dataset

dataset = load_dataset(
    "path/to/data",
    streaming=True  # Don't load entire dataset
)

Memory Profiling

Find what's consuming memory:

import torch

# Check current allocation
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Detailed breakdown
print(torch.cuda.memory_summary())

# Profile memory over time
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True
) as prof:
    # Your training code
    pass

print(prof.key_averages().table(sort_by="cuda_memory_usage"))

Clear Cache

# Free unused cached memory
torch.cuda.empty_cache()

# More aggressive cleanup
import gc
gc.collect()
torch.cuda.empty_cache()

Complete Example: Memory-Optimized Training

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
import bitsandbytes as bnb

# Load model with all optimizations
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Add LoRA for parameter-efficient training
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
)
model = get_peft_model(model, lora_config)

# 8-bit optimizer
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=2e-4)

# Training arguments
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=1,      # Small batch
    gradient_accumulation_steps=16,      # Effective batch = 16
    bf16=True,
    optim="adamw_bnb_8bit",
    gradient_checkpointing=True,
    max_grad_norm=1.0,
)

# Train!
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    optimizers=(optimizer, None),
)
trainer.train()

Maximize Your GPU Investment

Run larger models on GPUBrazil's affordable GPUs with these optimization techniques.

Get $5 Free Credit →

Optimization Cheat Sheet

TechniqueMemory SavingsSpeed ImpactComplexity
Mixed Precision~50%+20-50%Easy
Flash Attention10-20x (attention)+100-300%Easy
Gradient Checkpointing50-70% (activations)-20-30%Easy
8-bit Optimizer75% (optimizer)~SameEasy
CPU OffloadVariable-50-80%Medium
Model ShardingLinear with GPUs-10-20%Medium

Conclusion

GPU memory optimization is essential for modern ML. Combine multiple techniques for maximum effect:

  1. Start with: Mixed precision + Flash Attention (free performance!)
  2. Add if needed: Gradient checkpointing + 8-bit optimizer
  3. Last resort: CPU offloading or model sharding

With these techniques, you can train 7B models on 24GB GPUs and 70B models on 80GB GPUs. Deploy on GPUBrazil to maximize your compute budget.