@@ -528,19 +528,26 @@ def test_hparams_pickle(tmpdir):
528
528
class UnpickleableArgsBoringModel (BoringModel ):
529
529
"""A model that has an attribute that cannot be pickled."""
530
530
531
- def __init__ (self , foo = "bar" , pickle_me = (lambda x : x + 1 ), ** kwargs ):
531
+ def __init__ (self , foo = "bar" , pickle_me = (lambda x : x + 1 ), ignore = False , ** kwargs ):
532
532
super ().__init__ (** kwargs )
533
533
assert not is_picklable (pickle_me )
534
- self .save_hyperparameters ()
534
+ if ignore :
535
+ self .save_hyperparameters (ignore = ["pickle_me" ])
536
+ else :
537
+ self .save_hyperparameters ()
535
538
536
539
537
540
def test_hparams_pickle_warning (tmpdir ):
538
541
model = UnpickleableArgsBoringModel ()
539
542
trainer = Trainer (default_root_dir = tmpdir , max_steps = 1 )
540
- with pytest .warns (UserWarning , match = "attribute 'pickle_me' removed from hparams because it cannot be pickled" ):
543
+ with pytest .warns (UserWarning , match = "Attribute 'pickle_me' removed from hparams because it cannot be pickled" ):
541
544
trainer .fit (model )
542
545
assert "pickle_me" not in model .hparams
543
546
547
+ model = UnpickleableArgsBoringModel (ignore = True )
548
+ with no_warning_call (UserWarning , match = "Attribute 'pickle_me' removed from hparams because it cannot be pickled" ):
549
+ trainer .fit (model )
550
+
544
551
545
552
def test_hparams_save_yaml (tmpdir ):
546
553
class Options (str , Enum ):
0 commit comments