-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[WIP] Async checkpointing #3701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
from accelerate import Accelerator | ||
|
||
|
||
class AccelerateStorageWriter(FileSystemWriter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is the issue: I'm overriding quite interesting stuff from Pytorch that idk if I should (asked on their slack if it's safe). If we don't have this, we can't save optimizer into 1 directory and model into another, which we currently do
model_storage_md, optim_storage_md = {}, {} | ||
for wr_list in results: | ||
for wr in wr_list: | ||
new_index = dataclasses.asdict(wr.index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WriteResult
dataclass is frozen (which tells a lot about what kind of war crimes I do here), so we have to use some fancy python things to avoid that.
result = [] | ||
for to_get in ["model", "optim"]: | ||
result.append( | ||
Metadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default DCP thinks we're saving an object called "state" into 1 directory, which we're not. We're saving "optimizer" into 1 subdirectory and "model" into another. That's why we have to update the metadata (remove the "state" prefix and split it into 2)
self.fs.rename(tmp_path, metadata_path) | ||
|
||
|
||
def save_model_and_optimizer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only "public" api that we expose, not even. We only use this internally in accelerator.save_state
.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Very much WIP, overrides bunch of stuff I'm not sure that is stable to do.
TODO: discuss if we want to do a bit different approach (and more easily maintainable)