@@ -1935,7 +1935,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
1935
1935
else :
1936
1936
# We load the model state dict on the CPU to avoid an OOM error.
1937
1937
state_dict = torch .load (os .path .join (resume_from_checkpoint , WEIGHTS_NAME ), map_location = "cpu" )
1938
- load_result = model .load_state_dict (state_dict , strict = False )
1938
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
1939
+ # which takes *args instead of **kwargs
1940
+ load_result = model .load_state_dict (state_dict , False )
1939
1941
# release memory
1940
1942
del state_dict
1941
1943
self ._issue_warnings_after_load (load_result )
@@ -1989,7 +1991,9 @@ def _load_best_model(self):
1989
1991
# We load the model state dict on the CPU to avoid an OOM error.
1990
1992
state_dict = torch .load (best_model_path , map_location = "cpu" )
1991
1993
# If the model is on the GPU, it still works!
1992
- load_result = model .load_state_dict (state_dict , strict = False )
1994
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
1995
+ # which takes *args instead of **kwargs
1996
+ load_result = model .load_state_dict (state_dict , False )
1993
1997
if not is_sagemaker_mp_enabled ():
1994
1998
self ._issue_warnings_after_load (load_result )
1995
1999
elif os .path .exists (os .path .join (self .state .best_model_checkpoint , WEIGHTS_INDEX_NAME )):
0 commit comments