Skip to content

Commit 6e9cc18

Browse files
authored
to_device: Handle nested lists/tuples recursively (#658)
to_device: Handle nested lists/tuples recursively The previous implementation of `to_device` would break when a user decided to return a list of tensors in `forward`. This patch applies `to_device` recursively and adds support for lists in addition to tuples.
1 parent b2560a4 commit 6e9cc18

File tree

3 files changed

+110
-5
lines changed

3 files changed

+110
-5
lines changed

CHANGES.md

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121
- Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626)
2222
- Set train/validation on criterion if it's a PyTorch module (#621)
2323
- Don't pass `y=None` to `NeuralNet.train_split` to enable the direct use of split functions without positional `y` in their signatures. This is useful when working with unsupervised data (#605).
24+
- `to_numpy` is now able to unpack dicts and lists/tuples (#657, #658)
2425

2526
### Fixed
2627

skorch/tests/test_utils.py

+100-3
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,69 @@ def test_sparse_tensor_not_accepted_raises(self, to_tensor, device):
135135
assert exc.value.args[0] == msg
136136

137137

138+
class TestToNumpy:
139+
@pytest.fixture
140+
def to_numpy(self):
141+
from skorch.utils import to_numpy
142+
return to_numpy
143+
144+
@pytest.fixture
145+
def x_tensor(self):
146+
return torch.zeros(3, 4)
147+
148+
@pytest.fixture
149+
def x_tuple(self):
150+
return torch.ones(3), torch.zeros(3, 4)
151+
152+
@pytest.fixture
153+
def x_list(self):
154+
return [torch.ones(3), torch.zeros(3, 4)]
155+
156+
@pytest.fixture
157+
def x_dict(self):
158+
return {'a': torch.ones(3), 'b': (torch.zeros(2), torch.zeros(3))}
159+
160+
def compare_array_to_tensor(self, x_numpy, x_tensor):
161+
assert isinstance(x_tensor, torch.Tensor)
162+
assert isinstance(x_numpy, np.ndarray)
163+
assert x_numpy.shape == x_tensor.shape
164+
for a, b in zip(x_numpy.flatten(), x_tensor.flatten()):
165+
assert np.isclose(a, b.item())
166+
167+
def test_tensor(self, to_numpy, x_tensor):
168+
x_numpy = to_numpy(x_tensor)
169+
self.compare_array_to_tensor(x_numpy, x_tensor)
170+
171+
def test_list(self, to_numpy, x_list):
172+
x_numpy = to_numpy(x_list)
173+
for entry_numpy, entry_torch in zip(x_numpy, x_list):
174+
self.compare_array_to_tensor(entry_numpy, entry_torch)
175+
176+
def test_tuple(self, to_numpy, x_tuple):
177+
x_numpy = to_numpy(x_tuple)
178+
for entry_numpy, entry_torch in zip(x_numpy, x_tuple):
179+
self.compare_array_to_tensor(entry_numpy, entry_torch)
180+
181+
def test_dict(self, to_numpy, x_dict):
182+
x_numpy = to_numpy(x_dict)
183+
self.compare_array_to_tensor(x_numpy['a'], x_dict['a'])
184+
self.compare_array_to_tensor(x_numpy['b'][0], x_dict['b'][0])
185+
self.compare_array_to_tensor(x_numpy['b'][1], x_dict['b'][1])
186+
187+
@pytest.mark.parametrize('x_invalid', [
188+
1,
189+
[1,2,3],
190+
(1,2,3),
191+
{'a': 1},
192+
])
193+
def test_invalid_inputs(self, to_numpy, x_invalid):
194+
# Inputs that are invalid for the scope of to_numpy.
195+
with pytest.raises(TypeError) as e:
196+
to_numpy(x_invalid)
197+
expected = "Cannot convert this data type to a numpy array."
198+
assert e.value.args[0] == expected
199+
200+
138201
class TestToDevice:
139202
@pytest.fixture
140203
def to_device(self):
@@ -155,13 +218,17 @@ def x_dict(self):
155218
'x': torch.zeros(3),
156219
'y': torch.ones((4, 5))
157220
}
158-
221+
159222
@pytest.fixture
160223
def x_pad_seq(self):
161224
value = torch.zeros((5, 3)).float()
162225
length = torch.as_tensor([2, 2, 1])
163226
return pack_padded_sequence(value, length)
164227

228+
@pytest.fixture
229+
def x_list(self):
230+
return [torch.zeros(3), torch.ones(2, 4)]
231+
165232
def check_device_type(self, tensor, device_input, prev_device):
166233
"""assert expected device type conditioned on the input argument for `to_device`"""
167234
if None is device_input:
@@ -214,7 +281,7 @@ def test_check_device_tuple_torch_tensor(
214281
x_tup = to_device(x_tup, device=device_to)
215282
for xi, prev_d in zip(x_tup, prev_devices):
216283
self.check_device_type(xi, device_to, prev_d)
217-
284+
218285
@pytest.mark.parametrize('device_from, device_to', [
219286
('cpu', 'cpu'),
220287
('cpu', 'cuda'),
@@ -244,7 +311,7 @@ def test_check_device_dict_torch_tensor(
244311
assert x_dict.keys() == original_x_dict.keys()
245312
for k in x_dict:
246313
assert np.allclose(x_dict[k], original_x_dict[k])
247-
314+
248315
@pytest.mark.parametrize('device_from, device_to', [
249316
('cpu', 'cpu'),
250317
('cpu', 'cuda'),
@@ -267,6 +334,36 @@ def test_check_device_packed_padded_sequence(
267334
x_pad_seq = to_device(x_pad_seq, device=device_to)
268335
self.check_device_type(x_pad_seq.data, device_to, prev_device)
269336

337+
@pytest.mark.parametrize('device_from, device_to', [
338+
('cpu', 'cpu'),
339+
('cpu', 'cuda'),
340+
('cuda', 'cpu'),
341+
('cuda', 'cuda'),
342+
(None, None),
343+
])
344+
def test_nested_data(self, to_device, x_list, device_from, device_to):
345+
# Sometimes data is nested because it would need to be padded so it's
346+
# easier to return a list of tensors with different shapes.
347+
# to_device should honor this.
348+
if 'cuda' in (device_from, device_to) and not torch.cuda.is_available():
349+
pytest.skip()
350+
351+
prev_devices = [None for _ in range(len(x_list))]
352+
if None in (device_from, device_to):
353+
prev_devices = [x.device.type for x in x_list]
354+
355+
x_list = to_device(x_list, device=device_from)
356+
assert isinstance(x_list, list)
357+
358+
for xi, prev_d in zip(x_list, prev_devices):
359+
self.check_device_type(xi, device_from, prev_d)
360+
361+
x_list = to_device(x_list, device=device_to)
362+
assert isinstance(x_list, list)
363+
364+
for xi, prev_d in zip(x_list, prev_devices):
365+
self.check_device_type(xi, device_to, prev_d)
366+
270367

271368
class TestDuplicateItems:
272369
@pytest.fixture

skorch/utils.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def to_tensor(X, device, accept_sparse=False):
104104
def to_numpy(X):
105105
"""Generic function to convert a pytorch tensor to numpy.
106106
107+
This function tries to unpack the tensor(s) from supported
108+
data structures (e.g., dicts, lists, etc.) but doesn't go
109+
beyond.
110+
107111
Returns X when it already is a numpy array.
108112
109113
"""
@@ -116,6 +120,9 @@ def to_numpy(X):
116120
if is_pandas_ndframe(X):
117121
return X.values
118122

123+
if isinstance(X, (tuple, list)):
124+
return type(X)(to_numpy(x) for x in X)
125+
119126
if not is_torch_data_type(X):
120127
raise TypeError("Cannot convert this data type to a numpy array.")
121128

@@ -154,8 +161,8 @@ def to_device(X, device):
154161
return {key: to_device(val,device) for key, val in X.items()}
155162

156163
# PackedSequence class inherits from a namedtuple
157-
if isinstance(X, tuple) and (type(X) != PackedSequence):
158-
return tuple(x.to(device) for x in X)
164+
if isinstance(X, (tuple, list)) and (type(X) != PackedSequence):
165+
return type(X)(to_device(x, device) for x in X)
159166
return X.to(device)
160167

161168

0 commit comments

Comments
 (0)