Skip to content

[WIP] Async checkpointing #3701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

[WIP] Async checkpointing #3701

wants to merge 1 commit into from

Conversation

S1ro1
Copy link
Member

@S1ro1 S1ro1 commented Aug 1, 2025

Very much WIP, overrides bunch of stuff I'm not sure that is stable to do.
TODO: discuss if we want to do a bit different approach (and more easily maintainable)

from accelerate import Accelerator


class AccelerateStorageWriter(FileSystemWriter):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is the issue: I'm overriding quite interesting stuff from Pytorch that idk if I should (asked on their slack if it's safe). If we don't have this, we can't save optimizer into 1 directory and model into another, which we currently do

model_storage_md, optim_storage_md = {}, {}
for wr_list in results:
for wr in wr_list:
new_index = dataclasses.asdict(wr.index)
Copy link
Member Author

@S1ro1 S1ro1 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WriteResult dataclass is frozen (which tells a lot about what kind of war crimes I do here), so we have to use some fancy python things to avoid that.

result = []
for to_get in ["model", "optim"]:
result.append(
Metadata(
Copy link
Member Author

@S1ro1 S1ro1 Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default DCP thinks we're saving an object called "state" into 1 directory, which we're not. We're saving "optimizer" into 1 subdirectory and "model" into another. That's why we have to update the metadata (remove the "state" prefix and split it into 2)

self.fs.rename(tmp_path, metadata_path)


def save_model_and_optimizer(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only "public" api that we expose, not even. We only use this internally in accelerator.save_state.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@S1ro1 S1ro1 changed the title WIP: [Async checkpointing] [WIP] Async checkpointing Aug 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants