Skip to content

Commit 6866bf5

Browse files
committed
Fixes pyro-ppl#3430 (serializing an easyguide corrupts it).
1 parent e1293e6 commit 6866bf5

File tree

2 files changed

+1
-2
lines changed

2 files changed

+1
-2
lines changed

pyro/contrib/easyguide/easyguide.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def __init__(self, guide, sites):
230230
self.event_shape = torch.Size([sum(self._site_sizes.values())])
231231

232232
def __getstate__(self):
233-
state = getattr(super(), "__getstate__", self.__dict__.copy)()
233+
state = getattr(super(), "__getstate__", lambda: self.__dict__)().copy()
234234
state["_guide"] = state["_guide"]() # weakref -> ref
235235
return state
236236

tests/contrib/easyguide/test_easyguide.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def guide(self, batch, subsample, full_size):
7979
self.group(match="state_[0-9]*").map_estimate()
8080

8181

82-
@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3430")
8382
def test_serialize():
8483
guide = PickleGuide(model)
8584
check_guide(guide)

0 commit comments

Comments
 (0)