@@ -572,6 +572,44 @@ def test_save_load_state_dict_str(
572
572
score_after = accuracy_score (y , net .predict (X ))
573
573
assert np .isclose (score_after , score_before )
574
574
575
+ def test_save_load_state_dict_no_duplicate_registration_after_initialize (
576
+ self , net_cls , module_cls , net_fit , tmpdir ):
577
+ # #781
578
+ net = net_cls (module_cls ).initialize ()
579
+
580
+ p = tmpdir .mkdir ('skorch' ).join ('testmodel.pkl' )
581
+ with open (str (p ), 'wb' ) as f :
582
+ net_fit .save_params (f_params = f )
583
+ del net_fit
584
+
585
+ with open (str (p ), 'rb' ) as f :
586
+ net .load_params (f_params = f )
587
+
588
+ # check that there are no duplicates in _modules, _criteria, _optimizers
589
+ # pylint: disable=protected-access
590
+ assert net ._modules == ['module' ]
591
+ assert net ._criteria == ['criterion' ]
592
+ assert net ._optimizers == ['optimizer' ]
593
+
594
+ def test_save_load_state_dict_no_duplicate_registration_after_clone (
595
+ self , net_fit , tmpdir ):
596
+ # #781
597
+ net = clone (net_fit ).initialize ()
598
+
599
+ p = tmpdir .mkdir ('skorch' ).join ('testmodel.pkl' )
600
+ with open (str (p ), 'wb' ) as f :
601
+ net_fit .save_params (f_params = f )
602
+ del net_fit
603
+
604
+ with open (str (p ), 'rb' ) as f :
605
+ net .load_params (f_params = f )
606
+
607
+ # check that there are no duplicates in _modules, _criteria, _optimizers
608
+ # pylint: disable=protected-access
609
+ assert net ._modules == ['module' ]
610
+ assert net ._criteria == ['criterion' ]
611
+ assert net ._optimizers == ['optimizer' ]
612
+
575
613
@pytest .fixture (scope = 'module' )
576
614
def net_fit_adam (self , net_cls , module_cls , data ):
577
615
net = net_cls (
@@ -1426,6 +1464,15 @@ def test_get_params_works(self, net_cls, module_cls):
1426
1464
# now initialized
1427
1465
assert 'callbacks__myscore__scoring' in params
1428
1466
1467
+ def test_get_params_no_unwanted_params (self , net , net_fit ):
1468
+ # #781
1469
+ # make sure certain keys are not returned
1470
+ keys_unwanted = {'_modules' , '_criteria' , '_optimizers' }
1471
+ for net_ in (net , net_fit ):
1472
+ keys_found = set (net_ .get_params ())
1473
+ overlap = keys_found & keys_unwanted
1474
+ assert not overlap
1475
+
1429
1476
def test_get_params_with_uninit_callbacks (self , net_cls , module_cls ):
1430
1477
from skorch .callbacks import EpochTimer
1431
1478
0 commit comments