Skip to content

Feature: support dict-like input for to_numpy and to_device #657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 29, 2020
39 changes: 38 additions & 1 deletion skorch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.nn.utils.rnn import pack_padded_sequence

from skorch.tests.conftest import pandas_installed

from copy import deepcopy

class TestToTensor:
@pytest.fixture
Expand Down Expand Up @@ -149,6 +149,13 @@ def x(self):
def x_tup(self):
return torch.zeros(3), torch.ones((4, 5))

@pytest.fixture
def x_dict(self):
return {
'x': torch.zeros(3),
'y': torch.ones((4, 5))
}

@pytest.fixture
def x_pad_seq(self):
value = torch.zeros((5, 3)).float()
Expand Down Expand Up @@ -207,7 +214,37 @@ def test_check_device_tuple_torch_tensor(
x_tup = to_device(x_tup, device=device_to)
for xi, prev_d in zip(x_tup, prev_devices):
self.check_device_type(xi, device_to, prev_d)

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
('cuda', 'cpu'),
('cuda', 'cuda'),
(None, None),
])
def test_check_device_dict_torch_tensor(
self, to_device, x_dict, device_from, device_to):
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
pytest.skip()

original_x_dict = deepcopy(x_dict)

prev_devices=[None for _ in range(len(list(x_dict.keys())))]
if None in (device_from, device_to):
prev_devices = [x.device.type for x in x_dict.values()]

new_x_dict = to_device(x_dict, device=device_from)
for xi, prev_d in zip(new_x_dict.values(), prev_devices):
self.check_device_type(xi, device_from, prev_d)

new_x_dict = to_device(new_x_dict, device=device_to)
for xi, prev_d in zip(new_x_dict.values(), prev_devices):
self.check_device_type(xi, device_to, prev_d)

assert x_dict.keys() == original_x_dict.keys()
for k in x_dict:
assert np.allclose(x_dict[k], original_x_dict[k])

@pytest.mark.parametrize('device_from, device_to', [
('cpu', 'cpu'),
('cpu', 'cuda'),
Expand Down
7 changes: 7 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def to_numpy(X):
if isinstance(X, np.ndarray):
return X

if isinstance(X, dict):
return {key: to_numpy(val) for key, val in X.items()}

if is_pandas_ndframe(X):
return X.values

Expand All @@ -135,6 +138,7 @@ def to_device(X, device):

* torch tensor
* tuple of torch tensors
* dict of torch tensors
* PackSequence instance
* torch.nn.Module

Expand All @@ -146,6 +150,9 @@ def to_device(X, device):
if device is None:
return X

if isinstance(X, dict):
return {key: to_device(val,device) for key, val in X.items()}

# PackedSequence class inherits from a namedtuple
if isinstance(X, tuple) and (type(X) != PackedSequence):
return tuple(x.to(device) for x in X)
Expand Down