Skip to content

Parallelism config + TP + HSDP + BYODM (Bring Your Own Device Mesh) #3682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 76 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
2f471e3
Feat: init
S1ro1 Jun 24, 2025
43b1ca7
Feat: add validation + init from kwargs
S1ro1 Jun 24, 2025
79faa13
Fix: minor fixes
S1ro1 Jun 24, 2025
16f348b
Feat: more cleanup
S1ro1 Jul 4, 2025
53ef524
Minor refactor
S1ro1 Jul 4, 2025
cd31b02
remove import
S1ro1 Jul 7, 2025
2d89210
adding support for pre-configured device mesh
SalmanMohammadi Jul 15, 2025
afaafef
adding device mesh to fsdp2
SalmanMohammadi Jul 16, 2025
2d952cb
moving mesh dim defn to parralismconfig
SalmanMohammadi Jul 16, 2025
91ca626
tests
SalmanMohammadi Jul 17, 2025
910368b
WIP device mesh/accelerator validation
SalmanMohammadi Jul 18, 2025
b7d154e
WIP more tests
SalmanMohammadi Jul 18, 2025
8a0de72
Test Driven Development (TDD)
SalmanMohammadi Jul 18, 2025
1c68efb
fixing build_device_mesh
SalmanMohammadi Jul 18, 2025
e01abf1
FSDP dim names
SalmanMohammadi Jul 20, 2025
69b523c
adding example
Jul 21, 2025
c765a44
WIP
Jul 21, 2025
8d97930
fixing HSDP
Jul 21, 2025
57c0d9e
Feat: add back old options
S1ro1 Jul 21, 2025
c93285a
working example
Jul 21, 2025
cb40d36
debugging
Jul 21, 2025
b76ee67
adding parallelism config to partialstate
Jul 21, 2025
9aa2612
Feat: revert ddp changes
S1ro1 Jul 21, 2025
de96e74
Revert DDP
S1ro1 Jul 21, 2025
fd05e3b
Feat: (untested) update mesh dims and some minor tweaks
S1ro1 Jul 22, 2025
efc903e
adding dp_cp dims
Jul 22, 2025
7c3d0e3
updating comments
Jul 22, 2025
3cfce25
WIP
Jul 22, 2025
1bbdb75
wip 2
Jul 22, 2025
aa749ad
reverting
Jul 22, 2025
aa74576
storing state in accelerator rather than acceleratorstate
Jul 22, 2025
4e99b9c
Fix: minor tweaks
S1ro1 Jul 22, 2025
3d235cb
wip example update
Jul 22, 2025
61868c2
merging
Jul 22, 2025
f96fea3
Fixes for non-fsdp2 case
S1ro1 Jul 22, 2025
dd89452
Feat: ensure ddp/tp only works
S1ro1 Jul 22, 2025
7f243e0
updating example
Jul 23, 2025
4a2dd58
updating example
Jul 23, 2025
dc145c2
updating examples, fixing state
Jul 23, 2025
f21547f
fixed state
Jul 23, 2025
1a49c16
comments
Jul 23, 2025
07bf2b3
fixing partial state check
Jul 23, 2025
f274b35
linting
Jul 23, 2025
a6feca9
comments
Jul 23, 2025
80deb7e
removing fn
Jul 23, 2025
52c178f
merging
Jul 23, 2025
133ef5f
WIP: fix tp
S1ro1 Jul 23, 2025
74009ea
comments
Jul 24, 2025
379daa0
removing return
Jul 24, 2025
168b520
reverting upcast
Jul 24, 2025
76a546f
add guards
winglian Jul 25, 2025
e8963dc
guards for empty self.parallelism_config
winglian Jul 25, 2025
a402faf
use len on tuple to check if empty
winglian Jul 25, 2025
235d29f
Feat: cleanup example
S1ro1 Jul 26, 2025
1017752
Feat: some cleanup of example
S1ro1 Jul 27, 2025
36a1234
Merge branch 'main' into device_mesh_parallelism_config
S1ro1 Jul 27, 2025
7ddb3ab
Feat: add trackio
S1ro1 Jul 27, 2025
9fdc320
Fix: improve trackio
S1ro1 Jul 27, 2025
00dd4af
Feat: TP works
S1ro1 Jul 27, 2025
d21ff9f
Feat: some fsdp2 improv
S1ro1 Jul 27, 2025
d260842
Feat: working examples
S1ro1 Jul 28, 2025
8b89d27
handle clipping for tensor parallel
winglian Jul 29, 2025
4709fc8
Implicit replicate
S1ro1 Jul 29, 2025
353b559
Refactor: move to separate file + cleanup + basic comments
S1ro1 Jul 29, 2025
7364440
Fix: add unadded files, fix circular import
S1ro1 Jul 29, 2025
e90f832
Feat: better readme
S1ro1 Jul 29, 2025
044c713
Feat: add blog + ultrascale links
S1ro1 Jul 29, 2025
464a642
Tmp: should_save_model now returns only true
S1ro1 Jul 29, 2025
f85eadf
Fix: remove implicit_replication and style
S1ro1 Jul 30, 2025
86771e2
Fix: remove optional
S1ro1 Jul 30, 2025
c80aae0
add guard on parallelism_config.tp_enabled
winglian Jul 30, 2025
c8a2ae5
fix import
winglian Jul 30, 2025
ec59f84
fixing empty parallelism_config
Jul 30, 2025
0afb69f
fix import path for test patch
winglian Jul 30, 2025
89aad7a
fixing patch
Jul 30, 2025
c570f7c
merging
Jul 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions examples/fsdp2/nd_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""

import argparse

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM

from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from accelerate.utils.dataclasses import ParallelismConfig
from accelerate.utils.fsdp_utils import save_fsdp_optimizer
from utils import (
PerformanceTracker,
create_collate_fn,
get_dataset,
gpu_memory_usage_all,
setup_tokenizer,
)


MODEL_ID = "NousResearch/Llama-3.2-1B"


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--fsdp2-cls-name-to-wrap", type=str, default="LlamaDecoderLayer")
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=128)
parser.add_argument("--model-save-dir", type=str, default="./outputs")
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="Whether to save the model after training.",
)
parser.add_argument(
"--save-optimizer",
action="store_true",
default=False,
help="Whether to save the optimizer state after training.",
)
return parser.parse_args()


def main():
"""
Main function to train the model.
"""
args = parse_args()

set_seed(42)

if args.model:
model_id = args.model
else:
model_id = MODEL_ID

model_kwargs = {}
accelerator_kwargs = {}

parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
)

if parallelism_config.fsdp_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
cpu_ram_efficient_loading=False,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=[args.fsdp2_cls_name_to_wrap],
reshard_after_forward=True,
activation_checkpointing=True,
state_dict_type="FULL_STATE_DICT",
)
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin

accelerator = Accelerator(
mixed_precision="no",
parallelism_config=parallelism_config,
**accelerator_kwargs,
)

if args.tp_size > 1:
model_kwargs["tp_size"] = args.tp_size
model_kwargs["tp_plan"] = "auto"
model_kwargs["device_mesh"] = accelerator.torch_device_mesh

model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)
accelerator.print("Memory usage after model load")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
tokenizer = setup_tokenizer(model_id)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

model, optimizer = accelerator.prepare(model, optimizer)
accelerator.print("Memory usage after model prepare")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)

dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)

model.train()

total_num_steps = min(10, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=2)

accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break

outputs = model(**batch)
loss = outputs.loss

accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()

dist.all_reduce(loss, op=dist.ReduceOp.AVG)

batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)

print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}

if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
if step % 2 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)

accelerator.log(log_metrics)

accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")
if parallelism_config.fsdp_enabled and args.save_optimizer:
accelerator.print("Saving optimizer state...")
save_fsdp_optimizer(
fsdp2_plugin,
accelerator,
optimizer,
model,
args.model_save_dir + "/opt",
)
accelerator.print("Optimizer state saved.")
accelerator.print("Saving model state...")
if args.save_model:
model.save_pretrained(args.model_save_dir)
accelerator.print(f"Model saved to {args.model_save_dir}")


if __name__ == "__main__":
main()
187 changes: 187 additions & 0 deletions examples/fsdp2/nd_parallel_prepared_device_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""

import argparse

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM

from accelerate import Accelerator
from accelerate.state import PartialState
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from accelerate.utils.dataclasses import ParallelismConfig
from accelerate.utils.fsdp_utils import save_fsdp_optimizer
from utils import PerformanceTracker, create_collate_fn, get_dataset, gpu_memory_usage_all, setup_tokenizer


MODEL_ID = "NousResearch/Llama-3.2-1B"


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--fsdp2-cls-name-to-wrap", type=str, default="LlamaDecoderLayer")
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=128)
parser.add_argument("--model-save-dir", type=str, default="./outputs")
parser.add_argument(
"--save-model", action="store_true", default=False, help="Whether to save the model after training."
)
parser.add_argument(
"--save-optimizer",
action="store_true",
default=False,
help="Whether to save the optimizer state after training.",
)
return parser.parse_args()


def main():
"""
Main function to train the model.
"""
args = parse_args()

set_seed(42)

if args.model:
model_id = args.model
else:
model_id = MODEL_ID

model_kwargs = {}
accelerator_kwargs = {}

parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
)

device_mesh = parallelism_config.build_device_mesh("cuda")

if args.tp_size > 1:
model_kwargs["tp_size"] = args.tp_size
model_kwargs["tp_plan"] = "auto"
model_kwargs["device_mesh"] = device_mesh

model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)

PartialState(device_mesh=device_mesh, parallelism_config=parallelism_config)

if parallelism_config.fsdp_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
cpu_ram_efficient_loading=False,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=[args.fsdp2_cls_name_to_wrap],
reshard_after_forward=True,
activation_checkpointing=True,
state_dict_type="FULL_STATE_DICT",
)
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin

accelerator = Accelerator(
mixed_precision="no",
**accelerator_kwargs,
)

accelerator.print("Memory usage after model load")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
tokenizer = setup_tokenizer(model_id)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

model, optimizer = accelerator.prepare(model, optimizer)
accelerator.print("Memory usage after model prepare")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)

dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)

model.train()

total_num_steps = min(100, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=10)

accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break

outputs = model(**batch)
loss = outputs.loss

accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()

dist.all_reduce(loss, op=dist.ReduceOp.AVG)

batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)

print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}

if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)

accelerator.log(log_metrics)

accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")
if parallelism_config.fsdp_enabled and args.save_optimizer:
accelerator.print("Saving optimizer state...")
save_fsdp_optimizer(
fsdp2_plugin,
accelerator,
optimizer,
model,
args.model_save_dir + "/opt",
)
accelerator.print("Optimizer state saved.")
accelerator.print("Saving model state...")
if args.save_model:
model.save_pretrained(args.model_save_dir)
accelerator.print(f"Model saved to {args.model_save_dir}")


if __name__ == "__main__":
main()
Loading