Skip to content

Experimental support for accelerate #826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 4, 2022
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training
- Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch

### Changed

Expand Down
49 changes: 49 additions & 0 deletions docs/user/helper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,53 @@ argument ``idx=0``, the default) and one for y (with argument
gs.fit(X_sl, y_sl)


AccelerateMixin
---------------

This mixin class can be used to add support for huggingface accelerate_ to
skorch. E.g., this allows you to use mixed precision training (AMP), multi-GPU
training, or training with a TPU. For the time being, this feature should be
considered experimental.

To use this feature, create a new subclass of the neural net class you want to
use and inherit from the mixin class. E.g., if you want to use a
:class:`.NeuralNet`, it would look like this:

.. code:: python

from skorch import NeuralNet
from skorch.helper import AccelerateMixin

class AcceleratedNet(AccelerateMixin, NeuralNet):
"""NeuralNetClassifier with accelerate support"""

The same would work for :class:`.NeuralNetClassifier`,
:class:`.NeuralNetRegressor`, etc. Then pass an instance of Accelerator_ with
the desired parameters and you're good to go:

.. code:: python

from accelerate import Accelerator

accelerator = Accelerator(...)
net = AcceleratedNet(
MyModule,
accelerator=accelerator,
device=None,
callbacks__print_log__sink=accelerator.print)
Copy link
Member

@thomasjpfan thomasjpfan Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do set callbacks__print_log__sink automatically? It feels like more boilerplate for a user to write.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, true. On the other hand, since the accelerator does not exist when we define AccelerateMixin, we cannot set it as default argument. We would need to do something else, like:

def initialize(self):
    super().initialize()
    self.set_params(callbacks__print_log__sink=self.accelerator.print)
    return self

Not sure if the user would have any reason not to set this, but if they had, this hard-coding would prove annoying.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having good defaults is generally better. If a user does not want to set it like this, I'm guessing they are more advanced and would need to to override initialize to undo it.

It's a toss up for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a different solution which hopefully lets us have the cake and eat it too. Please take a look: 9d3e5e0

net.fit(X, y)

accelerate_ recommends to leave the device handling to the Accelerator_, which
is why we set ``device=None`` (thus telling skorch not to change the device).
Furthermore, using ``accelerator.print`` should avoid printing the same output
multiple times when training concurrently on multiple machines.

To install accelerate_, run the following command inside your Python environment:

.. code:: bash

python -m pip install accelerate

Command line interface helpers
------------------------------

Expand Down Expand Up @@ -201,6 +248,8 @@ callbacks through the command line (but you can modify existing ones
as usual).


.. _accelerate: https://github.com/huggingface/accelerate
.. _Accelerator: https://huggingface.co/docs/accelerate/accelerator.html
.. _fire: https://github.com/google/python-fire
.. _numpydoc: https://github.com/numpy/numpydoc
.. _example: https://github.com/skorch-dev/skorch/tree/master/examples/cli
114 changes: 114 additions & 0 deletions skorch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

from skorch.cli import parse_args # pylint: disable=unused-import
from skorch.dataset import unpack_data
from skorch.utils import _make_split
from skorch.utils import is_torch_data_type
from skorch.utils import to_tensor
Expand Down Expand Up @@ -508,3 +509,116 @@ def describe_signature(self, df):
)

return signature


