Skip to content

Add predict nonlinearity #662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make it easier to add custom modules or optimizers to a neural net class by automatically registering them where necessary and by making them available to set_params
- Added the `step_every` argument for `LRScheduler` to set whether the scheduler step should be taken on every epoch or on every batch.
- Added the `scoring` module with `loss_scoring` function, which computes the net's loss (using `get_loss`) on provided input data.
- Added a parameter `predict_nonlinearity` to `NeuralNet` which allows users to control the nonlinearity to be applied to the module output when calling `predict` and `predict_proba` (#637, #661)

### Changed

Expand All @@ -23,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Set train/validation on criterion if it's a PyTorch module (#621)
- Don't pass `y=None` to `NeuralNet.train_split` to enable the direct use of split functions without positional `y` in their signatures. This is useful when working with unsupervised data (#605).
- `to_numpy` is now able to unpack dicts and lists/tuples (#657, #658)
- When using `CrossEntropyLoss`, softmax is now automatically applied to the output when calling `predict` or `predict_proba`

### Fixed

Expand Down
13 changes: 7 additions & 6 deletions docs/user/neuralnet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,14 @@ from the ``module``. Alternatively, you may directly call
``net.module_(X)``.

In case of :class:`.NeuralNetClassifier`, the
:func:`~skorch.net.NeuralNetClassifier.predict` method tries to return
the class labels by applying the argmax over the last axis of the
result of :func:`~skorch.net.NeuralNetClassifier.predict_proba`.
:func:`~skorch.classifier.NeuralNetClassifier.predict` method tries to
return the class labels by applying the argmax over the last axis of
the result of
:func:`~skorch.classifier.NeuralNetClassifier.predict_proba`.
Obviously, this only makes sense if
:func:`~skorch.net.NeuralNetClassifier.predict_proba` returns class
probabilities. If this is not true, you should just use
:func:`~skorch.net.NeuralNetClassifier.predict_proba`.
:func:`~skorch.classifier.NeuralNetClassifier.predict_proba` returns
class probabilities. If this is not true, you should just use
:func:`~skorch.classifier.NeuralNetClassifier.predict_proba`.

score(X, y)
^^^^^^^^^^^
Expand Down
53 changes: 1 addition & 52 deletions skorch/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,7 @@ def predict(self, X):
y_pred : numpy ndarray

"""
y_preds = []
for yp in self.forward_iter(X, training=False):
yp = yp[0] if isinstance(yp, tuple) else yp
y_preds.append(to_numpy(yp.max(-1)[-1]))
y_pred = np.concatenate(y_preds, 0)
return y_pred
return super().predict_proba(X).argmax(axis=1)


neural_net_binary_clf_doc_start = """NeuralNet for binary classification tasks
Expand Down Expand Up @@ -361,49 +356,3 @@ def predict(self, X):
"""
y_proba = self.predict_proba(X)
return (y_proba[:, 1] > self.threshold).astype('uint8')

# pylint: disable=missing-docstring
def predict_proba(self, X):
"""Where applicable, return probability estimates for
samples.

If the module's forward method returns multiple outputs as a
tuple, it is assumed that the first output contains the
relevant information and the other values are ignored. If all
values are relevant, consider using
:func:`~skorch.NeuralNet.forward` instead.

Parameters
----------
X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:

* numpy arrays
* torch tensors
* pandas DataFrame or Series
* scipy sparse CSR matrices
* a dictionary of the former three
* a list/tuple of the former three
* a Dataset

If this doesn't work with your data, you have to pass a
``Dataset`` that can deal with the data.

Returns
-------
y_proba : numpy ndarray

"""
y_probas = []
self.check_is_fitted(attributes=['criterion_'])
bce_logits_loss = isinstance(
self.criterion_, torch.nn.BCEWithLogitsLoss)

for yp in self.forward_iter(X, training=False):
yp = yp[0] if isinstance(yp, tuple) else yp
if bce_logits_loss:
yp = torch.sigmoid(yp)
y_probas.append(to_numpy(yp))
y_proba = np.concatenate(y_probas, 0)
y_proba = np.stack((1 - y_proba, y_proba), axis=1)
return y_proba
72 changes: 72 additions & 0 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from skorch.exceptions import DeviceWarning
from skorch.history import History
from skorch.setter import optimizer_setter
from skorch.utils import _identity
from skorch.utils import _infer_predict_nonlinearty
from skorch.utils import FirstStepAccumulator
from skorch.utils import TeeGenerator
from skorch.utils import check_is_fitted
Expand Down Expand Up @@ -152,6 +154,33 @@ class NeuralNet:
``net.set_params(callbacks__print_log__keys_ignored=['epoch',
'train_loss'])``).

predict_nonlinearity : callable, None, or 'auto' (default='auto')
The nonlinearity to be applied to the prediction. When set to
'auto', infers the correct nonlinearity based on the criterion
(softmax for :class:`~torch.nn.CrossEntropyLoss` and sigmoid for
:class:`~torch.nn.BCEWithLogitsLoss`). If it cannot be inferred
or if the parameter is None, just use the identity
function. Don't pass a lambda function if you want the net to be
pickleable.

In case a callable is passed, it should accept the output of the
module (the first output if there is more than one), which is a
PyTorch tensor, and return the transformed PyTorch tensor.

This can be useful, e.g., when
:func:`~skorch.NeuralNetClassifier.predict_proba`
should return probabilities but a criterion is used that does
not expect probabilities. In that case, the module can return
whatever is required by the criterion and the
``predict_nonlinearity`` transforms this output into
probabilities.

The nonlinearity is applied only when calling
:func:`~skorch.classifier.NeuralNetClassifier.predict` or
:func:`~skorch.classifier.NeuralNetClassifier.predict_proba` but
not anywhere else -- notably, the loss is unaffected by this
nonlinearity.

warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the
module (cold start) or whether the module should be trained
Expand Down Expand Up @@ -213,6 +242,7 @@ def __init__(
dataset=Dataset,
train_split=CVSplit(5),
callbacks=None,
predict_nonlinearity='auto',
warm_start=False,
verbose=1,
device='cpu',
Expand All @@ -229,6 +259,7 @@ def __init__(
self.dataset = dataset
self.train_split = train_split
self.callbacks = callbacks
self.predict_nonlinearity = predict_nonlinearity
self.warm_start = warm_start
self.verbose = verbose
self.device = device
Expand Down Expand Up @@ -1010,6 +1041,45 @@ def infer(self, x, **fit_params):
return self.module_(**x_dict)
return self.module_(x, **fit_params)

def _get_predict_nonlinearity(self):
"""Return the nonlinearity to be applied to the prediction

This can be useful, e.g., when
:func:`~skorch.classifier.NeuralNetClassifier.predict_proba`
should return probabilities but a criterion is used that does
not expect probabilities. In that case, the module can return
whatever is required by the criterion and the
``predict_nonlinearity`` transforms this output into
probabilities.

The nonlinearity is applied only when calling
:func:`~skorch.classifier.NeuralNetClassifier.predict` or
:func:`~skorch.classifier.NeuralNetClassifier.predict_proba`
but not anywhere else -- notably, the loss is unaffected by
this nonlinearity.

Raises
------
TypeError
Raise a TypeError if the return value is not callable.

Returns
-------
nonlin : callable
A callable that takes a single argument, which is a PyTorch
tensor, and returns a PyTorch tensor.

"""
self.check_is_fitted()
nonlin = self.predict_nonlinearity
if nonlin is None:
nonlin = _identity
elif nonlin == 'auto':
nonlin = _infer_predict_nonlinearty(self)
if not callable(nonlin):
raise TypeError("predict_nonlinearity has to be a callable, 'auto' or None")
return nonlin

def predict_proba(self, X):
"""Return the output of the module's forward method as a numpy
array.
Expand Down Expand Up @@ -1041,9 +1111,11 @@ def predict_proba(self, X):
y_proba : numpy ndarray

"""
nonlin = self._get_predict_nonlinearity()
y_probas = []
for yp in self.forward_iter(X, training=False):
yp = yp[0] if isinstance(yp, tuple) else yp
yp = nonlin(yp)
y_probas.append(to_numpy(yp))
y_proba = np.concatenate(y_probas, 0)
return y_proba
Expand Down
9 changes: 8 additions & 1 deletion skorch/tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,14 @@ def test_custom_loss_does_not_call_sigmoid(
mock = Mock(side_effect=lambda x: x)
monkeypatch.setattr(torch, "sigmoid", mock)

net = net_cls(module_cls, max_epochs=1, lr=0.1, criterion=nn.MSELoss)
# add a custom nonlinearity - note that the output must return
# a 2d array from a 1d vector to conform to the required
# y_proba
def nonlin(x):
return torch.stack((1 - x, x), 1)

net = net_cls(module_cls, max_epochs=1, lr=0.1, criterion=nn.MSELoss,
predict_nonlinearity=nonlin)
X, y = data
net.fit(X, y)

Expand Down
86 changes: 86 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,6 +2540,92 @@ def forward(self, y_pred, _):
with raises:
net.fit(X, y)

def test_predict_nonlinearity_called_with_predict(
self, net_cls, module_cls, data):
side_effect = []
def nonlin(X):
side_effect.append(X)
return np.zeros_like(X)

X, y = data[0][:200], data[1][:200]
net = net_cls(
module_cls, max_epochs=1, predict_nonlinearity=nonlin).initialize()

# don't want callbacks to trigger side effects
net.callbacks_ = []
net.partial_fit(X, y)
assert not side_effect
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But wouldn't we expect callbacks such as accuracy scoring to call predict?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's why I removed all callbacks two lines earlier, otherwise the test becomes very messy.


# 2 calls, since batch size == 128 and n == 200
y_proba = net.predict(X)
assert len(side_effect) == 2
assert side_effect[0].shape == (128, 2)
assert side_effect[1].shape == (72, 2)
assert (y_proba == 0).all()

net.predict(X)
assert len(side_effect) == 4

def test_predict_nonlinearity_called_with_predict_proba(
self, net_cls, module_cls, data):
side_effect = []
def nonlin(X):
side_effect.append(X)
return np.zeros_like(X)

X, y = data[0][:200], data[1][:200]
net = net_cls(
module_cls, max_epochs=1, predict_nonlinearity=nonlin).initialize()

net.callbacks_ = []
# don't want callbacks to trigger side effects
net.partial_fit(X, y)
assert not side_effect

# 2 calls, since batch size == 128 and n == 200
y_proba = net.predict_proba(X)
assert len(side_effect) == 2
assert side_effect[0].shape == (128, 2)
assert side_effect[1].shape == (72, 2)
assert np.allclose(y_proba, 0)

net.predict_proba(X)
assert len(side_effect) == 4

def test_predict_nonlinearity_none(
self, net_cls, module_cls, data):
# even though we have CrossEntropyLoss, we don't want the
# output from predict_proba to be modified, thus we set
# predict_nonlinearity to None
X = data[0][:200]
net = net_cls(
module_cls,
max_epochs=1,
criterion=nn.CrossEntropyLoss,
predict_nonlinearity=None,
).initialize()

rv = np.random.random((20, 5))
net.forward_iter = lambda *args, **kwargs: (torch.as_tensor(rv) for _ in range(2))

# 2 batches, mock return value has shape 20,5 thus y_proba has
# shape 40,5
y_proba = net.predict_proba(X)
assert y_proba.shape == (40, 5)
assert np.allclose(y_proba[:20], rv)
assert np.allclose(y_proba[20:], rv)

def test_predict_nonlinearity_type_error(self, net_cls, module_cls):
# if predict_nonlinearity is not callable, raise a TypeError
net = net_cls(module_cls, predict_nonlinearity=123).initialize()

msg = "predict_nonlinearity has to be a callable, 'auto' or None"
with pytest.raises(TypeError, match=msg):
net.predict(np.zeros((3, 3)))

with pytest.raises(TypeError, match=msg):
net.predict_proba(np.zeros((3, 3)))


class TestNetSparseInput:
@pytest.fixture(scope='module')
Expand Down
66 changes: 66 additions & 0 deletions skorch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,69 @@ def list_gen():

assert first_return == expected_list
assert second_return == expected_list


class TestInferPredictNonlinearity:
@pytest.fixture
def infer_predict_nonlinearity(self):
from skorch.utils import _infer_predict_nonlinearty
return _infer_predict_nonlinearty

@pytest.fixture
def net_clf_cls(self):
from skorch import NeuralNetClassifier
return NeuralNetClassifier

@pytest.fixture
def net_bin_clf_cls(self):
from skorch import NeuralNetBinaryClassifier
return NeuralNetBinaryClassifier

@pytest.fixture
def net_regr_cls(self):
from skorch import NeuralNetRegressor
return NeuralNetRegressor

def test_infer_neural_net_classifier_default(
self, infer_predict_nonlinearity, net_clf_cls, module_cls):
# default NeuralNetClassifier: no output nonlinearity
net = net_clf_cls(module_cls).initialize()
fn = infer_predict_nonlinearity(net)

X = np.random.random((20, 5))
out = fn(X)
assert out is X

def test_infer_neural_net_classifier_crossentropy_loss(
self, infer_predict_nonlinearity, net_clf_cls, module_cls):
# CrossEntropyLoss criteron: nonlinearity should return valid probabilities
net = net_clf_cls(module_cls, criterion=torch.nn.CrossEntropyLoss).initialize()
fn = infer_predict_nonlinearity(net)

X = torch.rand((20, 5))
out = fn(X).numpy()
assert np.allclose(out.sum(axis=1), 1.0)
assert ((0 <= out) & (out <= 1.0)).all()

def test_infer_neural_binary_net_classifier_default(
self, infer_predict_nonlinearity, net_bin_clf_cls, module_cls):
# BCEWithLogitsLoss should return valid probabilities
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# BCEWithLogitsLoss should return valid probabilities
# BCEWithLogitsLoss criterion: nonlinearity should return valid probabilities

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

net = net_bin_clf_cls(module_cls).initialize()
fn = infer_predict_nonlinearity(net)

X = torch.rand(20) # binary classifier returns 1-dim output
X = 10 * X - 5.0 # random values from -5 to 5
out = fn(X).numpy()
assert out.shape == (20, 2) # output should be 2-dim
assert np.allclose(out.sum(axis=1), 1.0)
assert ((0 <= out) & (out <= 1.0)).all()

def test_infer_neural_net_regressor_default(
self, infer_predict_nonlinearity, net_regr_cls, module_cls):
# default NeuralNetRegressor: no output nonlinearity
net = net_regr_cls(module_cls).initialize()
fn = infer_predict_nonlinearity(net)

X = np.random.random((20, 5))
out = fn(X)
assert out is X
Loading