Skip to content

Serializing an easyguide corrupts it #3430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
BenZickel opened this issue Apr 9, 2025 · 0 comments
Closed

Serializing an easyguide corrupts it #3430

BenZickel opened this issue Apr 9, 2025 · 0 comments
Labels

Comments

@BenZickel
Copy link
Contributor

Issue Description

Serializing an easyguide corrupts the easyguide being serialized, while the deserialized easyguide actually works.

Environment

  • OS: WSL
  • Python version: 3.11.7
  • PyTorch version: 2.6.0+cu124
  • Pyro version: 1.9.1+455f7b3b

Code Snippet

The issue can be replicated by running the easyguide serialization test:

pyro-ppl$ pytest tests/contrib/easyguide/test_easyguide.py::test_serialize
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.11.7, pytest-8.3.5, pluggy-1.5.0
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /mnt/c/sw/pyro-ppl
configfile: setup.cfg
plugins: nbval-0.11.0, benchmark-4.0.0, cov-6.1.1, xdist-3.6.1, anyio-4.9.0
collected 1 item

tests/contrib/easyguide/test_easyguide.py F                                                                                                                                 [100%]

==================================================================================== FAILURES =====================================================================================
_________________________________________________________________________________ test_serialize __________________________________________________________________________________

    def test_serialize():
        guide = PickleGuide(model)
        check_guide(guide)

        # Work around https://github.com/pytorch/pytorch/issues/27972
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
            f = io.BytesIO()
            torch.save(guide, f)
            f.seek(0)
            actual = torch.load(f, weights_only=False)

        assert type(actual) == type(guide)
        assert dir(actual) == dir(guide)
>       check_guide(guide)

tests/contrib/easyguide/test_easyguide.py:96:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/contrib/easyguide/test_easyguide.py:54: in check_guide
    svi.step(batch, subsample, full_size=full_size)
pyro/infer/svi.py:145: in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
pyro/infer/trace_elbo.py:140: in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
pyro/infer/elbo.py:239: in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
pyro/infer/trace_elbo.py:57: in _get_trace
    model_trace, guide_trace = get_importance_trace(
pyro/infer/enum.py:60: in get_importance_trace
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
pyro/poutine/trace_messenger.py:216: in get_trace
    self(*args, **kwargs)
pyro/poutine/trace_messenger.py:191: in __call__
    ret = self.fn(*args, **kwargs)
pyro/nn/module.py:527: in __call__
    result = super().__call__(*args, **kwargs)
/home/ben/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1739: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/home/ben/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1750: in _call_impl
    return forward_call(*args, **kwargs)
pyro/contrib/easyguide/easyguide.py:104: in forward
    result = self.guide(*args, **kwargs)
tests/contrib/easyguide/test_easyguide.py:79: in guide
    self.group(match="state_[0-9]*").map_estimate()
pyro/contrib/easyguide/easyguide.py:312: in map_estimate
    return {
pyro/contrib/easyguide/easyguide.py:313: in <dictcomp>
    site["name"]: self.guide.map_estimate(site["name"])
pyro/contrib/easyguide/easyguide.py:243: in guide
    return self._guide()
pyro/nn/module.py:527: in __call__
    result = super().__call__(*args, **kwargs)
/home/ben/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1739: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/home/ben/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1750: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = PickleGuide(), args = (), kwargs = {}

    def forward(self, *args, **kwargs):
        """
        Runs the guide. This is typically used by inference algorithms.

        .. note:: This method is used internally by :class:`~torch.nn.Module`.
            Users should instead use :meth:`~torch.nn.Module.__call__`.
        """
        if self.prototype_trace is None:
            self._setup_prototype(*args, **kwargs)
>       result = self.guide(*args, **kwargs)
E       TypeError: PickleGuide.guide() missing 3 required positional arguments: 'batch', 'subsample', and 'full_size'

pyro/contrib/easyguide/easyguide.py:104: TypeError
============================================================================= short test summary info =============================================================================
FAILED tests/contrib/easyguide/test_easyguide.py::test_serialize - TypeError: PickleGuide.guide() missing 3 required positional arguments: 'batch', 'subsample', and 'full_size'
=============================================================================== 1 failed in 19.37s ================================================================================
@fritzo fritzo added the bug label Apr 9, 2025
BenZickel added a commit to BenZickel/pyro that referenced this issue Apr 9, 2025
fritzo pushed a commit that referenced this issue Apr 10, 2025
* Update CI Python version to 3.12.

* Try Python 3.11 as there is some problem with ruff.

* Go back to Python version 3.12 after fixing formatting to comply with latest versions of black and ruff.

* Try going back to Python 3.11.

* Fix to comply with PEP 612 (see microsoft/pyright#5844 (comment)).

* Fix typo

* Go back to previous runtime.py and ignore mypy PEP 612 check.

* Back to Python 3.12,

* Back to Python 3.11.

* Update Sphinx version to 0.5.0.

* Update Sphinx version to 8.2.3.

* Try removing Sphnix version in order solve dependency conflict.

* List all Python modules for CI debug.

* Try with Python 3.12 due to jaraco/path#231.

* Unpin setuptools version.

* Updated Sphinx version to 7.3.7.

* Fix Sphinx issue (works on local machine).

* Remove unused import.

* Try different Sphinx versions.

* Use Sphnix upgrade path.

* Convert random seed to int data type.

* Update levy-stable kstest.

* Mark easyguide/test_easyguide.py::test_serialize as an expected failure according to the bug mentioned at #3430.
BenZickel added a commit to BenZickel/pyro that referenced this issue Apr 14, 2025
@fritzo fritzo closed this as completed in 147b357 Apr 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants