Skip to content

[WIP] Async checkpointing #3701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@
if is_npu_available(check_device=False):
import torch_npu # noqa: F401

if is_torch_version(">=", "2.6.0"):
from .dist_checkpointing import save_model_and_optimizer

try:
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -3453,10 +3455,17 @@ def _inner(folder):
# Finish running the previous step before checkpointing
xm.mark_step()

# TODO: Siro - how to properly decide when to do this
_dist_save = self.parallelism_config is not None and self.parallelism_config.total_size > 1 and True
if _dist_save:
save_model_and_optimizer(self, self._models[0], self._optimizers[0], output_dir, True)

# Save the models taking care of FSDP and DeepSpeed nuances
weights = []
for i, model in enumerate(self._models):
if self.distributed_type == DistributedType.FSDP:
if _dist_save:
pass
elif self.distributed_type == DistributedType.FSDP:
logger.info("Saving FSDP model")
save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
logger.info(f"FSDP Model saved to output dir {output_dir}")
Expand All @@ -3474,7 +3483,9 @@ def _inner(folder):

# Save the optimizers taking care of FSDP and DeepSpeed nuances
optimizers = []
if self.distributed_type == DistributedType.FSDP:
if _dist_save:
pass
elif self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Saving FSDP Optimizer")
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
Expand Down
182 changes: 182 additions & 0 deletions src/accelerate/dist_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import queue

import dataclasses
import os
import pickle
import queue
from io import UnsupportedOperation
from pathlib import Path
from typing import TYPE_CHECKING, cast

import torch
import torch.distributed.checkpoint as dcp
import torch.distributed.checkpoint.state_dict as dcs
from torch.distributed.checkpoint.filesystem import (
FileSystemWriter,
SavePlan,
SavePlanner,
_generate_uuid,
_split_by_size_and_type,
)
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
from torch.distributed.checkpoint.storage import WriteResult


if TYPE_CHECKING:
from accelerate import Accelerator


class AccelerateStorageWriter(FileSystemWriter):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is the issue: I'm overriding quite interesting stuff from Pytorch that idk if I should (asked on their slack if it's safe). If we don't have this, we can't save optimizer into 1 directory and model into another, which we currently do

_DEFAULT_SUFFIX = ".distcp"
_OPTIM_FILE_PATH = "optimizer_0"
_MODEL_FILE_PATH = "pytorch_model_fsdp_0"

def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
self.optim_path = self.fs.concat_path(self.path, self._OPTIM_FILE_PATH)
self.model_path = self.fs.concat_path(self.path, self._MODEL_FILE_PATH)
self.fs.mkdir(self.optim_path)
self.fs.mkdir(self.model_path)
return super().prepare_local_plan(plan)

def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
):
storage_plan = plan.storage_data
optim_file_count = 0
model_file_count = 0

def gen_file(is_optimizer: bool = False) -> str:
nonlocal optim_file_count, model_file_count
if is_optimizer:
optim_file_count += 1
return f"{storage_plan.prefix}{optim_file_count}{self._DEFAULT_SUFFIX}"
else:
model_file_count += 1
return f"{storage_plan.prefix}{model_file_count}{self._DEFAULT_SUFFIX}"

file_queue: queue.Queue = queue.Queue()

for bucket in _split_by_size_and_type(1, plan.items):
optim_states = [wi for wi in bucket if "optim" in wi.index.fqn]
model_states = [wi for wi in bucket if "model" in wi.index.fqn]

for state, path in zip([optim_states, model_states], [self.optim_path, self.model_path]):
file_name = gen_file()
path = self.fs.concat_path(path, file_name)
file_queue.put((path, file_name, state))

return self._write_data(planner, file_queue)

def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
try:
metadata = dataclasses.replace(metadata, version="1.0.0")
except TypeError:
pass

