diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 56e738a01c1..c38754dd6e9 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -165,6 +165,8 @@ if is_npu_available(check_device=False): import torch_npu # noqa: F401 +if is_torch_version(">=", "2.6.0"): + from .dist_checkpointing import save_model_and_optimizer try: from torch.optim.lr_scheduler import LRScheduler @@ -3453,10 +3455,17 @@ def _inner(folder): # Finish running the previous step before checkpointing xm.mark_step() + # TODO: Siro - how to properly decide when to do this + _dist_save = self.parallelism_config is not None and self.parallelism_config.total_size > 1 and True + if _dist_save: + save_model_and_optimizer(self, self._models[0], self._optimizers[0], output_dir, True) + # Save the models taking care of FSDP and DeepSpeed nuances weights = [] for i, model in enumerate(self._models): - if self.distributed_type == DistributedType.FSDP: + if _dist_save: + pass + elif self.distributed_type == DistributedType.FSDP: logger.info("Saving FSDP model") save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i) logger.info(f"FSDP Model saved to output dir {output_dir}") @@ -3474,7 +3483,9 @@ def _inner(folder): # Save the optimizers taking care of FSDP and DeepSpeed nuances optimizers = [] - if self.distributed_type == DistributedType.FSDP: + if _dist_save: + pass + elif self.distributed_type == DistributedType.FSDP: for i, opt in enumerate(self._optimizers): logger.info("Saving FSDP Optimizer") save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i) diff --git a/src/accelerate/dist_checkpointing.py b/src/accelerate/dist_checkpointing.py new file mode 100644 index 00000000000..b8cb8009952 --- /dev/null +++ b/src/accelerate/dist_checkpointing.py @@ -0,0 +1,182 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.import queue + +import dataclasses +import os +import pickle +import queue +from io import UnsupportedOperation +from pathlib import Path +from typing import TYPE_CHECKING, cast + +import torch +import torch.distributed.checkpoint as dcp +import torch.distributed.checkpoint.state_dict as dcs +from torch.distributed.checkpoint.filesystem import ( + FileSystemWriter, + SavePlan, + SavePlanner, + _generate_uuid, + _split_by_size_and_type, +) +from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta +from torch.distributed.checkpoint.storage import WriteResult + + +if TYPE_CHECKING: + from accelerate import Accelerator + + +class AccelerateStorageWriter(FileSystemWriter): + _DEFAULT_SUFFIX = ".distcp" + _OPTIM_FILE_PATH = "optimizer_0" + _MODEL_FILE_PATH = "pytorch_model_fsdp_0" + + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + self.optim_path = self.fs.concat_path(self.path, self._OPTIM_FILE_PATH) + self.model_path = self.fs.concat_path(self.path, self._MODEL_FILE_PATH) + self.fs.mkdir(self.optim_path) + self.fs.mkdir(self.model_path) + return super().prepare_local_plan(plan) + + def write_data( + self, + plan: SavePlan, + planner: SavePlanner, + ): + storage_plan = plan.storage_data + optim_file_count = 0 + model_file_count = 0 + + def gen_file(is_optimizer: bool = False) -> str: + nonlocal optim_file_count, model_file_count + if is_optimizer: + optim_file_count += 1 + return f"{storage_plan.prefix}{optim_file_count}{self._DEFAULT_SUFFIX}" + else: + model_file_count += 1 + return f"{storage_plan.prefix}{model_file_count}{self._DEFAULT_SUFFIX}" + + file_queue: queue.Queue = queue.Queue() + + for bucket in _split_by_size_and_type(1, plan.items): + optim_states = [wi for wi in bucket if "optim" in wi.index.fqn] + model_states = [wi for wi in bucket if "model" in wi.index.fqn] + + for state, path in zip([optim_states, model_states], [self.optim_path, self.model_path]): + file_name = gen_file() + path = self.fs.concat_path(path, file_name) + file_queue.put((path, file_name, state)) + + return self._write_data(planner, file_queue) + + def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: + try: + metadata = dataclasses.replace(metadata, version="1.0.0") + except TypeError: + pass + + def _split_metadata( + metadata: Metadata, + ) -> tuple[Metadata, Metadata]: + result = [] + for to_get in ["model", "optim"]: + result.append( + Metadata( + state_dict_metadata={ + k.removeprefix("state."): v for k, v in metadata.state_dict_metadata.items() if to_get in k + }, + planner_data={ + k.removeprefix("state."): tuple([x for x in v if x != "state"]) + for k, v in metadata.planner_data.items() + if to_get in k + }, + ) + ) + + return tuple(result) + + model_metadata, optim_metadata = _split_metadata(metadata) + model_storage_md, optim_storage_md = {}, {} + for wr_list in results: + for wr in wr_list: + new_index = dataclasses.asdict(wr.index) + new_index["fqn"] = new_index["fqn"].removeprefix("state.") + wr = WriteResult( + index=MetadataIndex(**new_index), + size_in_bytes=wr.size_in_bytes, + storage_data=wr.storage_data, + ) + if "optim" in wr.index.fqn: + optim_storage_md.update({wr.index: wr.storage_data}) + else: + model_storage_md.update({wr.index: wr.storage_data}) + + model_metadata.storage_data = model_storage_md + optim_metadata.storage_data = optim_storage_md + + model_metadata.storage_meta = StorageMeta(self.model_path, save_id=_generate_uuid()) + optim_metadata.storage_meta = StorageMeta(self.optim_path, save_id=_generate_uuid()) + + tmp_optim_path = cast(Path, self.fs.concat_path(self.optim_path, ".metadata.tmp")) + tmp_model_path = cast(Path, self.fs.concat_path(self.model_path, ".metadata.tmp")) + + for meta, tmp_path, final_path in zip( + [model_metadata, optim_metadata], + [tmp_model_path, tmp_optim_path], + [self.model_path, self.optim_path], + ): + with self.fs.create_stream(tmp_path, "wb") as metadata_file: + pickle.dump(meta, metadata_file) + if self.sync_files: + try: + os.fsync(metadata_file.fileno()) + except (AttributeError, UnsupportedOperation): + os.sync() + + metadata_path = self.fs.concat_path(final_path, ".metadata") + if self.fs.exists(metadata_path): + self.fs.rm_file(metadata_path) + + self.fs.rename(tmp_path, metadata_path) + + +def save_model_and_optimizer( + accelerator: "Accelerator", + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + save_path: str, + async_save: bool = False, +) -> None: + if getattr(accelerator, "_async_save_handle", None) is not None: + accelerator._async_save_handle.result() + + options = dcs.StateDictOptions() + + model_sd, optimizer_sd = dcs.get_state_dict(model, optimizer, options=options) + + stateful = { + "model": model_sd, + "optimizer": optimizer_sd, + } + + save_fn = dcp.save if not async_save else dcp.async_save + + potential_handle = save_fn( + state_dict=stateful, + storage_writer=AccelerateStorageWriter(save_path), + ) + + if async_save: + accelerator._async_save_handle = potential_handle diff --git a/src/accelerate/state.py b/src/accelerate/state.py index d07344c67ee..87e993edf0a 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -232,6 +232,7 @@ def __init__(self, cpu: bool = False, **kwargs): and ( os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true" or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT" + or True ) ): self.backend = "cuda:nccl,cpu:gloo" diff --git a/src/accelerate/utils/fsdp_utils.py b/src/accelerate/utils/fsdp_utils.py index ef1ef7372ba..d9919f3ba70 100644 --- a/src/accelerate/utils/fsdp_utils.py +++ b/src/accelerate/utils/fsdp_utils.py @@ -303,13 +303,15 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o optim_state = torch.load(input_optimizer_file, weights_only=True) logger.info(f"Optimizer state loaded from {input_optimizer_file}") else: + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + ckpt_dir = ( os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") if f"{OPTIMIZER_NAME}" not in input_dir else input_dir ) logger.info(f"Loading Optimizer from {ckpt_dir}") - optim_state = {"optimizer": optimizer.state_dict()} + optim_state = {"optimizer": get_optimizer_state_dict(model, optimizer)} dist_cp.load( optim_state, checkpoint_id=ckpt_dir,