Skip to content

Commit 2e6a63d

Browse files
committed
Broadcast shapes of alpha and beta in Weibull rng
1 parent a74c03f commit 2e6a63d

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2499,6 +2499,8 @@ def __call__(self, alpha, beta, size=None, **kwargs):
24992499

25002500
@classmethod
25012501
def rng_fn(cls, rng, alpha, beta, size) -> np.ndarray:
2502+
if size is None:
2503+
size = np.broadcast_shapes(alpha.shape, beta.shape)
25022504
return np.asarray(beta * rng.weibull(alpha, size=size))
25032505

25042506

tests/distributions/test_continuous.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,6 +2369,25 @@ def seeded_weibul_rng_fn(self):
23692369
"check_rv_size",
23702370
]
23712371

2372+
# See issue #7220
2373+
def test_rng_different_shapes(self):
2374+
rng = np.random.default_rng(123)
2375+
2376+
# Simulate mean from an only-intercept model. 2 chains, 100 draws, 5 observations.
2377+
# So 'mu' is the same for all the observations (because it's intercept-only)
2378+
mu_draws = np.abs(150 + np.dstack([rng.normal(size=(2, 100, 1))] * 5))
2379+
2380+
# Simulate some alpha values
2381+
alpha_draws = np.abs(rng.normal(size=(2, 100, 1)))
2382+
2383+
# With 'mu' and 'alpha' get 'beta', which is what pm.Weibull needs
2384+
beta_draws = mu_draws / sp.gamma(1 + 1 / alpha_draws)
2385+
2386+
# See the draws, for a given chain and draw, they look all the same!
2387+
weibull_draws = pm.draw(pm.Weibull.dist(alpha=alpha_draws, beta=beta_draws))
2388+
2389+
assert not (weibull_draws == weibull_draws[:, :, 0][..., None]).all()
2390+
23722391

23732392
@pytest.mark.skipif(
23742393
condition=_polyagamma_not_installed,

0 commit comments

Comments
 (0)