Skip to content

Commit a21e5b9

Browse files
authored
Add defence for offload_states and reload_states w/o optimizer (#7211)
When the optimizer is not specified, the optimizer will be type `DeepSpeedZeRoOffload` instead of `DeepSpeedZeroOptimizer_Stage3` (e.g. for ZeRO-3 pure inference), while `DeepSpeedZeRoOffload` hasn't implemented methods `reload_states` and `offload_states`. https://github.com/deepspeedai/DeepSpeed/blob/56005d2b256eb81a88cba0a1984375f9663a3110/deepspeed/runtime/engine.py#L1684-L1707 ```log File "deepspeed/runtime/engine.py", line 3904, in offload_states self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'DeepSpeedZeRoOffload' object has no attribute 'offload_states' ``` In addition, #6855 seems to forget removing the check for `assert not self.zero_offload_param()`, as suggested by #6833 (comment), it returns None when offload_param is not given, and the newly added assertions have already covered these cases. This PR also removed this old check. Signed-off-by: Hollow Man <[email protected]>
1 parent 185330c commit a21e5b9

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

deepspeed/runtime/engine.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3894,7 +3894,9 @@ def offload_states(self,
38943894
param_offload_config = self.zero_offload_param()
38953895
assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters."
38963896

3897-
assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."
3897+
assert not isinstance(
3898+
self.optimizer,
3899+
DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer."
38983900

38993901
if device == OffloadDeviceEnum.none:
39003902
logger.warning("No device specified for offloading states.")
@@ -3913,4 +3915,9 @@ def reload_states(self, non_blocking: bool = False) -> None:
39133915
"""
39143916
assert self.zero_optimization_stage(
39153917
) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3."
3918+
3919+
assert not isinstance(
3920+
self.optimizer,
3921+
DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer."
3922+
39163923
self.optimizer.reload_states(non_blocking=non_blocking)

0 commit comments

Comments
 (0)