Skip to content

Commit 3755ed4

Browse files
committed
updated to_device function with auto and None arguments to:
* use the fastest device available * leave the device type unmodified respectively (skorch-dev#600)
1 parent 560e721 commit 3755ed4

File tree

6 files changed

+98
-19
lines changed

6 files changed

+98
-19
lines changed

CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Added `NeptuneLogger` callback for logging experiment metadata to neptune.ai
1313
- Add DataFrameTransformer, an sklearn compatible transformer that helps working with pandas DataFrames by transforming the DataFrame into a representation that works well with neural networks (#507)
14+
- Added `auto` and `None` options to automatically use the fastest device available or to leave the device type unmodified respectively. (#600)
1415

1516
### Changed
1617

examples/benchmarks/mnist.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sklearn.model_selection import StratifiedKFold
3535
from sklearn.metrics import accuracy_score
3636
from sklearn.utils import shuffle
37+
from skorch.utils import to_device
3738
from skorch import NeuralNetClassifier
3839
from skorch.callbacks import EpochScoring
3940
import torch
@@ -141,7 +142,7 @@ def train_torch(
141142
lr,
142143
max_epochs,
143144
):
144-
model.to(device)
145+
model = to_device(model, device)
145146

146147
idx_train, idx_valid = next(iter(StratifiedKFold(
147148
5, random_state=0).split(np.arange(len(X)), y)))
@@ -191,7 +192,7 @@ def train_step(model, dataset, device, criterion, batch_size, optimizer):
191192
batch_sizes = []
192193
tic = time.time()
193194
for Xi, yi in torch.utils.data.DataLoader(dataset, batch_size=batch_size):
194-
Xi, yi = Xi.to(device), yi.to(device)
195+
Xi, yi = to_device(Xi, device), to_device(yi, device)
195196
optimizer.zero_grad()
196197
y_pred = model(Xi)
197198
y_pred = torch.log(y_pred)
@@ -221,7 +222,7 @@ def valid_step(model, dataset, device, criterion, batch_size):
221222
for Xi, yi in torch.utils.data.DataLoader(
222223
dataset, batch_size=batch_size,
223224
):
224-
Xi, yi = Xi.to(device), yi.to(device)
225+
Xi, yi = to_device(Xi, device), to_device(yi, device)
225226
y_pred = model(Xi)
226227
y_pred = torch.log(y_pred)
227228
loss = criterion(y_pred, yi)

examples/word_language_model/data.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from skorch.utils import to_device
4+
35
import torch
46
from torch.autograd import Variable
57

@@ -70,7 +72,7 @@ def batchify(self, data, bsz):
7072
data = data.narrow(0, 0, nbatch * bsz)
7173
# Evenly divide the data across the bsz batches.
7274
data = data.view(bsz, -1).t().contiguous()
73-
return data.to(self.device)
75+
return to_device(data, self.device)
7476

7577
def get_batch(self, i):
7678
seq_len = min(self.bptt, len(self.batches) - 1 - i)

skorch/net.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def initialize_criterion(self):
431431
criterion_params = self._get_params_for('criterion')
432432
self.criterion_ = self.criterion(**criterion_params)
433433
if isinstance(self.criterion_, torch.nn.Module):
434-
self.criterion_ = self.criterion_.to(self.device)
434+
self.criterion_ = to_device(self.criterion_, self.device)
435435
return self
436436

437437
def _format_reinit_msg(self, name, kwargs=None, triggered_directly=True):
@@ -472,7 +472,7 @@ def initialize_module(self):
472472

473473
module = module(**kwargs)
474474

475-
self.module_ = module.to(self.device)
475+
self.module_ = to_device(module, self.device)
476476
return self
477477

478478
def _is_virtual_param(self, key):

skorch/tests/test_utils.py

+63-9
Original file line numberDiff line numberDiff line change
@@ -149,40 +149,94 @@ def x(self):
149149
def x_tup(self):
150150
return torch.zeros(3), torch.ones((4, 5))
151151

152+
@pytest.fixture
153+
def x_pad_seq(self):
154+
input = torch.zeros((5, 3)).float()
155+
length = torch.as_tensor([2, 2, 1])
156+
return pack_padded_sequence(input,length )
157+
158+
def check_device_type(self, tensor, device_input, prev_device):
159+
"""assert expected device type conditioned on the input argument for `to_device`"""
160+
if None is device_input:
161+
assert tensor.device.type == prev_device
162+
163+
elif device_input == "auto":
164+
expected = 'cuda' if torch.cuda.is_available() else 'cpu'
165+
assert tensor.device.type == expected
166+
167+
else:
168+
assert tensor.device.type == device_input
169+
170+
152171
@pytest.mark.parametrize('device_from, device_to', [
153172
('cpu', 'cpu'),
154173
('cpu', 'cuda'),
155174
('cuda', 'cpu'),
156175
('cuda', 'cuda'),
176+
(None, None),
177+
('auto', 'auto'),
157178
])
158179
def test_check_device_torch_tensor(self, to_device, x, device_from, device_to):
159180
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
160181
pytest.skip()
161182

183+
prev_device = None
184+
if None in (device_from, device_to):
185+
prev_device = x.device.type
186+
162187
x = to_device(x, device=device_from)
163-
assert x.device.type == device_from
188+
self.check_device_type(x, device_from, prev_device)
164189

165190
x = to_device(x, device=device_to)
166-
assert x.device.type == device_to
191+
self.check_device_type(x, device_to, prev_device)
167192

168193
@pytest.mark.parametrize('device_from, device_to', [
169194
('cpu', 'cpu'),
170195
('cpu', 'cuda'),
171196
('cuda', 'cpu'),
172197
('cuda', 'cuda'),
198+
(None, None),
199+
('auto', 'auto'),
173200
])
174201
def test_check_device_tuple_torch_tensor(
175-
self, to_device, x, device_from, device_to):
202+
self, to_device, x_tup, device_from, device_to):
176203
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
177204
pytest.skip()
178205

179-
x = to_device(x, device=device_from)
180-
for xi in x:
181-
assert xi.device.type == device_from
206+
prev_device = [None for _ in range(len(x_tup))]
207+
if None in (device_from, device_to):
208+
prev_device = [x.device.type for x in x_tup]
182209

183-
x = to_device(x, device=device_to)
184-
for xi in x:
185-
assert xi.device.type == device_to
210+
x_tup = to_device(x_tup, device=device_from)
211+
for idx, xi in enumerate(x_tup):
212+
self.check_device_type(xi, device_from, prev_device[idx])
213+
214+
x_tup = to_device(x_tup, device=device_to)
215+
for idx, xi in enumerate(x_tup):
216+
self.check_device_type(xi, device_to, prev_device[idx])
217+
218+
@pytest.mark.parametrize('device_from, device_to', [
219+
('cpu', 'cpu'),
220+
('cpu', 'cuda'),
221+
('cuda', 'cpu'),
222+
('cuda', 'cuda'),
223+
(None, None),
224+
('auto', 'auto'),
225+
])
226+
def test_check_device_packed_padded_sequence(
227+
self, to_device, x_pad_seq, device_from, device_to):
228+
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
229+
pytest.skip()
230+
231+
prev_device = None
232+
if None in (device_from, device_to):
233+
prev_device = x_pad_seq.data.device.type
234+
235+
x_pad_seq = to_device(x_pad_seq, device=device_from)
236+
self.check_device_type(x_pad_seq.data, device_from, prev_device)
237+
238+
x_pad_seq = to_device(x_pad_seq, device=device_to)
239+
self.check_device_type(x_pad_seq.data, device_to, prev_device)
186240

187241

188242
class TestDuplicateItems:

skorch/utils.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def to_tensor(X, device, accept_sparse=False):
8080
to_tensor_ = partial(to_tensor, device=device)
8181

8282
if is_torch_data_type(X):
83-
return X.to(device)
83+
return to_device(X, device)
8484
if isinstance(X, dict):
8585
return {key: to_tensor_(val) for key, val in X.items()}
8686
if isinstance(X, (list, tuple)):
@@ -126,12 +126,33 @@ def to_numpy(X):
126126

127127

128128
def to_device(X, device):
129-
"""Generic function to move module output(s) to a device.
129+
"""Generic function to modify the device type of the tensor(s)/module inputted.
130130
131-
Deals with X being a torch tensor or a tuple of torch tensors.
131+
Parameters
132+
----------
133+
X : input data
134+
Deals with X being a:
135+
136+
* torch tensor
137+
* tuple of torch tensors
138+
* PackSequence instance
139+
* torch.nn.Module
140+
141+
device : str, torch.device
142+
The compute device to be used. If device="auto" it is set to
143+
"cuda" if available otherwise "cpu". If device=None, return
144+
the input unmodified
132145
133146
"""
134-
if isinstance(X, tuple):
147+
if device is None:
148+
return X
149+
150+
if device == "auto":
151+
use_cuda = torch.cuda.is_available()
152+
device = "cuda" if use_cuda else "cpu"
153+
154+
# PackedSequence class inherits from a namedtuple
155+
if isinstance(X, tuple) & (type(X) != PackedSequence):
135156
return tuple(x.to(device) for x in X)
136157
return X.to(device)
137158

0 commit comments

Comments
 (0)