Skip to content

Commit c50f3bb

Browse files
benjamin-workottonemo
authored andcommitted
Rewrite History to not use any recursion. (skorch-dev#312)
* Rewrite History to not use any recursion. Instead unroll the successive indexing steps and perform them backwards, i.e. starting with batches and following with epochs. * Raise a KeyError when history indexing deeper than 4. * Add a benchmark script to test History. * Address comments by ottonemo. * clarifying comments * deprecation warning
1 parent d6e0b9b commit c50f3bb

File tree

3 files changed

+340
-69
lines changed

3 files changed

+340
-69
lines changed

examples/benchmarks/history.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Benchmark to test time and memory performance of History.
2+
3+
Before #312, the timing would be roughly 5 sec and memory usage would
4+
triple. After #312, the timing would be roughly 2 sec and memory usage
5+
roughly constant.
6+
7+
For the reasons, see #306.
8+
9+
"""
10+
11+
from pprint import pprint
12+
import time
13+
14+
import numpy as np
15+
from sklearn.datasets import make_classification
16+
import torch
17+
18+
from skorch import NeuralNetClassifier
19+
from skorch.callbacks import Callback
20+
from skorch.toy import make_classifier
21+
22+
23+
side_effects = []
24+
25+
26+
class TriggerKeyError(Callback):
27+
def on_batch_end(self, net, **kwargs):
28+
try:
29+
net.history[-1, 'batches', -1, 'foobar']
30+
except Exception as e:
31+
pass
32+
33+
34+
class PrintMemory(Callback):
35+
def on_batch_end(self, net, **kwargs):
36+
side_effects.append((
37+
torch.cuda.memory_allocated() / 1e6,
38+
torch.cuda.memory_cached() / 1e6
39+
))
40+
41+
42+
def train():
43+
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
44+
X = X.astype(np.float32)
45+
y = y.astype(np.int64)
46+
47+
module = make_classifier(input_units=20)
48+
49+
net = NeuralNetClassifier(
50+
module,
51+
max_epochs=10,
52+
lr=0.1,
53+
callbacks=[TriggerKeyError(), PrintMemory()],
54+
device='cuda',
55+
)
56+
57+
return net.fit(X, y)
58+
59+
60+
def safe_slice(history, keys):
61+
# catch errors
62+
for key in keys:
63+
try:
64+
history[key]
65+
except (KeyError, IndexError):
66+
pass
67+
68+
69+
def performance_history(history):
70+
# SUCCESSFUL
71+
# level 0
72+
for i in range(len(history)):
73+
history[i]
74+
75+
# level 1
76+
keys = tuple(history[0].keys())
77+
history[0, keys]
78+
history[:, keys]
79+
for key in keys:
80+
history[0, key]
81+
history[:, key]
82+
83+
# level 2
84+
for i in range(len(history[0, 'batches'])):
85+
history[0, 'batches', i]
86+
history[:, 'batches', i]
87+
history[:, 'batches', :]
88+
89+
# level 3
90+
keys = tuple(history[0, 'batches', 0].keys())
91+
history[0, 'batches', 0, keys]
92+
history[:, 'batches', 0, keys]
93+
history[0, 'batches', :, keys]
94+
history[:, 'batches', :, keys]
95+
for key in history[0, 'batches', 0]:
96+
history[0, 'batches', 0, key]
97+
history[:, 'batches', 0, key]
98+
history[0, 'batches', :, key]
99+
history[:, 'batches', :, key]
100+
101+
# KEY ERRORS
102+
# level 0
103+
safe_slice(history, [100000])
104+
105+
# level 1
106+
safe_slice(history, [np.s_[0, 'foo'], np.s_[:, 'foo']])
107+
108+
# level 2
109+
safe_slice(history, [
110+
np.s_[0, 'batches', 0],
111+
np.s_[:, 'batches', 0],
112+
np.s_[0, 'batches', :],
113+
np.s_[:, 'batches', :],
114+
])
115+
116+
# level 3
117+
safe_slice(history, [
118+
np.s_[0, 'batches', 0, 'foo'],
119+
np.s_[:, 'batches', 0, 'foo'],
120+
np.s_[0, 'batches', :, 'foo'],
121+
np.s_[:, 'batches', :, 'foo'],
122+
np.s_[0, 'batches', 0, ('foo', 'bar')],
123+
np.s_[:, 'batches', 0, ('foo', 'bar')],
124+
np.s_[0, 'batches', :, ('foo', 'bar')],
125+
np.s_[:, 'batches', :, ('foo', 'bar')],
126+
])
127+
128+
if __name__ == '__main__':
129+
net = train()
130+
tic = time.time()
131+
for _ in range(1000):
132+
performance_history(net.history)
133+
toc = time.time()
134+
print("Time for performing 1000 runs: {:.5f} sec.".format(toc - tic))
135+
assert toc - tic < 10, "accessing history is too slow"
136+
137+
print("Allocated / cached memory")
138+
pprint(side_effects)
139+
140+
mem_start = side_effects[0][0]
141+
mem_end = side_effects[-1][0]
142+
143+
print("Memory epoch 1: {:.4f}, last epoch: {:.4f}".format(
144+
mem_start, mem_end))
145+
assert np.isclose(mem_start, mem_end, rtol=1/3), "memory use should be similar"

skorch/history.py

+121-68
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,71 @@
11
"""Contains history class and helper functions."""
22

3+
import warnings
4+
35

46
# pylint: disable=invalid-name
5-
class _missingno:
6-
def __init__(self, e):
7-
self.e = e
8-
9-
def __repr__(self):
10-
return 'missingno'
11-
12-
13-
def _incomplete_mapper(x):
14-
for xs in x:
15-
# pylint: disable=unidiomatic-typecheck
16-
if type(xs) is _missingno:
17-
return xs
18-
return x
19-
20-
21-
# pylint: disable=missing-docstring
22-
def partial_index(l, idx):
23-
needs_unrolling = (
24-
isinstance(l, list) and len(l) > 0 and isinstance(l[0], list))
25-
types = int, tuple, list, slice
26-
needs_indirection = isinstance(l, list) and not isinstance(idx, types)
27-
28-
if needs_unrolling or needs_indirection:
29-
return [partial_index(n, idx) for n in l]
30-
31-
# join results of multiple indices
32-
if isinstance(idx, (tuple, list)):
33-
zz = [partial_index(l, n) for n in idx]
34-
if isinstance(l, list):
35-
total_join = zip(*zz)
36-
inner_join = list(map(_incomplete_mapper, total_join))
37-
else:
38-
total_join = tuple(zz)
39-
inner_join = _incomplete_mapper(total_join)
40-
return inner_join
41-
42-
try:
43-
return l[idx]
44-
except KeyError as e:
45-
return _missingno(e)
46-
47-
48-
# pylint: disable=missing-docstring
49-
def filter_missing(x):
50-
if isinstance(x, list):
51-
children = [filter_missing(n) for n in x]
52-
# pylint: disable=unidiomatic-typecheck
53-
filtered = list(filter(lambda x: type(x) != _missingno, children))
54-
55-
if children and not filtered:
56-
# pylint: disable=unidiomatic-typecheck
57-
return next(filter(lambda x: type(x) == _missingno, children))
58-
return filtered
59-
return x
7+
class _none:
8+
"""Special placeholder since ``None`` is a valid value."""
9+
10+
11+
def _not_none(items):
12+
"""Whether the item is a placeholder or contains a placeholder."""
13+
if not isinstance(items, (tuple, list)):
14+
items = (items,)
15+
return all(item is not _none for item in items)
16+
17+
18+
def _filter_none(items):
19+
"""Filter special placeholder value, preserves sequence type."""
20+
type_ = list if isinstance(items, list) else tuple
21+
return type_(filter(_not_none, items))
22+
23+
24+
def _getitem(item, i):
25+
"""Extract value or values from dicts.
26+
27+
Covers the case of a single key or multiple keys. If not found,
28+
return placeholders instead.
29+
30+
"""
31+
if not isinstance(i, (tuple, list)):
32+
return item.get(i, _none)
33+
type_ = list if isinstance(item, list) else tuple
34+
return type_(item.get(j, _none) for j in i)
35+
36+
37+
def _unpack_index(i):
38+
"""Unpack index and return exactly four elements.
39+
40+
If index is more shallow than 4, return None for trailing
41+
dimensions. If index is deeper than 4, raise a KeyError.
42+
43+
"""
44+
if len(i) > 4:
45+
raise KeyError(
46+
"Tried to index history with {} indices but only "
47+
"4 indices are possible.".format(len(i)))
48+
49+
# fill trailing indices with None
50+
i_e, k_e, i_b, k_b = i + tuple([None] * (4 - len(i)))
51+
52+
# handle special case of
53+
# history[j, 'batches', somekey]
54+
# which should really be
55+
# history[j, 'batches', :, somekey]
56+
if i_b is not None and not isinstance(i_b, (int, slice)):
57+
if k_b is not None:
58+
raise KeyError("The last argument '{}' is invalid; it must be a "
59+
"string or tuple of strings.".format(k_b))
60+
warnings.warn(
61+
"Argument 3 to history slicing must be of type int or slice, e.g. "
62+
"history[:, 'batches', 'train_loss'] should be "
63+
"history[:, 'batches', :, 'train_loss'].",
64+
DeprecationWarning,
65+
)
66+
i_b, k_b = slice(None), i_b
67+
68+
return i_e, k_e, i_b, k_b
6069

6170

6271
class History(list):
@@ -128,6 +137,7 @@ def new_epoch(self):
128137

129138
def new_batch(self):
130139
"""Register a new batch row for the current epoch."""
140+
# pylint: disable=invalid-sequence-index
131141
self[-1]['batches'].append({})
132142

133143
def record(self, attr, value):
@@ -145,24 +155,67 @@ def record_batch(self, attr, value):
145155
batch.
146156
147157
"""
158+
# pylint: disable=invalid-sequence-index
148159
self[-1]['batches'][-1][attr] = value
149160

150161
def to_list(self):
151162
"""Return history object as a list."""
152163
return list(self)
153164

154165
def __getitem__(self, i):
166+
# This implementation resolves indexing backwards,
167+
# i.e. starting from the batches, then progressing to the
168+
# epochs.
155169
if isinstance(i, (int, slice)):
156-
return super().__getitem__(i)
157-
158-
x = self
159-
if isinstance(i, tuple):
160-
for part in i:
161-
x_dirty = partial_index(x, part)
162-
x = filter_missing(x_dirty)
163-
# pylint: disable=unidiomatic-typecheck
164-
if type(x) is _missingno:
165-
raise x.e
166-
return x
167-
raise ValueError("Invalid parameter type passed to index. "
168-
"Pass string, int or tuple.")
170+
i = (i,)
171+
172+
# i_e: index epoch, k_e: key epoch
173+
# i_b: index batch, k_b: key batch
174+
i_e, k_e, i_b, k_b = _unpack_index(i)
175+
keyerror_msg = "Key '{}' was not found in history."
176+
177+
if i_b is not None and k_e != 'batches':
178+
raise KeyError("History indexing beyond the 2nd level is "
179+
"only possible if key 'batches' is used, "
180+
"found key '{}'.".format(k_e))
181+
182+
items = self.to_list()
183+
184+
# extract indices of batches
185+
# handles: history[..., k_e, i_b]
186+
if i_b is not None:
187+
items = [row[k_e][i_b] for row in items]
188+
189+
# extract keys of batches
190+
# handles: history[..., k_e, i_b][k_b]
191+
if k_b is not None:
192+
items = [
193+
_filter_none([_getitem(b, k_b) for b in batches])
194+
if isinstance(batches, (list, tuple))
195+
else _getitem(batches, k_b)
196+
for batches in items
197+
]
198+
# get rid of empty batches
199+
items = [b for b in items if b not in (_none, [], ())]
200+
if not _filter_none(items):
201+
# all rows contained _none or were empty
202+
raise KeyError(keyerror_msg.format(k_b))
203+
204+
# extract epoch-level values, but only if not already done
205+
# handles: history[..., k_e]
206+
if (k_e is not None) and (i_b is None):
207+
items = [_getitem(batches, k_e)
208+
for batches in items]
209+
if not _filter_none(items):
210+
raise KeyError(keyerror_msg.format(k_e))
211+
212+
# extract the epochs
213+
# handles: history[i_b, ..., ..., ...]
214+
if i_e is not None:
215+
items = items[i_e]
216+
if isinstance(i_e, slice):
217+
items = _filter_none(items)
218+
if items is _none:
219+
raise KeyError(keyerror_msg.format(k_e))
220+
221+
return items

0 commit comments

Comments
 (0)