Skip to content

Commit 9d3e5e0

Browse files
Automatically use accelerator.print by default
This is achieved by setting the default callbacks__print_log__print to 'auto'. When this is detected later, replace it with self.accelerator.print if available. This way, we can use the sane default even if we cannot directly set it as default (because the accelerator instance does not exist yet) but still give the user the option to set a different argument.
1 parent 9c22965 commit 9d3e5e0

File tree

3 files changed

+51
-11
lines changed

3 files changed

+51
-11
lines changed

docs/user/helper.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,11 @@ the desired parameters and you're good to go:
8181
MyModule,
8282
accelerator=accelerator,
8383
device=None,
84-
callbacks__print_log__sink=accelerator.print)
84+
)
8585
net.fit(X, y)
8686
8787
accelerate_ recommends to leave the device handling to the Accelerator_, which
8888
is why we set ``device=None`` (thus telling skorch not to change the device).
89-
Furthermore, using ``accelerator.print`` should avoid printing the same output
90-
multiple times when training concurrently on multiple machines.
9189

9290
To install accelerate_, run the following command inside your Python environment:
9391

skorch/helper.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ class was added: Using this mixin in conjunction with the accelerate library
551551
... MyModule,
552552
... accelerator=accelerator,
553553
... device=None,
554-
... callbacks__print_log__sink=accelerator.print)
554+
... )
555555
>>> net.fit(X, y)
556556
557557
The same approach works with all the other skorch net classes.
@@ -562,9 +562,17 @@ class was added: Using this mixin in conjunction with the accelerate library
562562
In addition to the usual parameters, pass an instance of
563563
``accelerate.Accelerator`` with the desired settings.
564564
565+
callbacks__print_log__sink : 'auto' or callable
566+
If 'auto', uses the ``print`` function of the accelerator, if it has one.
567+
This avoids printing the same output multiple times when training
568+
concurrently on multiple machines. If the accelerator does not have a
569+
``print`` function, use Python's ``print`` function instead.
570+
565571
"""
566-
def __init__(self, *args, accelerator, **kwargs):
567-
super().__init__(*args, **kwargs)
572+
def __init__(self, *args, accelerator, callbacks__print_log__sink='auto', **kwargs):
573+
super().__init__(
574+
*args, callbacks__print_log__sink=callbacks__print_log__sink, **kwargs
575+
)
568576
self.accelerator = accelerator
569577

570578
def _check_kwargs(self, kwargs):
@@ -575,6 +583,13 @@ def _check_kwargs(self, kwargs):
575583
"When device placement is performed by the accelerator, set device=None"
576584
)
577585

586+
def _initialize_callbacks(self):
587+
if self.callbacks__print_log__sink == 'auto':
588+
print_func = getattr(self.accelerator, 'print', print)
589+
self.callbacks__print_log__sink = print_func
590+
super()._initialize_callbacks()
591+
return self
592+
578593
def _initialize_criterion(self, *args, **kwargs):
579594
super()._initialize_criterion(*args, **kwargs)
580595

skorch/tests/test_helper.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pickle
33
from distutils.version import LooseVersion
44
from functools import partial
5+
from unittest.mock import Mock
56

67
import numpy as np
78
import pytest
@@ -760,10 +761,7 @@ def test_mixed_precision(self, net_cls, accelerator_cls, data, mixed_precision):
760761
fp16 = mixed_precision == 'fp16'
761762
accelerator = accelerator_cls(fp16=fp16)
762763

763-
net = net_cls(
764-
accelerator=accelerator,
765-
callbacks__print_log__sink=accelerator.print,
766-
)
764+
net = net_cls(accelerator=accelerator)
767765
X, y = data
768766
net.fit(X, y) # does not raise
769767

@@ -781,6 +779,36 @@ def test_device_placement(self, net_cls, accelerator_cls, data):
781779
with pytest.raises(ValueError, match=msg):
782780
net.fit(*data)
783781

782+
def test_print_log_sink_auto_uses_accelerator_print(self, net_cls, accelerator_cls):
783+
# the net defaults to using the accelerator's print function
784+
accelerator = accelerator_cls()
785+
net = net_cls(accelerator=accelerator)
786+
net.initialize()
787+
print_log = dict(net.callbacks_)['print_log']
788+
assert print_log.sink == accelerator.print
789+
790+
def test_print_log_sink_can_be_overwritten(self, net_cls, accelerator_cls):
791+
# users can still set their own sinks for print log
792+
accelerator = accelerator_cls()
793+
net = net_cls(accelerator=accelerator, callbacks__print_log__sink=123)
794+
net.initialize()
795+
print_log = dict(net.callbacks_)['print_log']
796+
assert print_log.sink == 123
797+
798+
def test_print_log_sink_uses_print_if_accelerator_has_no_print(
799+
self, net_cls, accelerator_cls
800+
):
801+
# we should not depend on the accelerator having a print function
802+
803+
# we need to use Mock here because Accelerator does not allow attr
804+
# deletion
805+
accelerator = Mock(spec=accelerator_cls())
806+
delattr(accelerator, 'print')
807+
net = net_cls(accelerator=accelerator)
808+
net.initialize()
809+
print_log = dict(net.callbacks_)['print_log']
810+
assert print_log.sink is print
811+
784812
def test_all_components_prepared(self, module_cls, data):
785813
# We cannot test whether accelerate is really performing its job.
786814
# Instead, we test that all modules and optimizers, even custom
@@ -858,7 +886,6 @@ def train_step_single(self, *args, **kwargs):
858886
device=None,
859887
accelerator=accelerator,
860888
max_epochs=2,
861-
callbacks__print_log__sink=accelerator.print,
862889
)
863890
X, y = data
864891
# does not raise

0 commit comments

Comments
 (0)