Skip to content

Commit 2f00e25

Browse files
ottonemobenjamin-work
authored andcommitted
Implement a NeuralNetRegressor.
1 parent 28fc5a3 commit 2f00e25

File tree

9 files changed

+684
-27
lines changed

9 files changed

+684
-27
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ A scikit-learn compatible neural network library that wraps pytorch.
44

55
## Example
66

7-
```
7+
```python
88
import numpy as np
99
from sklearn.datasets import make_classification
1010
import torch
@@ -49,7 +49,7 @@ y_proba = net.predict_proba(X)
4949

5050
In an sklearn Pipeline:
5151

52-
```
52+
```python
5353
from sklearn.pipeline import Pipeline
5454
from sklearn.preprocessing import StandardScaler
5555

@@ -66,7 +66,7 @@ y_proba = pipe.predict_proba(X)
6666

6767
With grid search
6868

69-
```
69+
```python
7070
from sklearn.model_selection import GridSearchCV
7171

7272

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@ dependencies:
2121
- xz=5.2.2=1
2222
- yaml=0.1.6=0
2323
- zlib=1.2.8=3
24+
- pip:
25+
- tabulate==0.7.7

inferno/callbacks.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
from itertools import cycle
2+
from numbers import Number
13
import operator
4+
import sys
25
import time
36

47
import numpy as np
58
from sklearn.base import BaseEstimator
69
from sklearn import metrics
10+
from tabulate import tabulate
711

12+
from inferno.utils import Ansi
813
from inferno.utils import to_numpy
914
from inferno.utils import to_var
15+
from inferno.utils import check_history_slice
1016

1117

1218
class Callback:
@@ -139,11 +145,87 @@ def on_batch_end(self, net, X, y, train):
139145
if isinstance(self.scoring, str): # TODO: make py2.7 compatible
140146
# scoring is a string
141147
y = self.target_extractor(y)
142-
scorer = getattr(metrics, self.scoring)
148+
try:
149+
scorer = getattr(metrics, self.scoring)
150+
except AttributeError:
151+
raise NameError("Metric with name '{}' does not exist, "
152+
"use a valid sklearn metric name."
153+
"".format(self.scoring))
143154
y_pred = self.pred_extractor(net.module_(to_var(X)))
144155
score = scorer(y, y_pred)
145156
else:
146157
# scoring is a function
147158
score = self.scoring(net, X, y)
148159

149160
net.history.record_batch(self.name, score)
161+
162+
163+
class PrintLog(Callback):
164+
def __init__(
165+
self,
166+
keys=('epoch', 'train_loss', 'valid_loss', 'train_loss_best',
167+
'valid_loss_best', 'dur'),
168+
sink=print,
169+
tablefmt='simple',
170+
floatfmt='.4f',
171+
):
172+
self.keys = (keys,) if isinstance(keys, str) else keys
173+
self.sink = sink
174+
self.tablefmt = tablefmt
175+
self.floatfmt = floatfmt
176+
177+
def initialize(self):
178+
self.first_iteration_ = True
179+
self.idx_ = {key: i for i, key in enumerate(self.keys)}
180+
return self
181+
182+
def format_row(self, row):
183+
row_formatted = []
184+
colors = cycle(Ansi)
185+
186+
for key, item in zip(self.keys, row):
187+
if key.endswith('_best'):
188+
continue
189+
190+
if not isinstance(item, Number):
191+
row_formatted.append(item)
192+
continue
193+
194+
color = next(colors)
195+
# if numeric, there could be a 'best' key
196+
idx_best = self.idx_.get(key + '_best')
197+
198+
is_integer = float(item).is_integer()
199+
template = '{}' if is_integer else '{:' + self.floatfmt + '}'
200+
201+
if (idx_best is not None) and row[idx_best]:
202+
template = color.value + template + Ansi.ENDC.value
203+
row_formatted.append(template.format(item))
204+
205+
return row_formatted
206+
207+
def table(self, data):
208+
formatted = [self.format_row(row) for row in data]
209+
headers = [key for key in self.keys if not key.endswith('_best')]
210+
return tabulate(
211+
formatted,
212+
headers=headers,
213+
tablefmt=self.tablefmt,
214+
floatfmt=self.floatfmt,
215+
)
216+
217+
def on_epoch_end(self, net, *args, **kwargs):
218+
sl = slice(-1, None), self.keys
219+
check_history_slice(net.history, sl)
220+
data = net.history[sl]
221+
tabulated = self.table(data)
222+
223+
if self.first_iteration_:
224+
header, lines = tabulated.split('\n', 2)[:2]
225+
self.sink(header)
226+
self.sink(lines)
227+
self.first_iteration_ = False
228+
229+
self.sink(tabulated.rsplit('\n', 1)[-1])
230+
if self.sink is print:
231+
sys.stdout.flush()

inferno/net.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from inferno.callbacks import BestLoss
1414
from inferno.callbacks import Callback
1515
from inferno.callbacks import EpochTimer
16+
from inferno.callbacks import PrintLog
1617
from inferno.callbacks import Scoring
18+
from inferno.utils import get_dim
1719
from inferno.utils import to_numpy
1820
from inferno.utils import to_tensor
1921
from inferno.utils import to_var
@@ -59,15 +61,20 @@ def partial_index(l, idx):
5961
return [partial_index(n, idx) for n in l]
6062