class AccelerateMixin:
"""Mixin class to add support for huggingface accelerate

This is an *experimental* feature.

Use this mixin class with one of the neural net classes (e.g. ``NeuralNet``,
``NeuralNetClassifier``, or ``NeuralNetRegressor``) and pass an instance of
``Accelerator`` for mixed precision, multi-GPU, or TPU training.

Install the accelerate library using:

.. code-block::

python -m pip install accelerate

skorch does not itself provide any facilities to enable these training
features. A lot of them can still be implemented by the user with a little
bit of extra work but it can be a daunting task. That is why this helper
class was added: Using this mixin in conjunction with the accelerate library
should cover a lot of common use cases.

Since accelerate is still quite young and backwards compatiblity breaking
features might be added, we treat its integration as an experimental
feature. When accelerate's API stabilizes, we will consider adding it to
skorch proper.

Examples
--------
>>> from skorch import NeuralNetClassifier
>>> from skorch.helper import AccelerateMixin
>>> from accelerate import Accelerator
>>>
>>> class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
... '''NeuralNetClassifier with accelerate support'''
>>>
>>> accelerator = Accelerator(...)
>>> net = AcceleratedNet(
... MyModule,
... accelerator=accelerator,
... device=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely can make device="auto" too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so sure on this one, as it's not as clear cut as the other default. Also, changing device is not as obscure as the print log sink.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with print, I can not really think of a case where device is not None when using Accelerator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, you convinced me, I changed the default to None.

... callbacks__print_log__sink=accelerator.print)
>>> net.fit(X, y)

The same approach works with all the other skorch net classes.

Parameters
----------
accelerator : accelerate.Accelerator
In addition to the usual parameters, pass an instance of
``accelerate.Accelerator`` with the desired settings.

"""
def __init__(self, *args, accelerator, **kwargs):
super().__init__(*args, **kwargs)
self.accelerator = accelerator

def _check_kwargs(self, kwargs):
super()._check_kwargs(kwargs)

if self.accelerator.device_placement and (self.device is not None):
raise ValueError(
"When device placement is performed by the accelerator, set device=None"
)

def _initialize_criterion(self, *args, **kwargs):
super()._initialize_criterion(*args, **kwargs)

with self._current_init_context('criterion'):
for name in self._criteria:
criterion = getattr(self, name + '_')
if isinstance(criterion, torch.nn.Module):
self.accelerator.prepare(criterion)

return self

def _initialize_module(self, *args, **kwargs):
super()._initialize_module(*args, **kwargs)

with self._current_init_context('module'):
for name in self._modules:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
self.accelerator.prepare(module)

return self

def _initialize_optimizer(self, *args, **kwargs):
super()._initialize_optimizer(*args, **kwargs)

with self._current_init_context('optimizer'):
for name in self._optimizers:
optimizer = getattr(self, name + '_')
if isinstance(optimizer, torch.optim.Optimizer):
self.accelerator.prepare(optimizer)
return self

def train_step_single(self, batch, **fit_params):
self._set_training(True)
Xi, yi = unpack_data(batch)
y_pred = self.infer(Xi, **fit_params)
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
self.accelerator.backward(loss)
return {
'loss': loss,
'y_pred': y_pred,
}