def _split_metadata(
metadata: Metadata,
) -> tuple[Metadata, Metadata]:
result = []
for to_get in ["model", "optim"]:
result.append(
Metadata(
Copy link
Member Author

@S1ro1 S1ro1 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default DCP thinks we're saving an object called "state" into 1 directory, which we're not. We're saving "optimizer" into 1 subdirectory and "model" into another. That's why we have to update the metadata (remove the "state" prefix and split it into 2)

state_dict_metadata={
k.removeprefix("state."): v for k, v in metadata.state_dict_metadata.items() if to_get in k
},
planner_data={
k.removeprefix("state."): tuple([x for x in v if x != "state"])
for k, v in metadata.planner_data.items()
if to_get in k
},
)
)

return tuple(result)

model_metadata, optim_metadata = _split_metadata(metadata)
model_storage_md, optim_storage_md = {}, {}
for wr_list in results:
for wr in wr_list:
new_index = dataclasses.asdict(wr.index)
Copy link
Member Author

@S1ro1 S1ro1 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WriteResult dataclass is frozen (which tells a lot about what kind of war crimes I do here), so we have to use some fancy python things to avoid that.

new_index["fqn"] = new_index["fqn"].removeprefix("state.")
wr = WriteResult(
index=MetadataIndex(**new_index),
size_in_bytes=wr.size_in_bytes,
storage_data=wr.storage_data,
)
if "optim" in wr.index.fqn:
optim_storage_md.update({wr.index: wr.storage_data})
else:
model_storage_md.update({wr.index: wr.storage_data})

model_metadata.storage_data = model_storage_md
optim_metadata.storage_data = optim_storage_md

model_metadata.storage_meta = StorageMeta(self.model_path, save_id=_generate_uuid())
optim_metadata.storage_meta = StorageMeta(self.optim_path, save_id=_generate_uuid())

tmp_optim_path = cast(Path, self.fs.concat_path(self.optim_path, ".metadata.tmp"))
tmp_model_path = cast(Path, self.fs.concat_path(self.model_path, ".metadata.tmp"))

for meta, tmp_path, final_path in zip(
[model_metadata, optim_metadata],
[tmp_model_path, tmp_optim_path],
[self.model_path, self.optim_path],
):
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(meta, metadata_file)
if self.sync_files:
try:
os.fsync(metadata_file.fileno())
except (AttributeError, UnsupportedOperation):
os.sync()

metadata_path = self.fs.concat_path(final_path, ".metadata")
if self.fs.exists(metadata_path):
self.fs.rm_file(metadata_path)

self.fs.rename(tmp_path, metadata_path)


def save_model_and_optimizer(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only "public" api that we expose, not even. We only use this internally in accelerator.save_state.

accelerator: "Accelerator",
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
save_path: str,
async_save: bool = False,
) -> None:
if getattr(accelerator, "_async_save_handle", None) is not None:
accelerator._async_save_handle.result()

options = dcs.StateDictOptions()

model_sd, optimizer_sd = dcs.get_state_dict(model, optimizer, options=options)

stateful = {
"model": model_sd,
"optimizer": optimizer_sd,
}

save_fn = dcp.save if not async_save else dcp.async_save

potential_handle = save_fn(
state_dict=stateful,
storage_writer=AccelerateStorageWriter(save_path),
)

if async_save:
accelerator._async_save_handle = potential_handle
1 change: 1 addition & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(self, cpu: bool = False, **kwargs):
and (
os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true"
or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
or True
)
):
self.backend = "cuda:nccl,cpu:gloo"
Expand Down
4 changes: 3 additions & 1 deletion src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,15 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
optim_state = torch.load(input_optimizer_file, weights_only=True)
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
else:
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict

ckpt_dir = (
os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
if f"{OPTIMIZER_NAME}" not in input_dir
else input_dir
)
logger.info(f"Loading Optimizer from {ckpt_dir}")
optim_state = {"optimizer": optimizer.state_dict()}
optim_state = {"optimizer": get_optimizer_state_dict(model, optimizer)}
dist_cp.load(
optim_state,
checkpoint_id=ckpt_dir,
Expand Down
Loading