6163
# join results of multiple indices
62-
if type(idx) is tuple or type(idx) is list:
64+
if isinstance(idx, (tuple, list)):
6365
def incomplete_mapper(x):
6466
for xs in x:
6567
if type(xs) is __missingno:
6668
return xs
6769
return x
68-
total_join = zip(*[partial_index(l, n) for n in idx])
69-
inner_join = map(incomplete_mapper, total_join)
70-
return list(inner_join)
70+
zz = [partial_index(l, n) for n in idx]
71+
if is_list_like(l):
72+
total_join = zip(*zz)
73+
inner_join = list(map(incomplete_mapper, total_join))
74+
else:
75+
total_join = tuple(zz)
76+
inner_join = incomplete_mapper(total_join)
77+
return inner_join
7178

7279
try:
7380
return l[idx]
@@ -104,6 +111,7 @@ class NeuralNet(Callback):
104111
('epoch_timer', EpochTimer),
105112
('average_loss', AverageLoss),
106113
('best_loss', BestLoss),
114+
('print_log', PrintLog),
107115
]
108116

109117
def __init__(
@@ -270,24 +278,23 @@ def fit(self, X, y, **fit_params):
270278

271279
return self
272280

273-
def predict_proba(self, X):
274-
y_proba = self.forward(X, training_behavior=False)
275-
y_proba = to_numpy(y_proba)
276-
return y_proba
277-
278281
def forward(self, X, training_behavior=False):
279282
self.module_.train(training_behavior)
280283

281284
iterator = self.get_iterator(X)
282-
y_probas = []
285+
y_infer = []
283286
for x in iterator:
284287
x = to_var(x, use_cuda=self.use_cuda)
285-
y_probas.append(self.module_(x))
286-
return torch.cat(y_probas, dim=0)
288+
y_infer.append(self.module_(x))
289+
return torch.cat(y_infer, dim=0)
290+
291+
def predict_proba(self, X):
292+
y_proba = self.forward(X, training_behavior=False)
293+
y_proba = to_numpy(y_proba)
294+
return y_proba
287295

288296
def predict(self, X):
289-
self.module_.train(False)
290-
return self.predict_proba(X).argmax(1)
297+
return self.predict_proba(X)
291298

292299
def get_optimizer(self):
293300
kwargs = self._get_params_for('optim')
@@ -378,15 +385,19 @@ class NeuralNetClassifier(NeuralNet):
378385
('valid_loss', 'valid_batch_size'),
379386
('valid_acc', 'valid_batch_size'),
380387
])),
381-
('best_loss', BestLoss(
382-
keys_possible=['train_loss', 'valid_loss', 'valid_acc'],
383-
signs=[-1, -1, 1],
384-
)),
385388
('accuracy', Scoring(
386389
name='valid_acc',
387390
scoring='accuracy_score',
388391
pred_extractor=accuracy_pred_extractor,
389392
)),
393+
('best_loss', BestLoss(
394+
keys_possible=['train_loss', 'valid_loss', 'valid_acc'],
395+
signs=[-1, -1, 1],
396+
)),
397+
('print_log', PrintLog(keys=(
398+
'epoch', 'train_loss', 'valid_loss', 'train_loss_best',
399+
'valid_loss_best', 'valid_acc', 'valid_acc_best', 'dur'),
400+
)),
390401
]
391402

392403
def __init__(
@@ -406,3 +417,29 @@ def __init__(
406417
def get_loss(self, y_pred, y, train=False):
407418
y_pred_log = torch.log(y_pred)
408419
return self.criterion_(y_pred_log, y)
420+
421+
def predict(self, X):
422+
return self.predict_proba(X).argmax(1)
423+
424+
425+
class NeuralNetRegressor(NeuralNet):
426+
def __init__(
427+
self,
428+
module,
429+
criterion=torch.nn.MSELoss,
430+
*args,
431+
**kwargs
432+
):
433+
super(NeuralNetRegressor, self).__init__(
434+
module,
435+
criterion=criterion,
436+
*args,
437+
**kwargs
438+
)
439+
440+
def check_data(self, _, y):
441+
# The problem with 1-dim float y is that the pytorch DataLoader will
442+
# somehow upcast it to DoubleTensor
443+
if get_dim(y) == 1:
444+
raise ValueError("The target data shouldn't be 1-dimensional; "
445+
"please reshape (e.g. y.reshape(-1, 1).")

inferno/tests/conftest.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
5+
6+
@pytest.fixture
7+
def history_cls():
8+
from inferno.net import History
9+
return History
10+
11+
12+
@pytest.fixture
13+
def history(history_cls):
14+
return history_cls()
15+
16+
17+
def get_history(*callbacks, history_cls=history_cls):
18+
h = history_cls()()
19+
net = Mock()
20+
net.history = h
21+
data = [(range(6, 10), 1, 'hi'),
22+
(range(2, 6), 2, 'ho'),
23+
(range(10, 14), 3, 'hu')]
24+
25+
for range_, epoch, text in data:
26+
h.new_epoch()
27+
for cb in callbacks:
28+
cb.on_epoch_begin(net)
29+
30+
for i in range_:
31+
h.new_batch()
32+
for cb in callbacks:
33+
cb.on_batch_begin(net)
34+
35+
h.record_batch('train_loss', 1 - i / 10)
36+
h.record_batch('train_batch_size', 10)
37+
h.record_batch('valid_loss', i)
38+
h.record_batch('valid_batch_size', 1)
39+
40+
for cb in callbacks:
41+
cb.on_batch_end(net)
42+
43+
h.record('epoch', epoch)
44+
h.record('text', text)
45+
for cb in callbacks:
46+
cb.on_epoch_end(net)
47+
48+
return h

0 commit comments

Comments
 (0)