-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[WIP] Async checkpointing #3701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Copyright 2025 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License.import queue | ||
|
||
import dataclasses | ||
import os | ||
import pickle | ||
import queue | ||
from io import UnsupportedOperation | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, cast | ||
|
||
import torch | ||
import torch.distributed.checkpoint as dcp | ||
import torch.distributed.checkpoint.state_dict as dcs | ||
from torch.distributed.checkpoint.filesystem import ( | ||
FileSystemWriter, | ||
SavePlan, | ||
SavePlanner, | ||
_generate_uuid, | ||
_split_by_size_and_type, | ||
) | ||
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta | ||
from torch.distributed.checkpoint.storage import WriteResult | ||
|
||
|
||
if TYPE_CHECKING: | ||
from accelerate import Accelerator | ||
|
||
|
||
class AccelerateStorageWriter(FileSystemWriter): | ||
_DEFAULT_SUFFIX = ".distcp" | ||
_OPTIM_FILE_PATH = "optimizer_0" | ||
_MODEL_FILE_PATH = "pytorch_model_fsdp_0" | ||
|
||
def prepare_local_plan(self, plan: SavePlan) -> SavePlan: | ||
self.optim_path = self.fs.concat_path(self.path, self._OPTIM_FILE_PATH) | ||
self.model_path = self.fs.concat_path(self.path, self._MODEL_FILE_PATH) | ||
self.fs.mkdir(self.optim_path) | ||
self.fs.mkdir(self.model_path) | ||
return super().prepare_local_plan(plan) | ||
|
||
def write_data( | ||
self, | ||
plan: SavePlan, | ||
planner: SavePlanner, | ||
): | ||
storage_plan = plan.storage_data | ||
optim_file_count = 0 | ||
model_file_count = 0 | ||
|
||
def gen_file(is_optimizer: bool = False) -> str: | ||
nonlocal optim_file_count, model_file_count | ||
if is_optimizer: | ||
optim_file_count += 1 | ||
return f"{storage_plan.prefix}{optim_file_count}{self._DEFAULT_SUFFIX}" | ||
else: | ||
model_file_count += 1 | ||
return f"{storage_plan.prefix}{model_file_count}{self._DEFAULT_SUFFIX}" | ||
|
||
file_queue: queue.Queue = queue.Queue() | ||
|
||
for bucket in _split_by_size_and_type(1, plan.items): | ||
optim_states = [wi for wi in bucket if "optim" in wi.index.fqn] | ||
model_states = [wi for wi in bucket if "model" in wi.index.fqn] | ||
|
||
for state, path in zip([optim_states, model_states], [self.optim_path, self.model_path]): | ||
file_name = gen_file() | ||
path = self.fs.concat_path(path, file_name) | ||
file_queue.put((path, file_name, state)) | ||
|
||
return self._write_data(planner, file_queue) | ||
|
||
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: | ||
try: | ||
metadata = dataclasses.replace(metadata, version="1.0.0") | ||
except TypeError: | ||
pass | ||
|
||
def _split_metadata( | ||
metadata: Metadata, | ||
) -> tuple[Metadata, Metadata]: | ||
result = [] | ||
for to_get in ["model", "optim"]: | ||
result.append( | ||
Metadata( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By default DCP thinks we're saving an object called "state" into 1 directory, which we're not. We're saving "optimizer" into 1 subdirectory and "model" into another. That's why we have to update the metadata (remove the "state" prefix and split it into 2) |
||
state_dict_metadata={ | ||
k.removeprefix("state."): v for k, v in metadata.state_dict_metadata.items() if to_get in k | ||
}, | ||
planner_data={ | ||
k.removeprefix("state."): tuple([x for x in v if x != "state"]) | ||
for k, v in metadata.planner_data.items() | ||
if to_get in k | ||
}, | ||
) | ||
) | ||
|
||
return tuple(result) | ||
|
||
model_metadata, optim_metadata = _split_metadata(metadata) | ||
model_storage_md, optim_storage_md = {}, {} | ||
for wr_list in results: | ||
for wr in wr_list: | ||
new_index = dataclasses.asdict(wr.index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
new_index["fqn"] = new_index["fqn"].removeprefix("state.") | ||
wr = WriteResult( | ||
index=MetadataIndex(**new_index), | ||
size_in_bytes=wr.size_in_bytes, | ||
storage_data=wr.storage_data, | ||
) | ||
if "optim" in wr.index.fqn: | ||
optim_storage_md.update({wr.index: wr.storage_data}) | ||
else: | ||
model_storage_md.update({wr.index: wr.storage_data}) | ||
|
||
model_metadata.storage_data = model_storage_md | ||
optim_metadata.storage_data = optim_storage_md | ||
|
||
model_metadata.storage_meta = StorageMeta(self.model_path, save_id=_generate_uuid()) | ||
optim_metadata.storage_meta = StorageMeta(self.optim_path, save_id=_generate_uuid()) | ||
|
||
tmp_optim_path = cast(Path, self.fs.concat_path(self.optim_path, ".metadata.tmp")) | ||
tmp_model_path = cast(Path, self.fs.concat_path(self.model_path, ".metadata.tmp")) | ||
|
||
for meta, tmp_path, final_path in zip( | ||
[model_metadata, optim_metadata], | ||
[tmp_model_path, tmp_optim_path], | ||
[self.model_path, self.optim_path], | ||
): | ||
with self.fs.create_stream(tmp_path, "wb") as metadata_file: | ||
pickle.dump(meta, metadata_file) | ||
if self.sync_files: | ||
try: | ||
os.fsync(metadata_file.fileno()) | ||
except (AttributeError, UnsupportedOperation): | ||
os.sync() | ||
|
||
metadata_path = self.fs.concat_path(final_path, ".metadata") | ||
if self.fs.exists(metadata_path): | ||
self.fs.rm_file(metadata_path) | ||
|
||
self.fs.rename(tmp_path, metadata_path) | ||
|
||
|
||
def save_model_and_optimizer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the only "public" api that we expose, not even. We only use this internally in |
||
accelerator: "Accelerator", | ||
model: torch.nn.Module, | ||
optimizer: torch.optim.Optimizer, | ||
save_path: str, | ||
async_save: bool = False, | ||
) -> None: | ||
if getattr(accelerator, "_async_save_handle", None) is not None: | ||
accelerator._async_save_handle.result() | ||
|
||
options = dcs.StateDictOptions() | ||
|
||
model_sd, optimizer_sd = dcs.get_state_dict(model, optimizer, options=options) | ||
|
||
stateful = { | ||
"model": model_sd, | ||
"optimizer": optimizer_sd, | ||
} | ||
|
||
save_fn = dcp.save if not async_save else dcp.async_save | ||
|
||
potential_handle = save_fn( | ||
state_dict=stateful, | ||
storage_writer=AccelerateStorageWriter(save_path), | ||
) | ||
|
||
if async_save: | ||
accelerator._async_save_handle = potential_handle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is the issue: I'm overriding quite interesting stuff from Pytorch that idk if I should (asked on their slack if it's safe). If we don't have this, we can't save optimizer into 1 directory and model into another, which we currently do