@@ -114,6 +114,21 @@ def forward(self, x):
114
114
return latent_pred
115
115
116
116
117
+ class MyBernoulliLikelihood (gpytorch .likelihoods .BernoulliLikelihood ):
118
+ """This class only exists to add a param to BernoulliLikelihood
119
+
120
+ BernoulliLikelihood used to have parameters before gpytorch v1.10, but now
121
+ it does not have any parameters anymore. This is not an issue per se, but
122
+ there are a few things we cannot test anymore, e.g. that parameters are
123
+ passed to the likelihood correctly when using grid search. Therefore, create
124
+ a custom class with a (pointless) parameter.
125
+
126
+ """
127
+ def __init__ (self , * args , some_parameter = 1 , ** kwargs ):
128
+ self .some_parameter = some_parameter
129
+ super ().__init__ (* args , ** kwargs )
130
+
131
+
117
132
class BaseProbabilisticTests :
118
133
"""Base class for all GP estimators.
119
134
@@ -220,34 +235,24 @@ def pipe(self, gp):
220
235
# saving and loading #
221
236
######################
222
237
223
- @pytest .mark .xfail (strict = True )
224
- def test_pickling (self , gp_fit ):
225
- # Currently fails because of issues outside of our control, this test
226
- # should alert us to when the issue has been fixed. Some issues have
227
- # been fixed in https://github.com/cornellius-gp/gpytorch/pull/1336 but
228
- # not all.
229
- pickle .dumps (gp_fit )
238
+ def test_pickling (self , gp_fit , data ):
239
+ loaded = pickle .loads (pickle .dumps (gp_fit ))
240
+ X , _ = data
230
241
231
- def test_pickle_error_msg (self , gp_fit ):
232
- # Should eventually be replaced by a test that saves and loads the model
233
- # using pickle and checks that the predictions are identical
234
- msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
235
- " https://github.com/pytorch/pytorch/issues/38137. "
236
- "Try using 'dill' instead of 'pickle'." )
237
- with pytest .raises (pickle .PicklingError , match = msg ):
238
- pickle .dumps (gp_fit )
242
+ y_pred_before = gp_fit .predict (X )
243
+ y_pred_after = loaded .predict (X )
244
+ assert np .allclose (y_pred_before , y_pred_after )
239
245
240
- def test_deepcopy (self , gp_fit ):
241
- # Should eventually be replaced by a test that saves and loads the model
242
- # using deepcopy and checks that the predictions are identical
243
- msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
244
- " https://github.com/pytorch/pytorch/issues/38137. "
245
- "Try using 'dill' instead of 'pickle'." )
246
- with pytest .raises (pickle .PicklingError , match = msg ):
247
- copy .deepcopy (gp_fit ) # doesn't raise
246
+ def test_deepcopy (self , gp_fit , data ):
247
+ copied = copy .deepcopy (gp_fit )
248
+ X , _ = data
249
+
250
+ y_pred_before = gp_fit .predict (X )
251
+ y_pred_after = copied .predict (X )
252
+ assert np .allclose (y_pred_before , y_pred_after )
248
253
249
- def test_clone (self , gp_fit ):
250
- clone (gp_fit ) # doesn't raise
254
+ def test_clone (self , gp_fit , data ):
255
+ clone (gp_fit ) # does not raise
251
256
252
257
def test_save_load_params (self , gp_fit , tmpdir ):
253
258
gp2 = clone (gp_fit ).initialize ()
@@ -335,7 +340,8 @@ def test_grid_search_works(self, gp, data, recwarn):
335
340
params = {
336
341
'lr' : [0.01 , 0.02 ],
337
342
'max_epochs' : [10 , 20 ],
338
- 'likelihood__max_plate_nesting' : [1 , 2 ],
343
+ # this parameter does not exist but that's okay
344
+ 'likelihood__some_parameter' : [1 , 2 ],
339
345
}
340
346
gp .set_params (verbose = 0 )
341
347
gs = GridSearchCV (gp , params , refit = True , cv = 3 , scoring = self .scoring )
@@ -419,32 +425,29 @@ def test_multioutput_predict_proba(self, gp_multioutput, data):
419
425
])
420
426
def test_set_params_uninitialized_net_correct_message (
421
427
self , gp , kwargs , expected , capsys ):
422
- # When gp is initialized, if module or optimizer need to be
423
- # re-initialized, alert the user to the fact what parameters
424
- # were responsible for re-initialization. Note that when the
425
- # module parameters but not optimizer parameters were changed,
426
- # the optimizer is re-initialized but not because the
427
- # optimizer parameters changed.
428
+ # When gp is uninitialized, there is nothing to alert the user to
428
429
gp .set_params (** kwargs )
429
430
msg = capsys .readouterr ()[0 ].strip ()
430
431
assert msg == expected
431
432
432
433
@pytest .mark .parametrize ('kwargs,expected' , [
433
434
({}, "" ),
434
435
(
435
- {'likelihood__max_plate_nesting' : 2 },
436
+ # this parameter does not exist but that's okay
437
+ {'likelihood__some_parameter' : 2 },
436
438
("Re-initializing module because the following "
437
- "parameters were re-set: likelihood__max_plate_nesting .\n "
439
+ "parameters were re-set: likelihood__some_parameter .\n "
438
440
"Re-initializing criterion.\n "
439
441
"Re-initializing optimizer." )
440
442
),
441
443
(
442
444
{
443
- 'likelihood__max_plate_nesting' : 2 ,
445
+ # this parameter does not exist but that's okay
446
+ 'likelihood__some_parameter' : 2 ,
444
447
'optimizer__momentum' : 0.567 ,
445
448
},
446
449
("Re-initializing module because the following "
447
- "parameters were re-set: likelihood__max_plate_nesting .\n "
450
+ "parameters were re-set: likelihood__some_parameter .\n "
448
451
"Re-initializing criterion.\n "
449
452
"Re-initializing optimizer." )
450
453
),
@@ -570,23 +573,6 @@ def gp(self, gp_cls, module_cls):
570
573
)
571
574
return gpr
572
575
573
- # pickling and deepcopy work for ExactGPRegressor but not for the others, so
574
- # override the expected failures here.
575
-
576
- def test_pickling (self , gp_fit ):
577
- # does not raise
578
- pickle .dumps (gp_fit )
579
-
580
- def test_pickle_error_msg (self , gp_fit ):
581
- # Should eventually be replaced by a test that saves and loads the model
582
- # using pickle and checks that the predictions are identical
583
- # FIXME
584
- pickle .dumps (gp_fit )
585
-
586
- def test_deepcopy (self , gp_fit ):
587
- # FIXME
588
- copy .deepcopy (gp_fit ) # doesn't raise
589
-
590
576
def test_wrong_module_type_raises (self , gp_cls ):
591
577
# ExactGPRegressor requires the module to be an ExactGP, if it's not,
592
578
# raise an appropriate error message to the user.
@@ -649,6 +635,32 @@ def gp(self, gp_cls, module_cls, data):
649
635
assert gpr .batch_size < self .n_samples
650
636
return gpr
651
637
638
+ # Since GPyTorch v1.10, GPRegressor works with pickle/deepcopy.
639
+
640
+ def test_pickling (self , gp_fit , data ):
641
+ # TODO: remove once Python 3.7 is no longer supported
642
+ if version_gpytorch < Version ('1.10' ):
643
+ pytest .skip ("GPyTorch < 1.10 does not support pickling." )
644
+
645
+ loaded = pickle .loads (pickle .dumps (gp_fit ))
646
+ X , _ = data
647
+
648
+ y_pred_before = gp_fit .predict (X )
649
+ y_pred_after = loaded .predict (X )
650
+ assert np .allclose (y_pred_before , y_pred_after )
651
+
652
+ def test_deepcopy (self , gp_fit , data ):
653
+ # TODO: remove once Python 3.7 is no longer supported
654
+ if version_gpytorch < Version ('1.10' ):
655
+ pytest .skip ("GPyTorch < 1.10 does not support deepcopy." )
656
+
657
+ copied = copy .deepcopy (gp_fit )
658
+ X , _ = data
659
+
660
+ y_pred_before = gp_fit .predict (X )
661
+ y_pred_after = copied .predict (X )
662
+ assert np .allclose (y_pred_before , y_pred_after )
663
+
652
664
653
665
class TestGPBinaryClassifier (BaseProbabilisticTests ):
654
666
"""Tests for GPBinaryClassifier."""
@@ -686,11 +698,40 @@ def gp(self, gp_cls, module_cls, data):
686
698
gpc = gp_cls (
687
699
module_cls ,
688
700
module__inducing_points = torch .from_numpy (X [:10 ]),
689
-
701
+ likelihood = MyBernoulliLikelihood ,
690
702
criterion = gpytorch .mlls .VariationalELBO ,
691
703
criterion__num_data = int (0.8 * len (y )),
692
704
batch_size = 24 ,
693
705
)
694
706
# we want to make sure batching is properly tested
695
707
assert gpc .batch_size < self .n_samples
696
708
return gpc
709
+
710
+ # Since GPyTorch v1.10, GPBinaryClassifier is the only estimator left that
711
+ # still has issues with pickling/deepcopying.
712
+
713
+ @pytest .mark .xfail (strict = True )
714
+ def test_pickling (self , gp_fit , data ):
715
+ # Currently fails because of issues outside of our control, this test
716
+ # should alert us to when the issue has been fixed. Some issues have
717
+ # been fixed in https://github.com/cornellius-gp/gpytorch/pull/1336 but
718
+ # not all.
719
+ pickle .dumps (gp_fit )
720
+
721
+ def test_pickle_error_msg (self , gp_fit , data ):
722
+ # Should eventually be replaced by a test that saves and loads the model
723
+ # using pickle and checks that the predictions are identical
724
+ msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
725
+ " https://github.com/pytorch/pytorch/issues/38137. "
726
+ "Try using 'dill' instead of 'pickle'." )
727
+ with pytest .raises (pickle .PicklingError , match = msg ):
728
+ pickle .dumps (gp_fit )
729
+
730
+ def test_deepcopy (self , gp_fit , data ):
731
+ # Should eventually be replaced by a test that saves and loads the model
732
+ # using deepcopy and checks that the predictions are identical
733
+ msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
734
+ " https://github.com/pytorch/pytorch/issues/38137. "
735
+ "Try using 'dill' instead of 'pickle'." )
736
+ with pytest .raises (pickle .PicklingError , match = msg ):
737
+ copy .deepcopy (gp_fit ) # doesn't raise
0 commit comments