-
Notifications
You must be signed in to change notification settings - Fork 398
Experimental support for accelerate #826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
e070abe
08f495d
b541ea0
9c22965
9d3e5e0
f195e8c
58310c1
8409c05
c4b79b9
08681f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
import torch | ||
|
||
from skorch.cli import parse_args # pylint: disable=unused-import | ||
from skorch.dataset import unpack_data | ||
from skorch.utils import _make_split | ||
from skorch.utils import is_torch_data_type | ||
from skorch.utils import to_tensor | ||
|
@@ -508,3 +509,116 @@ def describe_signature(self, df): | |
) | ||
|
||
return signature | ||
|
||
|
||
class AccelerateMixin: | ||
"""Mixin class to add support for huggingface accelerate | ||
|
||
This is an *experimental* feature. | ||
|
||
Use this mixin class with one of the neural net classes (e.g. ``NeuralNet``, | ||
``NeuralNetClassifier``, or ``NeuralNetRegressor``) and pass an instance of | ||
``Accelerator`` for mixed precision, multi-GPU, or TPU training. | ||
|
||
Install the accelerate library using: | ||
|
||
.. code-block:: | ||
|
||
python -m pip install accelerate | ||
|
||
skorch does not itself provide any facilities to enable these training | ||
features. A lot of them can still be implemented by the user with a little | ||
bit of extra work but it can be a daunting task. That is why this helper | ||
class was added: Using this mixin in conjunction with the accelerate library | ||
should cover a lot of common use cases. | ||
|
||
Since accelerate is still quite young and backwards compatiblity breaking | ||
features might be added, we treat its integration as an experimental | ||
feature. When accelerate's API stabilizes, we will consider adding it to | ||
skorch proper. | ||
|
||
Examples | ||
-------- | ||
>>> from skorch import NeuralNetClassifier | ||
>>> from skorch.helper import AccelerateMixin | ||
>>> from accelerate import Accelerator | ||
>>> | ||
>>> class AcceleratedNet(AccelerateMixin, NeuralNetClassifier): | ||
... '''NeuralNetClassifier with accelerate support''' | ||
>>> | ||
>>> accelerator = Accelerator(...) | ||
>>> net = AcceleratedNet( | ||
... MyModule, | ||
... accelerator=accelerator, | ||
... device=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We likely can make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not so sure on this one, as it's not as clear cut as the other default. Also, changing device is not as obscure as the print log sink. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, you convinced me, I changed the default to |
||
... callbacks__print_log__sink=accelerator.print) | ||
>>> net.fit(X, y) | ||
|
||
The same approach works with all the other skorch net classes. | ||
|
||
Parameters | ||
---------- | ||
accelerator : accelerate.Accelerator | ||
In addition to the usual parameters, pass an instance of | ||
``accelerate.Accelerator`` with the desired settings. | ||
|
||
""" | ||
def __init__(self, *args, accelerator, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.accelerator = accelerator | ||
|
||
def _check_kwargs(self, kwargs): | ||
super()._check_kwargs(kwargs) | ||
|
||
if self.accelerator.device_placement and (self.device is not None): | ||
raise ValueError( | ||
"When device placement is performed by the accelerator, set device=None" | ||
) | ||
|
||
def _initialize_criterion(self, *args, **kwargs): | ||
super()._initialize_criterion(*args, **kwargs) | ||
|
||
with self._current_init_context('criterion'): | ||
for name in self._criteria: | ||
criterion = getattr(self, name + '_') | ||
if isinstance(criterion, torch.nn.Module): | ||
self.accelerator.prepare(criterion) | ||
|
||
return self | ||
|
||
def _initialize_module(self, *args, **kwargs): | ||
super()._initialize_module(*args, **kwargs) | ||
|
||
with self._current_init_context('module'): | ||
for name in self._modules: | ||
module = getattr(self, name + '_') | ||
thomasjpfan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(module, torch.nn.Module): | ||
self.accelerator.prepare(module) | ||
|
||
return self | ||
|
||
def _initialize_optimizer(self, *args, **kwargs): | ||
super()._initialize_optimizer(*args, **kwargs) | ||
|
||
with self._current_init_context('optimizer'): | ||
for name in self._optimizers: | ||
optimizer = getattr(self, name + '_') | ||
if isinstance(optimizer, torch.optim.Optimizer): | ||
self.accelerator.prepare(optimizer) | ||
return self | ||
|
||
def train_step_single(self, batch, **fit_params): | ||
self._set_training(True) | ||
Xi, yi = unpack_data(batch) | ||
y_pred = self.infer(Xi, **fit_params) | ||
loss = self.get_loss(y_pred, yi, X=Xi, training=True) | ||
self.accelerator.backward(loss) | ||
thomasjpfan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return { | ||
'loss': loss, | ||
'y_pred': y_pred, | ||
} | ||
|
||
def get_iterator(self, *args, **kwargs): | ||
iterator = super().get_iterator(*args, **kwargs) | ||
iterator = self.accelerator.prepare(iterator) | ||
return iterator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do set
callbacks__print_log__sink
automatically? It feels like more boilerplate for a user to write.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, true. On the other hand, since the accelerator does not exist when we define
AccelerateMixin
, we cannot set it as default argument. We would need to do something else, like:Not sure if the user would have any reason not to set this, but if they had, this hard-coding would prove annoying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having good defaults is generally better. If a user does not want to set it like this, I'm guessing they are more advanced and would need to to override
initialize
to undo it.It's a toss up for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried a different solution which hopefully lets us have the cake and eat it too. Please take a look: 9d3e5e0