You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add defence for offload_states and reload_states w/o optimizer (#7211)
When the optimizer is not specified, the optimizer will be type
`DeepSpeedZeRoOffload` instead of `DeepSpeedZeroOptimizer_Stage3` (e.g.
for ZeRO-3 pure inference), while `DeepSpeedZeRoOffload` hasn't
implemented methods `reload_states` and `offload_states`.
https://github.com/deepspeedai/DeepSpeed/blob/56005d2b256eb81a88cba0a1984375f9663a3110/deepspeed/runtime/engine.py#L1684-L1707
```log
File "deepspeed/runtime/engine.py", line 3904, in offload_states
self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DeepSpeedZeRoOffload' object has no attribute 'offload_states'
```
In addition, #6855 seems to
forget removing the check for `assert not self.zero_offload_param()`, as
suggested by
#6833 (comment),
it returns None when offload_param is not given, and the newly added
assertions have already covered these cases. This PR also removed this
old check.
Signed-off-by: Hollow Man <[email protected]>
Copy file name to clipboardExpand all lines: deepspeed/runtime/engine.py
+8-1
Original file line number
Diff line number
Diff line change
@@ -3894,7 +3894,9 @@ def offload_states(self,
3894
3894
param_offload_config=self.zero_offload_param()
3895
3895
assertparam_offload_configisNoneorparam_offload_config.device==OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters."
3896
3896
3897
-
assertnotself.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."
3897
+
assertnotisinstance(
3898
+
self.optimizer,
3899
+
DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer."
3898
3900
3899
3901
ifdevice==OffloadDeviceEnum.none:
3900
3902
logger.warning("No device specified for offloading states.")
0 commit comments