Skip to content

Enhancement/auto device option #600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `NeptuneLogger` callback for logging experiment metadata to neptune.ai
- Add DataFrameTransformer, an sklearn compatible transformer that helps working with pandas DataFrames by transforming the DataFrame into a representation that works well with neural networks (#507)
- Added `None` option to `device` which leaves the device(s) unmodified (#600)

### Changed

Expand Down
7 changes: 4 additions & 3 deletions examples/benchmarks/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from skorch.utils import to_device
from skorch import NeuralNetClassifier
from skorch.callbacks import EpochScoring
import torch
Expand Down Expand Up @@ -141,7 +142,7 @@ def train_torch(
lr,
max_epochs,
):
model.to(device)
model = to_device(model, device)

idx_train, idx_valid = next(iter(StratifiedKFold(
5, random_state=0).split(np.arange(len(X)), y)))
Expand Down Expand Up @@ -191,7 +192,7 @@ def train_step(model, dataset, device, criterion, batch_size, optimizer):
batch_sizes = []
tic = time.time()
for Xi, yi in torch.utils.data.DataLoader(dataset, batch_size=batch_size):
Xi, yi = Xi.to(device), yi.to(device)
Xi, yi = to_device(Xi, device), to_device(yi, device)
optimizer.zero_grad()
y_pred = model(Xi)
y_pred = torch.log(y_pred)
Expand Down Expand Up @@ -221,7 +222,7 @@ def valid_step(model, dataset, device, criterion, batch_size):
for Xi, yi in torch.utils.data.DataLoader(
dataset, batch_size=batch_size,
):
Xi, yi = Xi.to(device), yi.to(device)
Xi, yi = to_device(Xi, device), to_device(yi, device)
y_pred = model(Xi)
y_pred = torch.log(y_pred)
loss = criterion(y_pred, yi)
Expand Down
4 changes: 3 additions & 1 deletion examples/word_language_model/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from skorch.utils import to_device

import torch
from torch.autograd import Variable

Expand Down Expand Up @@ -70,7 +72,7 @@ def batchify(self, data, bsz):
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
return data.to(self.device)
return to_device(data, self.device)

def get_batch(self, i):
seq_len = min(self.bptt, len(self.batches) - 1 - i)
Expand Down
7 changes: 4 additions & 3 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ class NeuralNet:
device : str, torch.device (default='cpu')
The compute device to be used. If set to 'cuda', data in torch
tensors will be pushed to cuda tensors before being sent to the
module.
module. If set to None, then all compute devices will be left
unmodified.

Attributes
----------
Expand Down Expand Up @@ -431,7 +432,7 @@ def initialize_criterion(self):
criterion_params = self._get_params_for('criterion')
self.criterion_ = self.criterion(**criterion_params)
if isinstance(self.criterion_, torch.nn.Module):
self.criterion_ = self.criterion_.to(self.device)
self.criterion_ = to_device(self.criterion_, self.device)
return self

def _format_reinit_msg(self, name, kwargs=None, triggered_directly=True):
Expand Down Expand Up @@ -472,7 +473,7 @@ def initialize_module(self):

module = module(**kwargs)

self.module_ = module.to(self.device)
self.module_ = to_device(module, self.device)
return self

def _is_virtual_param(self, key):
Expand Down
64 changes: 55 additions & 9 deletions skorch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,40 +149,86 @@ def x(self):
def x_tup(self):
return torch.zeros(3), torch.ones((4, 5))

@pytest.fixture
def x_pad_seq(self):
value = torch.zeros((5, 3)).float()
length = torch.as_tensor([2, 2, 1])
return pack_padded_sequence(value, length)

def check_device_type(self, tensor, device_input, prev_device):
"""assert expected device type conditioned on the input argument for `to_device`"""
if None is device_input:
assert tensor.device.type == prev_device

else:
assert tensor.device.type == device_input

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
('cuda', 'cpu'),
('cuda', 'cuda'),
(None, None),
])
def test_check_device_torch_tensor(self, to_device, x, device_from, device_to):
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
pytest.skip()

prev_device = None
if None in (device_from, device_to):
prev_device = x.device.type

x = to_device(x, device=device_from)
assert x.device.type == device_from
self.check_device_type(x, device_from, prev_device)

x = to_device(x, device=device_to)
assert x.device.type == device_to
self.check_device_type(x, device_to, prev_device)

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
('cuda', 'cpu'),
('cuda', 'cuda'),
(None, None),
])
def test_check_device_tuple_torch_tensor(
self, to_device, x, device_from, device_to):
self, to_device, x_tup, device_from, device_to):
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
pytest.skip()

x = to_device(x, device=device_from)
for xi in x:
assert xi.device.type == device_from
prev_devices = [None for _ in range(len(x_tup))]
if None in (device_from, device_to):
prev_devices = [x.device.type for x in x_tup]

x = to_device(x, device=device_to)
for xi in x:
assert xi.device.type == device_to
x_tup = to_device(x_tup, device=device_from)
for xi, prev_d in zip(x_tup, prev_devices):
self.check_device_type(xi, device_from, prev_d)

x_tup = to_device(x_tup, device=device_to)
for xi, prev_d in zip(x_tup, prev_devices):
self.check_device_type(xi, device_to, prev_d)

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
('cuda', 'cpu'),
('cuda', 'cuda'),
(None, None),
])
def test_check_device_packed_padded_sequence(
self, to_device, x_pad_seq, device_from, device_to):
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
pytest.skip()

prev_device = None
if None in (device_from, device_to):
prev_device = x_pad_seq.data.device.type

x_pad_seq = to_device(x_pad_seq, device=device_from)
self.check_device_type(x_pad_seq.data, device_from, prev_device)

x_pad_seq = to_device(x_pad_seq, device=device_to)
self.check_device_type(x_pad_seq.data, device_to, prev_device)


class TestDuplicateItems:
Expand Down
27 changes: 23 additions & 4 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_tensor(X, device, accept_sparse=False):
to_tensor_ = partial(to_tensor, device=device)

if is_torch_data_type(X):
return X.to(device)
return to_device(X, device)
if isinstance(X, dict):
return {key: to_tensor_(val) for key, val in X.items()}
if isinstance(X, (list, tuple)):
Expand Down Expand Up @@ -126,12 +126,28 @@ def to_numpy(X):


def to_device(X, device):
"""Generic function to move module output(s) to a device.
"""Generic function to modify the device type of the tensor(s) or module.

Deals with X being a torch tensor or a tuple of torch tensors.
Parameters
----------
X : input data
Deals with X being a:

* torch tensor
* tuple of torch tensors
* PackSequence instance
* torch.nn.Module

device : str, torch.device
The compute device to be used. If device=None, return the input
unmodified

"""
if isinstance(X, tuple):
if device is None:
return X

# PackedSequence class inherits from a namedtuple
if isinstance(X, tuple) and (type(X) != PackedSequence):
return tuple(x.to(device) for x in X)
return X.to(device)

Expand Down Expand Up @@ -482,6 +498,9 @@ def get_map_location(target_device, fallback_device='cpu'):
"""Determine the location to map loaded data (e.g., weights)
for a given target device (e.g. 'cuda').
"""
if target_device is None:
target_device = fallback_device

map_location = torch.device(target_device)

# The user wants to use CUDA but there is no CUDA device
Expand Down