Skip to content

Fix issues with unwrapping modules with accelerate #963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add support for compiled PyTorch modules using the `torch.compile` function, introduced in [PyTorch 2.0 release](https://pytorch.org/get-started/pytorch-2.0/), which can greatly improve performance on new GPU architectures; to use it, initialize your net with the `compile=True` argument, further compilation arguments can be specified using the dunder notation, e.g. `compile__dynamic=True`
- Add a class [`DistributedHistory`](https://skorch.readthedocs.io/en/latest/history.html#skorch.history.DistributedHistory) which should be used when training in a multi GPU setting (#955)
- `SkorchDoctor`: A helper class that assists in understanding and debugging the neural net training, see [this notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb) (#912)
- When using `AccelerateMixin`, it is now possible to prevent unwrapping of the modules by setting `unwrap_after_train=True`

### Changed

Expand All @@ -21,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `_get_param_names` returns a list instead of a generator so that subsequent
error messages return useful information instead of a generator `repr`
string (#925)
- Fixed a bug that caused modules to not be sufficiently unwrapped at the end of training when using `AccelerateMixin`, which could prevent them from being pickleable

## [0.12.1] - 2022-11-18

Expand Down
25 changes: 24 additions & 1 deletion skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,19 @@ class was added: Using this mixin in conjunction with the accelerate library
leave device handling to accelerate. Therefore, it is best to leave this
argument to be None, which means that skorch does not set the device.

unwrap_after_train : bool (default=True)
By default, with this option being ``True``, the module(s) and criterion
are automatically "unwrapped" after training. This means that their
initial state -- from before they were prepared by the ``accelerator`` --
is restored. This is necessary to pickle the net.

There are circumstances where you might want to disable this behavior. For
instance, when you want to further train the model with AMP enabled (using
``net.partial_fit`` or ``warm_start=True``). Also, unwrapping the modules
means that the advantage of using mixed precision is lost during
inference. In those cases, if you don't need to pickle the net, you should
set ``unwrap_after_train=False``.

callbacks__print_log__sink : 'auto' or callable
If 'auto', uses the ``print`` function of the accelerator, if it has one.
This avoids printing the same output multiple times when training
Expand All @@ -900,6 +913,7 @@ def __init__(
*args,
accelerator,
device=None,
unwrap_after_train=True,
callbacks__print_log__sink='auto',
**kwargs
):
Expand All @@ -910,6 +924,7 @@ def __init__(
**kwargs
)
self.accelerator = accelerator
self.unwrap_after_train = unwrap_after_train

def _validate_params(self):
super()._validate_params()
Expand Down Expand Up @@ -1009,7 +1024,15 @@ def _step_optimizer(self, step_fn):
# pylint: disable=unused-argument
def on_train_end(self, net, X=None, y=None, **kwargs):
super().on_train_end(net, X=X, y=y, **kwargs)
self.module_ = self.accelerator.unwrap_model(self.module_)
if not self.unwrap_after_train:
return self

for name in self._modules + self._criteria:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
orig = self.accelerator.unwrap_model(module, keep_fp32_wrapper=False)
setattr(self, name + '_', orig)
return self

def evaluation_step(self, batch, training=False):
# More context:
Expand Down
50 changes: 46 additions & 4 deletions skorch/tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,9 @@ def test_mixed_precision(self, net_cls, accelerator_cls, data, mixed_precision):
assert np.isfinite(net.history[:, "train_loss"]).all()

@pytest.mark.parametrize('mixed_precision', [
pytest.param('fp16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
pytest.param('bf16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
'no', # no acceleration works because forward is left the same
'fp16',
pytest.param('bf16', marks=pytest.mark.xfail(raises=pickle.PicklingError)),
])
def test_mixed_precision_pickling(
self, net_cls, accelerator_cls, data, mixed_precision
Expand Down Expand Up @@ -601,9 +601,51 @@ def test_mixed_precision_pickling(

accelerator = accelerator_cls(mixed_precision=mixed_precision)
net = net_cls(accelerator=accelerator)
net.initialize()
X, y = data
net.fit(X[:100], y[:100])
pickle.loads(pickle.dumps(net))

def test_unwrapping_all_modules(self, module_cls, accelerator_cls, data):
# This test is for a bug we had previously where only 'module_' was
# unwrapped, not all possible modules and criteria.
if not torch.cuda.is_available():
pytest.skip('skipping test because device does not support it')

class MyNet(AcceleratedNet):
"""Net with two different modules"""
def initialize_module(self):
super().initialize_module()
self.module2_ = module_cls()
return self

accelerator = accelerator_cls(mixed_precision='fp16')
net = MyNet(module_cls, accelerator=accelerator, unwrap_after_train=True)
X, y = data
net.fit(X[:100], y[:100])

# there isn't really an elegant way to check if the modules have been
# correctly unwrapped
assert not hasattr(net.criterion_.forward, '__wrapped__')
assert not hasattr(net.module_.forward, '__wrapped__')
assert not hasattr(net.module2_.forward, '__wrapped__')

def test_not_unwrapping_modules(self, net_cls, accelerator_cls, data):
# Make it possible not to unwrap the modules after training. This is
# useful, e.g., to allow further training with warm start or to do
# inference with AMP, but it prevents the model from being pickled.
if not torch.cuda.is_available():
pytest.skip('skipping test because device does not support it')

accelerator = accelerator_cls(mixed_precision='fp16')
net = net_cls(accelerator=accelerator, unwrap_after_train=False)
X, y = data
net.fit(X[:100], y[:100])

# there isn't really an elegant way to check if the modules have been
# correctly unwrapped
assert hasattr(net.criterion_.forward, '__wrapped__')
assert hasattr(net.module_.forward, '__wrapped__')

@pytest.mark.parametrize('mixed_precision', ['fp16', 'bf16', 'no'])
def test_mixed_precision_save_load_params(
self, net_cls, accelerator_cls, data, mixed_precision, tmp_path
Expand Down Expand Up @@ -693,7 +735,7 @@ def backward(self, loss, **kwargs):
loss.backward(**kwargs)
loss.backward_was_called = True

def unwrap_model(self, model):
def unwrap_model(self, model, keep_fp32_wrapper=True):
return model

def gather_for_metrics(self, output):
Expand Down