Skip to content

Disconnected node in model graph after deterministic operations #7722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
williambdean opened this issue Mar 14, 2025 · 0 comments
Open

Disconnected node in model graph after deterministic operations #7722

williambdean opened this issue Mar 14, 2025 · 0 comments

Comments

@williambdean
Copy link
Contributor

The models that are now allowed following #7656 have a disconnected node in the model graph.

The sampling is as expected. It is just the graphviz representation that is incorrect.

import numpy as np
import pymc as pm
from pymc.model_graph import ModelGraph

seed = sum(map(ord, "Observed disconnected node"))
rng = np.random.default_rng(seed)

true_mu = 100
true_sigma = 30

n_obs = 10
coords = {
    "date": np.arange(n_obs),
}

dist = pm.Normal.dist(mu=true_mu, sigma=true_sigma, shape=n_obs)
data = pm.draw(dist, random_seed=rng)

scaling = data.max()

with pm.Model(coords=coords) as model:
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma")

    target = pm.Data("target", data, dims="date")
    scaled_target = target / scaling

    pm.Normal("observed", mu=mu, sigma=sigma, observed=scaled_target, dims="date")

pm.model_to_graphviz(model).render("scaled_target")

ModelGraph(model).make_compute_graph()

The observed should have "target" in the compute_graph

defaultdict(set,
            {'mu': set(),
             'sigma': set(),
             'target': set(),
             'observed': {'mu', 'sigma'}})

Seems like it needs a fix here:

pymc/pymc/model_graph.py

Lines 322 to 343 in af81955

if var in self.model.observed_RVs:
obs_node = self.model.rvs_to_values[var]
# loop created so that the elif block can go through this again
# and remove any intermediate ops, notably dtype casting, to observations
while True:
obs_name = obs_node.name
if obs_name and obs_name != var_name:
input_map[var_name] = input_map[var_name].difference({obs_name})
input_map[obs_name] = input_map[obs_name].union({var_name})
break
elif (
# for cases where observations are cast to a certain dtype
# see issue 5795: https://github.com/pymc-devs/pymc/issues/5795
obs_node.owner
and isinstance(obs_node.owner.op, Elemwise)
and isinstance(obs_node.owner.op.scalar_op, Cast)
):
# we can retrieve the observation node by going up the graph
obs_node = obs_node.owner.inputs[0]
else:
break

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant