Skip to content

Commit aa1b034

Browse files
Allow regression with 1d targets
This change makes it possible to pass a 1-dimensional y to `NeuralNetRegressor`. Problem description Right now, skorch requires the `y` passed to `NeuralNetRegressor.fit` to be 2-dimensional, even if there is only one target, as is the most common case. This problem has come up a few times in the past, but mostly it's just an annoyance - just do `y.reshape(-1, 1)` and you're good (the error message says as much). There are, however, also cases where it's not so easy to solve. For instance, in #972, a user reports that they cannot use skorch with sklearn's `BaggingRegressor`. The problem is that even if `y` is reshaped, once it is passed to the net from `BaggingRegressor`, it is 1d again. I assume that `BaggingRegressor` internally squeezes `y` at some point. This PR lifts the 2d restriction check. Initial motivation Why does skorch require `y` to be 2d? I couldn't remember the initial reasoning and did some archeology. I found this comment: (2f00e25#diff-66ed08bca4d171889565d0285a36b9b47e0e91e3b33d85c51352d8eb00faefac): > # The problem with 1-dim float y is that the pytorch DataLoader will > # somehow upcast it to DoubleTensor This strange behavior should not be an issue anymore, so if that was the only problem, we should be able to just remove the constraint, right? Problems with removing the constraint Unfortunately, it's not that easy. The issue comes down to the following: When we remove the constraint and allow the target `y` to be 1d, but the prediction `y_pred` is still 2d, the criterion `nn.MSELoss` will probably do the wrong thing. What exactly is wrong? Instead of calculating the squared error for each sample pair, the criterion will broadcast the vector and calculate _all squared errors_ between each sample, then return the mean of that. To demonstrate, let's remove the reduction step and look at the shape: ```python >>> import torch >>> criterion = torch.nn.MSELoss(reduction='none') >>> y = torch.rand(100) >>> y_pred = torch.rand((100, 1)) >>> y.shape, y_pred.shape (torch.Size([100]), torch.Size([100, 1])) >>> se = criterion(y_pred, y) /home/vinh/anaconda3/envs/skorch/lib/python3.10/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) >>> se.shape torch.Size([100, 100]) ``` As can be seen, PyTorch broadcasts the two arrays, leading to 100x100 errors being calculated. Thankfully, PyTorch warns about potential issues with that. The current solution is to accept this behavior and hope that the users will indeed see the warning. If they don't see it or ignore it, it could be a huge issue, because they still get a loss scalar and might even see a small improvement in the loss during training. But the model will not converge and it's going to be a huge pain to debug the bug, if it's even identified as such. Just to be clear, existing code, which uses 2d targets, will not be affected by the change introduced in this PR and is still the preferred way (IMO) to use regression in skorch. Rejected solutions I did consider the following solutions but rejected them. Raising an error when shapes mismatch This would remove the risk of users missing the warning. The problem with this is that mismatching shapes can be okay in certain circumstances. Some criteria don't expect target and prediction to have the same shape, so we would need to check based on criterion. Moreover, theoretically, users may indeed want to broadcast. Raising an error would prevent that and users may have to resort to subclassing to circumvent the error. Automatic reshaping We could automatically add/remove dimensions if we see that they mismatch. This has the same problems as the previous solution regarding the dependence on the type of criterion. Furthermore, automatic adjustment of the user's output is prone to run into issues in some edge cases (e.g. when the broadcasting is actually desired).
1 parent f69be0f commit aa1b034

File tree

3 files changed

+67
-23
lines changed

3 files changed

+67
-23
lines changed

CHANGES.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Changed
1313

