Skip to content

Commit 7276c93

Browse files
Fix failing tests for GPyTorch v1.10 (#956)
1 parent 51bba10 commit 7276c93

File tree

1 file changed

+95
-54
lines changed

1 file changed

+95
-54
lines changed

skorch/tests/test_probabilistic.py

+95-54
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,21 @@ def forward(self, x):
114114
return latent_pred
115115

116116

117+
class MyBernoulliLikelihood(gpytorch.likelihoods.BernoulliLikelihood):
118+
"""This class only exists to add a param to BernoulliLikelihood
119+
120+
BernoulliLikelihood used to have parameters before gpytorch v1.10, but now
121+
it does not have any parameters anymore. This is not an issue per se, but
122+
there are a few things we cannot test anymore, e.g. that parameters are
123+
passed to the likelihood correctly when using grid search. Therefore, create
124+
a custom class with a (pointless) parameter.
125+
126+
"""
127+
def __init__(self, *args, some_parameter=1, **kwargs):
128+
self.some_parameter = some_parameter
129+
super().__init__(*args, **kwargs)
130+
131+
117132
class BaseProbabilisticTests:
118133
"""Base class for all GP estimators.
119134
@@ -220,34 +235,24 @@ def pipe(self, gp):
220235
# saving and loading #
221236
######################
222237

223-
@pytest.mark.xfail(strict=True)
224-
def test_pickling(self, gp_fit):
225-
# Currently fails because of issues outside of our control, this test
226-
# should alert us to when the issue has been fixed. Some issues have
227-
# been fixed in https://github.com/cornellius-gp/gpytorch/pull/1336 but
228-
# not all.
229-
pickle.dumps(gp_fit)
238+
def test_pickling(self, gp_fit, data):
239+
loaded = pickle.loads(pickle.dumps(gp_fit))
240+
X, _ = data
230241

231-
def test_pickle_error_msg(self, gp_fit):
232-
# Should eventually be replaced by a test that saves and loads the model
233-
# using pickle and checks that the predictions are identical
234-
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
235-
" https://github.com/pytorch/pytorch/issues/38137. "
236-
"Try using 'dill' instead of 'pickle'.")
237-
with pytest.raises(pickle.PicklingError, match=msg):
238-
pickle.dumps(gp_fit)
242+
y_pred_before = gp_fit.predict(X)
243+
y_pred_after = loaded.predict(X)
244+
assert np.allclose(y_pred_before, y_pred_after)
239245

240-
def test_deepcopy(self, gp_fit):
241-
# Should eventually be replaced by a test that saves and loads the model
242-
# using deepcopy and checks that the predictions are identical
243-
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
244-
" https://github.com/pytorch/pytorch/issues/38137. "
245-
"Try using 'dill' instead of 'pickle'.")
246-
with pytest.raises(pickle.PicklingError, match=msg):
247-
copy.deepcopy(gp_fit) # doesn't raise
246+
def test_deepcopy(self, gp_fit, data):
247+
copied = copy.deepcopy(gp_fit)
248+
X, _ = data
249+
250+
y_pred_before = gp_fit.predict(X)
251+
y_pred_after = copied.predict(X)
252+
assert np.allclose(y_pred_before, y_pred_after)
248253

249-
def test_clone(self, gp_fit):
250-
clone(gp_fit) # doesn't raise
254+
def test_clone(self, gp_fit, data):
255+
clone(gp_fit) # does not raise
251256

252257
def test_save_load_params(self, gp_fit, tmpdir):
253258
gp2 = clone(gp_fit).initialize()
@@ -335,7 +340,8 @@ def test_grid_search_works(self, gp, data, recwarn):
335340
params = {
336341
'lr': [0.01, 0.02],
337342
'max_epochs': [10, 20],
338-
'likelihood__max_plate_nesting': [1, 2],
343+
# this parameter does not exist but that's okay
344+
'likelihood__some_parameter': [1, 2],
339345
}
340346
gp.set_params(verbose=0)
341347
gs = GridSearchCV(gp, params, refit=True, cv=3, scoring=self.scoring)
@@ -419,32 +425,29 @@ def test_multioutput_predict_proba(self, gp_multioutput, data):
419425
])
420426
def test_set_params_uninitialized_net_correct_message(
421427
self, gp, kwargs, expected, capsys):
422-
# When gp is initialized, if module or optimizer need to be
423-
# re-initialized, alert the user to the fact what parameters
424-
# were responsible for re-initialization. Note that when the
425-
# module parameters but not optimizer parameters were changed,
426-
# the optimizer is re-initialized but not because the
427-
# optimizer parameters changed.
428+
# When gp is uninitialized, there is nothing to alert the user to
428429
gp.set_params(**kwargs)
429430
msg = capsys.readouterr()[0].strip()
430431
assert msg == expected
431432

432433
@pytest.mark.parametrize('kwargs,expected', [
433434
({}, ""),
434435
(
435-
{'likelihood__max_plate_nesting': 2},
436+
# this parameter does not exist but that's okay
437+
{'likelihood__some_parameter': 2},
436438
("Re-initializing module because the following "
437-
"parameters were re-set: likelihood__max_plate_nesting.\n"
439+
"parameters were re-set: likelihood__some_parameter.\n"
438440
"Re-initializing criterion.\n"
439441
"Re-initializing optimizer.")
440442
),
441443
(
442444
{
443-
'likelihood__max_plate_nesting': 2,
445+
# this parameter does not exist but that's okay
446+
'likelihood__some_parameter': 2,
444447
'optimizer__momentum': 0.567,
445448
},
446449
("Re-initializing module because the following "
447-
"parameters were re-set: likelihood__max_plate_nesting.\n"
450+
"parameters were re-set: likelihood__some_parameter.\n"
448451
"Re-initializing criterion.\n"
449452
"Re-initializing optimizer.")
450453
),
@@ -570,23 +573,6 @@ def gp(self, gp_cls, module_cls):
570573
)
571574
return gpr
572575

573-
# pickling and deepcopy work for ExactGPRegressor but not for the others, so
574-
# override the expected failures here.
575-
576-
def test_pickling(self, gp_fit):
577-
# does not raise
578-
pickle.dumps(gp_fit)
579-
580-
def test_pickle_error_msg(self, gp_fit):
581-
# Should eventually be replaced by a test that saves and loads the model
582-
# using pickle and checks that the predictions are identical
583-
# FIXME
584-
pickle.dumps(gp_fit)
585-
586-
def test_deepcopy(self, gp_fit):
587-
# FIXME
588-
copy.deepcopy(gp_fit) # doesn't raise
589-
590576
def test_wrong_module_type_raises(self, gp_cls):
591577
# ExactGPRegressor requires the module to be an ExactGP, if it's not,
592578
# raise an appropriate error message to the user.
@@ -649,6 +635,32 @@ def gp(self, gp_cls, module_cls, data):
649635
assert gpr.batch_size < self.n_samples
650636
return gpr
651637

638+
# Since GPyTorch v1.10, GPRegressor works with pickle/deepcopy.
639+
640+
def test_pickling(self, gp_fit, data):
641+
# TODO: remove once Python 3.7 is no longer supported
642+
if version_gpytorch < Version('1.10'):
643+
pytest.skip("GPyTorch < 1.10 does not support pickling.")
644+
645+
loaded = pickle.loads(pickle.dumps(gp_fit))
646+
X, _ = data
647+
648+
y_pred_before = gp_fit.predict(X)
649+
y_pred_after = loaded.predict(X)
650+
assert np.allclose(y_pred_before, y_pred_after)
651+
652+
def test_deepcopy(self, gp_fit, data):
653+
# TODO: remove once Python 3.7 is no longer supported
654+
if version_gpytorch < Version('1.10'):
655+
pytest.skip("GPyTorch < 1.10 does not support deepcopy.")
656+
657+
copied = copy.deepcopy(gp_fit)
658+
X, _ = data
659+
660+
y_pred_before = gp_fit.predict(X)
661+
y_pred_after = copied.predict(X)
662+
assert np.allclose(y_pred_before, y_pred_after)
663+
652664

653665
class TestGPBinaryClassifier(BaseProbabilisticTests):
654666
"""Tests for GPBinaryClassifier."""
@@ -686,11 +698,40 @@ def gp(self, gp_cls, module_cls, data):
686698
gpc = gp_cls(
687699
module_cls,
688700
module__inducing_points=torch.from_numpy(X[:10]),
689-
701+
likelihood=MyBernoulliLikelihood,
690702
criterion=gpytorch.mlls.VariationalELBO,
691703
criterion__num_data=int(0.8 * len(y)),
692704
batch_size=24,
693705
)
694706
# we want to make sure batching is properly tested
695707
assert gpc.batch_size < self.n_samples
696708
return gpc
709+
710+
# Since GPyTorch v1.10, GPBinaryClassifier is the only estimator left that
711+
# still has issues with pickling/deepcopying.
712+
713+
@pytest.mark.xfail(strict=True)
714+
def test_pickling(self, gp_fit, data):
715+
# Currently fails because of issues outside of our control, this test
716+
# should alert us to when the issue has been fixed. Some issues have
717+
# been fixed in https://github.com/cornellius-gp/gpytorch/pull/1336 but
718+
# not all.
719+
pickle.dumps(gp_fit)
720+
721+
def test_pickle_error_msg(self, gp_fit, data):
722+
# Should eventually be replaced by a test that saves and loads the model
723+
# using pickle and checks that the predictions are identical
724+
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
725+
" https://github.com/pytorch/pytorch/issues/38137. "
726+
"Try using 'dill' instead of 'pickle'.")
727+
with pytest.raises(pickle.PicklingError, match=msg):
728+
pickle.dumps(gp_fit)
729+
730+
def test_deepcopy(self, gp_fit, data):
731+
# Should eventually be replaced by a test that saves and loads the model
732+
# using deepcopy and checks that the predictions are identical
733+
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
734+
" https://github.com/pytorch/pytorch/issues/38137. "
735+
"Try using 'dill' instead of 'pickle'.")
736+
with pytest.raises(pickle.PicklingError, match=msg):
737+
copy.deepcopy(gp_fit) # doesn't raise

0 commit comments

Comments
 (0)