Skip to content

Commit ae43026

Browse files
authored
Allow var_names to be propagated to nutpie sampler (pymc-devs#7850)
1 parent 751f9a8 commit ae43026

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

conda-envs/environment-alternative-backends.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- cloudpickle
1212
- zarr>=2.5.0,<3
1313
- numba
14-
- nutpie >= 0.13.4
14+
- nutpie >= 0.15.1
1515
# Jaxlib version must not be greater than jax version!
1616
- blackjax>=1.2.2
1717
- jax>=0.4.28

pymc/sampling/mcmc.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,18 +331,15 @@ def _sample_external_nuts(
331331
"`idata_kwargs` are currently ignored by the nutpie sampler",
332332
UserWarning,
333333
)
334-
if var_names is not None:
335-
warnings.warn(
336-
"`var_names` are currently ignored by the nutpie sampler",
337-
UserWarning,
338-
)
334+
339335
compile_kwargs = {}
340336
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
341337
for kwarg in ("backend", "gradient_backend"):
342338
if kwarg in nuts_sampler_kwargs:
343339
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
344340
compiled_model = nutpie.compile_pymc_model(
345341
model,
342+
var_names=var_names,
346343
**compile_kwargs,
347344
)
348345
t_start = time.time()

tests/sampling/test_mcmc_external.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import numpy as np
1616
import numpy.testing as npt
1717
import pytest
18+
import xarray as xr
1819

19-
from pymc import Data, Model, Normal, sample
20+
from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample
2021

2122

2223
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
@@ -86,3 +87,55 @@ def test_step_args():
8687
)
8788

8889
npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
90+
91+
92+
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
93+
def test_sample_var_names(nuts_sampler):
94+
seed = 1234
95+
kwargs = {
96+
"chains": 1,
97+
"tune": 100,
98+
"draws": 100,
99+
"random_seed": seed,
100+
"progressbar": False,
101+
"compute_convergence_checks": False,
102+
}
103+
104+
# Generate data
105+
rng = np.random.default_rng(seed)
106+
107+
group = rng.choice(list("ABCD"), size=100)
108+
x = rng.normal(size=100)
109+
y = rng.normal(size=100)
110+
111+
group_values, group_idx = np.unique(group, return_inverse=True)
112+
113+
coords = {"group": group_values}
114+
115+
# Create model
116+
with Model(coords=coords) as model:
117+
b_group = Normal("b_group", dims="group")
118+
b_x = Normal("b_x")
119+
mu = Deterministic("mu", b_group[group_idx] + b_x * x)
120+
sigma = HalfNormal("sigma")
121+
Normal("y", mu=mu, sigma=sigma, observed=y)
122+
123+
free_RVs = [var.name for var in model.free_RVs]
124+
125+
with model:
126+
# Sample with and without var_names, but always with the same seed
127+
idata_1 = sample(nuts_sampler=nuts_sampler, **kwargs)
128+
# Remove the last free RV from the sampling
129+
idata_2 = sample(nuts_sampler=nuts_sampler, var_names=free_RVs[:-1], **kwargs)
130+
131+
assert "mu" in idata_1.posterior
132+
assert "mu" not in idata_2.posterior
133+
134+
assert free_RVs[-1] in idata_1.posterior
135+
assert free_RVs[-1] not in idata_2.posterior
136+
137+
for var in free_RVs[:-1]:
138+
assert var in idata_1.posterior
139+
assert var in idata_2.posterior
140+
141+
xr.testing.assert_allclose(idata_1.posterior[var], idata_2.posterior[var])

0 commit comments

Comments
 (0)