Skip to content

Should branching logps accept constants #7711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Mar 4, 2025 · 0 comments
Open

Should branching logps accept constants #7711

ricardoV94 opened this issue Mar 4, 2025 · 0 comments
Labels

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 4, 2025

Description

The following example illustrates a restriction in the current logp derivations, when branch includes constants

import pytensor.tensor as pt
import pymc as pm

t = pt.arange(10)
cat = pm.Categorical.dist(p=[0.5, 0.5], shape=(10,))
# cat_fixed = pt.where(t > 5, cat, -1)  # Not accepted because -1 is not measurable
cat_fixed = pt.where(t > 5, cat, pm.DiracDelta.dist(-1, shape=cat.shape))  # fine
pm.logp(cat_fixed, cat_fixed.type())

Should we allow it? This also applies to operations like join and make_vector where one may combine measurable and constant inputs.

If we allow it should we also allow broadcasting? This is currently not allowed (hence the need for shape=cat.shape) because the logp of broadcasted operations can be tricky to handle systematically, but for constants it may be fine?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant