1
1
"""Contains history class and helper functions."""
2
2
3
+ import warnings
4
+
3
5
4
6
# pylint: disable=invalid-name
5
- class _missingno :
6
- def __init__ (self , e ):
7
- self .e = e
8
-
9
- def __repr__ (self ):
10
- return 'missingno'
11
-
12
-
13
- def _incomplete_mapper (x ):
14
- for xs in x :
15
- # pylint: disable=unidiomatic-typecheck
16
- if type (xs ) is _missingno :
17
- return xs
18
- return x
19
-
20
-
21
- # pylint: disable=missing-docstring
22
- def partial_index (l , idx ):
23
- needs_unrolling = (
24
- isinstance (l , list ) and len (l ) > 0 and isinstance (l [0 ], list ))
25
- types = int , tuple , list , slice
26
- needs_indirection = isinstance (l , list ) and not isinstance (idx , types )
27
-
28
- if needs_unrolling or needs_indirection :
29
- return [partial_index (n , idx ) for n in l ]
30
-
31
- # join results of multiple indices
32
- if isinstance (idx , (tuple , list )):
33
- zz = [partial_index (l , n ) for n in idx ]
34
- if isinstance (l , list ):
35
- total_join = zip (* zz )
36
- inner_join = list (map (_incomplete_mapper , total_join ))
37
- else :
38
- total_join = tuple (zz )
39
- inner_join = _incomplete_mapper (total_join )
40
- return inner_join
41
-
42
- try :
43
- return l [idx ]
44
- except KeyError as e :
45
- return _missingno (e )
46
-
47
-
48
- # pylint: disable=missing-docstring
49
- def filter_missing (x ):
50
- if isinstance (x , list ):
51
- children = [filter_missing (n ) for n in x ]
52
- # pylint: disable=unidiomatic-typecheck
53
- filtered = list (filter (lambda x : type (x ) != _missingno , children ))
54
-
55
- if children and not filtered :
56
- # pylint: disable=unidiomatic-typecheck
57
- return next (filter (lambda x : type (x ) == _missingno , children ))
58
- return filtered
59
- return x
7
+ class _none :
8
+ """Special placeholder since ``None`` is a valid value."""
9
+
10
+
11
+ def _not_none (items ):
12
+ """Whether the item is a placeholder or contains a placeholder."""
13
+ if not isinstance (items , (tuple , list )):
14
+ items = (items ,)
15
+ return all (item is not _none for item in items )
16
+
17
+
18
+ def _filter_none (items ):
19
+ """Filter special placeholder value, preserves sequence type."""
20
+ type_ = list if isinstance (items , list ) else tuple
21
+ return type_ (filter (_not_none , items ))
22
+
23
+
24
+ def _getitem (item , i ):
25
+ """Extract value or values from dicts.
26
+
27
+ Covers the case of a single key or multiple keys. If not found,
28
+ return placeholders instead.
29
+
30
+ """
31
+ if not isinstance (i , (tuple , list )):
32
+ return item .get (i , _none )
33
+ type_ = list if isinstance (item , list ) else tuple
34
+ return type_ (item .get (j , _none ) for j in i )
35
+
36
+
37
+ def _unpack_index (i ):
38
+ """Unpack index and return exactly four elements.
39
+
40
+ If index is more shallow than 4, return None for trailing
41
+ dimensions. If index is deeper than 4, raise a KeyError.
42
+
43
+ """
44
+ if len (i ) > 4 :
45
+ raise KeyError (
46
+ "Tried to index history with {} indices but only "
47
+ "4 indices are possible." .format (len (i )))
48
+
49
+ # fill trailing indices with None
50
+ i_e , k_e , i_b , k_b = i + tuple ([None ] * (4 - len (i )))
51
+
52
+ # handle special case of
53
+ # history[j, 'batches', somekey]
54
+ # which should really be
55
+ # history[j, 'batches', :, somekey]
56
+ if i_b is not None and not isinstance (i_b , (int , slice )):
57
+ if k_b is not None :
58
+ raise KeyError ("The last argument '{}' is invalid; it must be a "
59
+ "string or tuple of strings." .format (k_b ))
60
+ warnings .warn (
61
+ "Argument 3 to history slicing must be of type int or slice, e.g. "
62
+ "history[:, 'batches', 'train_loss'] should be "
63
+ "history[:, 'batches', :, 'train_loss']." ,
64
+ DeprecationWarning ,
65
+ )
66
+ i_b , k_b = slice (None ), i_b
67
+
68
+ return i_e , k_e , i_b , k_b
60
69
61
70
62
71
class History (list ):
@@ -128,6 +137,7 @@ def new_epoch(self):
128
137
129
138
def new_batch (self ):
130
139
"""Register a new batch row for the current epoch."""
140
+ # pylint: disable=invalid-sequence-index
131
141
self [- 1 ]['batches' ].append ({})
132
142
133
143
def record (self , attr , value ):
@@ -145,24 +155,67 @@ def record_batch(self, attr, value):
145
155
batch.
146
156
147
157
"""
158
+ # pylint: disable=invalid-sequence-index
148
159
self [- 1 ]['batches' ][- 1 ][attr ] = value
149
160
150
161
def to_list (self ):
151
162
"""Return history object as a list."""
152
163
return list (self )
153
164
154
165
def __getitem__ (self , i ):
166
+ # This implementation resolves indexing backwards,
167
+ # i.e. starting from the batches, then progressing to the
168
+ # epochs.
155
169
if isinstance (i , (int , slice )):
156
- return super ().__getitem__ (i )
157
-
158
- x = self
159
- if isinstance (i , tuple ):
160
- for part in i :
161
- x_dirty = partial_index (x , part )
162
- x = filter_missing (x_dirty )
163
- # pylint: disable=unidiomatic-typecheck
164
- if type (x ) is _missingno :
165
- raise x .e
166
- return x
167
- raise ValueError ("Invalid parameter type passed to index. "
168
- "Pass string, int or tuple." )
170
+ i = (i ,)
171
+
172
+ # i_e: index epoch, k_e: key epoch
173
+ # i_b: index batch, k_b: key batch
174
+ i_e , k_e , i_b , k_b = _unpack_index (i )
175
+ keyerror_msg = "Key '{}' was not found in history."
176
+
177
+ if i_b is not None and k_e != 'batches' :
178
+ raise KeyError ("History indexing beyond the 2nd level is "
179
+ "only possible if key 'batches' is used, "
180
+ "found key '{}'." .format (k_e ))
181
+
182
+ items = self .to_list ()
183
+
184
+ # extract indices of batches
185
+ # handles: history[..., k_e, i_b]
186
+ if i_b is not None :
187
+ items = [row [k_e ][i_b ] for row in items ]
188
+
189
+ # extract keys of batches
190
+ # handles: history[..., k_e, i_b][k_b]
191
+ if k_b is not None :
192
+ items = [
193
+ _filter_none ([_getitem (b , k_b ) for b in batches ])
194
+ if isinstance (batches , (list , tuple ))
195
+ else _getitem (batches , k_b )
196
+ for batches in items
197
+ ]
198
+ # get rid of empty batches
199
+ items = [b for b in items if b not in (_none , [], ())]
200
+ if not _filter_none (items ):
201
+ # all rows contained _none or were empty
202
+ raise KeyError (keyerror_msg .format (k_b ))
203
+
204
+ # extract epoch-level values, but only if not already done
205
+ # handles: history[..., k_e]
206
+ if (k_e is not None ) and (i_b is None ):
207
+ items = [_getitem (batches , k_e )
208
+ for batches in items ]
209
+ if not _filter_none (items ):
210
+ raise KeyError (keyerror_msg .format (k_e ))
211
+
212
+ # extract the epochs
213
+ # handles: history[i_b, ..., ..., ...]
214
+ if i_e is not None :
215
+ items = items [i_e ]
216
+ if isinstance (i_e , slice ):
217
+ items = _filter_none (items )
218
+ if items is _none :
219
+ raise KeyError (keyerror_msg .format (k_e ))
220
+
221
+ return items
0 commit comments