Open
Description
Bug description
I was reading this part of the docs on FSDP https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#identify-large-layers. Where is the memory profile table from, and how can I repro it?

I ran the toy script above:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.demos import Transformer, WikiText2
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer( # 1B parameters
vocab_size=vocab_size,
nlayers=32,
nhid=4096,
ninp=1024,
nhead=64,
)
def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.1)
L.seed_everything(42)
# Data
dataset = WikiText2()
train_dataloader = DataLoader(dataset)
# Model
model = LanguageModel(vocab_size=dataset.vocab_size)
# Trainer
trainer = L.Trainer(accelerator="cuda", devices=8, strategy=FSDPStrategy(), limit_train_batches=2, max_epochs=1)
trainer.fit(model, train_dataloader)
trainer.print(torch.cuda.memory_summary())
With 8 H100's can I get memory profile of
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 7737 MiB | 17357 MiB | 51663 MiB | 43926 MiB |
| from large pool | 7733 MiB | 17349 MiB | 49592 MiB | 41859 MiB |
| from small pool | 4 MiB | 187 MiB | 2071 MiB | 2067 MiB |
|---------------------------------------------------------------------------|
| Active memory | 11585 MiB | 17357 MiB | 51663 MiB | 40078 MiB |
| from large pool | 11581 MiB | 17349 MiB | 49592 MiB | 38011 MiB |
| from small pool | 4 MiB | 187 MiB | 2071 MiB | 2067 MiB |
|---------------------------------------------------------------------------|
| Requested memory | 11584 MiB | 17356 MiB | 51482 MiB | 39898 MiB |
| from large pool | 11580 MiB | 17348 MiB | 49411 MiB | 37830 MiB |
| from small pool | 4 MiB | 187 MiB | 2071 MiB | 2067 MiB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 19712 MiB | 19712 MiB | 19712 MiB | 0 B |
| from large pool | 19522 MiB | 19522 MiB | 19522 MiB | 0 B |
| from small pool | 190 MiB | 190 MiB | 190 MiB | 0 B |
|---------------------------------------------------------------------------|
using around 20GB which is considerably different from the cited 9.6GB; am I missing something about how these numbers were computed? It's either a regression or the docs are not up to date I think.
cc @awaelchli (who I think wrote these docs)
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
No response
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.2
#- PyTorch Version (e.g., 2.5): 2.7.1
#- Python version (e.g., 3.12): 3.12.9
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.8
#- GPU models and configuration: 8x H100
#- How you installed Lightning(`conda`, `pip`, source): uv
More info
No response