|
15 | 15 | import numpy as np
|
16 | 16 | import numpy.testing as npt
|
17 | 17 | import pytest
|
| 18 | +import xarray as xr |
18 | 19 |
|
19 |
| -from pymc import Data, Model, Normal, sample |
| 20 | +from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample |
20 | 21 |
|
21 | 22 |
|
22 | 23 | @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
|
@@ -86,3 +87,55 @@ def test_step_args():
|
86 | 87 | )
|
87 | 88 |
|
88 | 89 | npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)
|
| 90 | + |
| 91 | + |
| 92 | +@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) |
| 93 | +def test_sample_var_names(nuts_sampler): |
| 94 | + seed = 1234 |
| 95 | + kwargs = { |
| 96 | + "chains": 1, |
| 97 | + "tune": 100, |
| 98 | + "draws": 100, |
| 99 | + "random_seed": seed, |
| 100 | + "progressbar": False, |
| 101 | + "compute_convergence_checks": False, |
| 102 | + } |
| 103 | + |
| 104 | + # Generate data |
| 105 | + rng = np.random.default_rng(seed) |
| 106 | + |
| 107 | + group = rng.choice(list("ABCD"), size=100) |
| 108 | + x = rng.normal(size=100) |
| 109 | + y = rng.normal(size=100) |
| 110 | + |
| 111 | + group_values, group_idx = np.unique(group, return_inverse=True) |
| 112 | + |
| 113 | + coords = {"group": group_values} |
| 114 | + |
| 115 | + # Create model |
| 116 | + with Model(coords=coords) as model: |
| 117 | + b_group = Normal("b_group", dims="group") |
| 118 | + b_x = Normal("b_x") |
| 119 | + mu = Deterministic("mu", b_group[group_idx] + b_x * x) |
| 120 | + sigma = HalfNormal("sigma") |
| 121 | + Normal("y", mu=mu, sigma=sigma, observed=y) |
| 122 | + |
| 123 | + free_RVs = [var.name for var in model.free_RVs] |
| 124 | + |
| 125 | + with model: |
| 126 | + # Sample with and without var_names, but always with the same seed |
| 127 | + idata_1 = sample(nuts_sampler=nuts_sampler, **kwargs) |
| 128 | + # Remove the last free RV from the sampling |
| 129 | + idata_2 = sample(nuts_sampler=nuts_sampler, var_names=free_RVs[:-1], **kwargs) |
| 130 | + |
| 131 | + assert "mu" in idata_1.posterior |
| 132 | + assert "mu" not in idata_2.posterior |
| 133 | + |
| 134 | + assert free_RVs[-1] in idata_1.posterior |
| 135 | + assert free_RVs[-1] not in idata_2.posterior |
| 136 | + |
| 137 | + for var in free_RVs[:-1]: |
| 138 | + assert var in idata_1.posterior |
| 139 | + assert var in idata_2.posterior |
| 140 | + |
| 141 | + xr.testing.assert_allclose(idata_1.posterior[var], idata_2.posterior[var]) |
0 commit comments