def get_iterator(self, *args, **kwargs):
iterator = super().get_iterator(*args, **kwargs)
iterator = self.accelerator.prepare(iterator)
return iterator
156 changes: 156 additions & 0 deletions skorch/tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test for helper.py"""
import pickle
from distutils.version import LooseVersion
from functools import partial

import numpy as np
import pytest
Expand Down Expand Up @@ -708,3 +710,157 @@ def test_describe_signature_other_dtypes(self, transformer_cls, df):
'col_cats': {"dtype": torch.int32, "input_units": 2},
}
assert result == expected


class TestAccelerate:
@pytest.fixture(scope='module')
def data(self, classifier_data):
return classifier_data

@pytest.fixture(scope='module')
def module_cls(self, classifier_module):
return classifier_module

@pytest.fixture
def accelerator_cls(self):
pytest.importorskip('accelerate')

from accelerate import Accelerator

return Accelerator

@pytest.fixture
def net_cls(self, module_cls):
from skorch import NeuralNetClassifier
from skorch.helper import AccelerateMixin

class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
pass

return partial(
AcceleratedNet,
module=module_cls,
max_epochs=2,
lr=0.1,
device=None,
)

@pytest.mark.parametrize('mixed_precision', ['no', 'fp16', 'bf16'])
def test_mixed_precision(self, net_cls, accelerator_cls, data, mixed_precision):
# Only test if training works at all, no specific test of whether the
# indicated precision is actually used, since that depends on the
# underlying hardware.
import accelerate

if LooseVersion(accelerate.__version__) > '0.5.1':
accelerator = accelerator_cls(mixed_precision=mixed_precision)
elif mixed_precision == 'bf16':
pytest.skip('bf16 only supported in accelerate version > 0.5.1')
else:
fp16 = mixed_precision == 'fp16'
accelerator = accelerator_cls(fp16=fp16)

net = net_cls(
accelerator=accelerator,
callbacks__print_log__sink=accelerator.print,
)
X, y = data
net.fit(X, y) # does not raise

def test_force_cpu(self, net_cls, accelerator_cls, data):
accelerator = accelerator_cls(device_placement=False, cpu=True)
net = net_cls(accelerator=accelerator)
net.set_params(device='cpu')
net.fit(*data) # does not raise

def test_device_placement(self, net_cls, accelerator_cls, data):
accelerator = accelerator_cls(device_placement=True)
net = net_cls(accelerator=accelerator)
net.set_params(device='cpu')
msg = "When device placement is performed by the accelerator, set device=None"
with pytest.raises(ValueError, match=msg):
net.fit(*data)

def test_all_components_prepared(self, module_cls, data):
# We cannot test whether accelerate is really performing its job.
# Instead, we test that all modules and optimizers, even custom
# user-defined ones, are properly prepared. We also test that
# loss.backward() is called. This means that we do test implementation
# details of accelerate that may change in the future.
from skorch import NeuralNetClassifier
from skorch.helper import AccelerateMixin

# pylint: disable=missing-class-docstring
class MockAccelerator:
def __init__(self):
self.device_placement = True
self.print = print

def prepare(self, *args):
for arg in args:
arg.is_prepared = True
return args if len(args) > 1 else args[0]

def backward(self, loss, **kwargs):
loss.backward(**kwargs)
loss.backward_was_called = True

# pylint: disable=missing-class-docstring
class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
def get_iterator(self, *args, **kwargs):
iterator = super().get_iterator(*args, **kwargs)
assert iterator.is_prepared
return iterator

def initialize_criterion(self):
super().initialize_criterion()
kwargs = self.get_params_for('criterion')
# pylint: disable=attribute-defined-outside-init
self.criterion2_ = self.criterion(**kwargs)
return self

def initialize_module(self):
super().initialize_module()
kwargs = self.get_params_for('module')
# pylint: disable=attribute-defined-outside-init
self.module2_ = self.module(**kwargs)
return self

def initialize_optimizer(self, *args, **kwargs):
super().initialize_optimizer(*args, **kwargs)
named_parameters = self.module2_.named_parameters()
args, kwargs = self.get_params_for_optimizer(
'optimizer', named_parameters)
# pylint: disable=attribute-defined-outside-init
self.optimizer2_ = self.optimizer(*args, **kwargs)
return self

def infer(self, *args, **kwargs):
# check that all modules and criteria are prepared
assert self.module_.is_prepared
assert self.module2_.is_prepared
assert self.criterion_.is_prepared
assert self.criterion2_.is_prepared
return super().infer(*args, **kwargs)

def train_step_single(self, *args, **kwargs):
# check that all optimizers are prepared and that
# loss.backward() was called
assert self.optimizer_.is_prepared
assert self.optimizer2_.is_prepared
output = super().train_step_single(*args, **kwargs)
assert output['loss'].backward_was_called
return output

accelerator = MockAccelerator()
net = AcceleratedNet(
module_cls,
device=None,
accelerator=accelerator,
max_epochs=2,
callbacks__print_log__sink=accelerator.print,
)
X, y = data
# does not raise
net.fit(X, y)
net.predict(X)