Skip to content

Commit 4eed2be

Browse files
authored
FSDP bug fix for load_state_dict (#18596)
1 parent d344534 commit 4eed2be

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/transformers/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
19351935
else:
19361936
# We load the model state dict on the CPU to avoid an OOM error.
19371937
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
1938-
load_result = model.load_state_dict(state_dict, strict=False)
1938+
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
1939+
# which takes *args instead of **kwargs
1940+
load_result = model.load_state_dict(state_dict, False)
19391941
# release memory
19401942
del state_dict
19411943
self._issue_warnings_after_load(load_result)
@@ -1989,7 +1991,9 @@ def _load_best_model(self):
19891991
# We load the model state dict on the CPU to avoid an OOM error.
19901992
state_dict = torch.load(best_model_path, map_location="cpu")
19911993
# If the model is on the GPU, it still works!
1992-
load_result = model.load_state_dict(state_dict, strict=False)
1994+
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
1995+
# which takes *args instead of **kwargs
1996+
load_result = model.load_state_dict(state_dict, False)
19931997
if not is_sagemaker_mp_enabled():
19941998
self._issue_warnings_after_load(load_result)
19951999
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):

0 commit comments

Comments
 (0)