Skip to content

Commit f23b3b1

Browse files
authored
Improve warning for unpickable hyperparameter (#19581)
1 parent b871f7a commit f23b3b1

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

src/lightning/pytorch/utilities/parsing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def clean_namespace(hparams: MutableMapping) -> None:
4141
del_attrs = [k for k, v in hparams.items() if not is_picklable(v)]
4242

4343
for k in del_attrs:
44-
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
44+
rank_zero_warn(
45+
f"Attribute '{k}' removed from hparams because it cannot be pickled. You can suppress this warning by"
46+
f" setting `self.save_hyperparameters(ignore=['{k}'])`.",
47+
)
4548
del hparams[k]
4649

4750

tests/tests_pytorch/models/test_hparams.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,19 +528,26 @@ def test_hparams_pickle(tmpdir):
528528
class UnpickleableArgsBoringModel(BoringModel):
529529
"""A model that has an attribute that cannot be pickled."""
530530

531-
def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), **kwargs):
531+
def __init__(self, foo="bar", pickle_me=(lambda x: x + 1), ignore=False, **kwargs):
532532
super().__init__(**kwargs)
533533
assert not is_picklable(pickle_me)
534-
self.save_hyperparameters()
534+
if ignore:
535+
self.save_hyperparameters(ignore=["pickle_me"])
536+
else:
537+
self.save_hyperparameters()
535538

536539

537540
def test_hparams_pickle_warning(tmpdir):
538541
model = UnpickleableArgsBoringModel()
539542
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
540-
with pytest.warns(UserWarning, match="attribute 'pickle_me' removed from hparams because it cannot be pickled"):
543+
with pytest.warns(UserWarning, match="Attribute 'pickle_me' removed from hparams because it cannot be pickled"):
541544
trainer.fit(model)
542545
assert "pickle_me" not in model.hparams
543546

547+
model = UnpickleableArgsBoringModel(ignore=True)
548+
with no_warning_call(UserWarning, match="Attribute 'pickle_me' removed from hparams because it cannot be pickled"):
549+
trainer.fit(model)
550+
544551

545552
def test_hparams_save_yaml(tmpdir):
546553
class Options(str, Enum):

0 commit comments

Comments
 (0)