14+
- `NeuralNetRegressor` can now be fitted with 1-dimensional `y`, which is necessary in some specific circumstances (e.g. in conjunction with sklearn's `BaggingRegressor`, see #972); for this to work correctly, the output of the of the PyTorch module should also be 1-dimensional; the existing default, i.e. having `y` and `y_pred` be 2-dimensional, remains the recommended way of using `NeuralNetRegressor`
15+
1416
### Fixed
1517

1618
## [0.13.0] - 2023-05-17

skorch/regressor.py

-9
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,6 @@ def check_data(self, X, y):
6666
# The user implements its own mechanism for generating y.
6767
return
6868

69-
if get_dim(y) == 1:
70-
msg = (
71-
"The target data shouldn't be 1-dimensional but instead have "
72-
"2 dimensions, with the second dimension having the same size "
73-
"as the number of regression targets (usually 1). Please "
74-
"reshape your target data to be 2-dimensional "
75-
"(e.g. y = y.reshape(-1, 1).")
76-
raise ValueError(msg)
77-
7869
# pylint: disable=signature-differs
7970
def fit(self, X, y, **fit_params):
8071
"""See ``NeuralNet.fit``.

skorch/tests/test_regressor.py

+65-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
55
"""
66

7+
from functools import partial
8+
79
import numpy as np
810
import pytest
911
from sklearn.base import clone
@@ -21,6 +23,12 @@ def module_cls(self):
2123
from skorch.toy import make_regressor
2224
return make_regressor(dropout=0.5)
2325

26+
@pytest.fixture(scope='module')
27+
def module_pred_1d_cls(self):
28+
from skorch.toy import MLPModule
29+
# Module that returns 1d predictions
30+
return partial(MLPModule, output_units=1, squeeze_output=True)
31+
2432
@pytest.fixture(scope='module')
2533
def net_cls(self):
2634
from skorch import NeuralNetRegressor
@@ -57,9 +65,9 @@ def net_fit(self, net, data):
5765
def test_clone(self, net_fit):
5866
clone(net_fit)
5967

60-
def test_fit(self, net_fit):
61-
# fitting does not raise anything
62-
pass
68+
def test_fit(self, net_fit, recwarn):
69+
# fitting does not raise anything and does not warn
70+
assert not recwarn.list
6371

6472
@pytest.mark.parametrize('method', INFERENCE_METHODS)
6573
def test_not_fitted_raises(self, net_cls, module_cls, data, method):
@@ -91,17 +99,6 @@ def test_history_default_keys(self, net_fit):
9199
for row in net_fit.history:
92100
assert expected_keys.issubset(row)
93101

94-
def test_target_1d_raises(self, net, data):
95-
X, y = data
96-
with pytest.raises(ValueError) as exc:
97-
net.fit(X, y.flatten())
98-
assert exc.value.args[0] == (
99-
"The target data shouldn't be 1-dimensional but instead have "
100-
"2 dimensions, with the second dimension having the same size "
101-
"as the number of regression targets (usually 1). Please "
102-
"reshape your target data to be 2-dimensional "
103-
"(e.g. y = y.reshape(-1, 1).")
104-
105102
def test_predict_predict_proba(self, net_fit, data):
106103
X = data[0]
107104
y_pred = net_fit.predict(X)
@@ -123,3 +120,57 @@ def test_multioutput_score(self, multioutput_net, multioutput_regression_data):
123120
multioutput_net.fit(X, y)
124121
r2_score = multioutput_net.score(X, y)
125122
assert r2_score <= 1.
123+
124+
def test_dimension_mismatch_warning(self, net_cls, module_cls, data, recwarn):
125+
# When the target and the prediction have different dimensionality, mse
126+
# loss with broadcast them, calculating all pairwise errors instead of
127+
# only sample-wise. Since the errors are averaged at the end, there is
128+
# still a valid loss, which makes the error hard to spot. Thankfully,
129+
# torch gives a warning in that case. We test that this warning exists,
130+
# otherwise, skorch users could run into very hard to debug issues
131+
# during training.
132+
net = net_cls(module_cls)
133+
X, y = data
134+
X, y = X[:100], y[:100].flatten() # make y 1d
135+
net.fit(X, y)
136+
137+
w0, w1 = recwarn.list # one warning for train, one for valid The
138+
# warning comes from PyTorch, so checking the exact wording is prone to
139+
# error in future PyTorch versions. We thus check a substring of the
140+
# whole message and cross our fingers that it's not changed.
141+
msg_substr = (
142+
"This will likely lead to incorrect results due to broadcasting. "
143+
"Please ensure they have the same size"
144+
)
145+
assert msg_substr in str(w0.message)
146+
assert msg_substr in str(w1.message)
147+
148+
def test_fitting_with_1d_target_and_pred(
149+
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
150+
):
151+
# This test relates to the previous one. In general, users should fit
152+
# with target and prediction being 2d, even if the 2nd dimension is just
153+
# 1. However, in some circumstances (like when using BaggingRegressor,
154+
# see next test), having the ability to fit with 1d is required. In that
155+
# case, the module output also needs to be 1d for correctness.
156+
X, y = data
157+
X, y = X[:100], y[:100] # less data to run faster
158+
y = y.flatten()
159+
160+
net = net_cls(module_pred_1d_cls)
161+
net.fit(X, y)
162+
assert not recwarn.list
163+
164+
def test_bagging_regressor(
165+
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
166+
):
167+
# https://github.com/skorch-dev/skorch/issues/972
168+
from sklearn.ensemble import BaggingRegressor
169+
170+
net = net_cls(module_pred_1d_cls) # module output should be 1d too
171+
X, y = data
172+
X, y = X[:100], y[:100] # less data to run faster
173+
y = y.flatten() # make y 1d or else sklearn will complain
174+
regr = BaggingRegressor(estimator=net, n_estimators=2, random_state=0)
175+
regr.fit(X, y) # does not raise
176+
assert not recwarn.list # ensure there is no broadcast warning from torch

0 commit comments

Comments
 (0)