Skip to content

Commit f79a34e

Browse files
zoj613rlouf
authored andcommitted
Add InverseGamma distribution
1 parent 6487d25 commit f79a34e

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed

mcx/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .exponential import Exponential
1010
from .gamma import Gamma
1111
from .halfnormal import HalfNormal
12+
from .inverse_gamma import InverseGamma
1213
from .lognormal import LogNormal
1314
from .mvnormal import MvNormal
1415
from .normal import Normal
@@ -26,6 +27,7 @@
2627
"DiscreteUniform",
2728
"Exponential",
2829
"Gamma",
30+
"InverseGamma",
2931
"LogNormal",
3032
"MvNormal",
3133
"HalfNormal",

mcx/distributions/inverse_gamma.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from jax import lax
2+
from jax import numpy as jnp
3+
from jax import random
4+
from jax.scipy.special import gammaln
5+
6+
from mcx.distributions import constraints
7+
from mcx.distributions.distribution import Distribution
8+
from mcx.distributions.shapes import promote_shapes
9+
10+
11+
class InverseGamma(Distribution):
12+
parameters = {
13+
"a": constraints.strictly_positive,
14+
"b": constraints.strictly_positive,
15+
}
16+
support = constraints.strictly_positive
17+
18+
def __init__(self, a, b):
19+
self.event_shape = ()
20+
a, b = promote_shapes(a, b)
21+
batch_shape = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
22+
self.batch_shape = batch_shape
23+
self.a = jnp.broadcast_to(a, batch_shape)
24+
self.b = jnp.broadcast_to(b, batch_shape)
25+
26+
def sample(self, rng_key, sample_shape=()):
27+
shape = sample_shape + self.batch_shape + self.event_shape
28+
# IF X ~ Gamma(a, scale=1/b), then 1/X ~ Inverse-Gamma(a, scale=b)
29+
return self.b / random.gamma(rng_key, self.a, shape)
30+
31+
@constraints.limit_to_support
32+
def logpdf(self, x):
33+
# We use the fact that f(x;a,b) = f(x/b;a,1) / b to improve
34+
# numerical stability for small values of ``x`` that can blow up th
35+
# logp value if not re-scaled.
36+
y = x / self.b
37+
return -(self.a + 1) * jnp.log(y) - gammaln(self.a) - 1.0 / y - jnp.log(self.b)
+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import numpy as np
2+
import pytest
3+
from jax import numpy as jnp
4+
from jax import random
5+
6+
from mcx.distributions import InverseGamma
7+
8+
9+
@pytest.fixture
10+
def rng_key():
11+
return random.PRNGKey(123)
12+
13+
14+
#
15+
# SAMPLING CORRECTNESS
16+
#
17+
18+
19+
def invgamma_mean(a, b):
20+
# only defined for a > 1
21+
return b / (a - 1)
22+
23+
24+
def invgamma_variance(a, b):
25+
# only defined for a > 2
26+
return (b ** 2) / ((a - 1) ** 2 * (a - 2))
27+
28+
29+
sample_mean_cases = [
30+
{"a": 5.5, "b": 5, "expected": invgamma_mean(5.5, 5)},
31+
{"a": 15, "b": 2.0, "expected": invgamma_mean(15, 2.0)},
32+
{"a": 20, "b": 1.5, "expected": invgamma_mean(20, 1.5)},
33+
]
34+
35+
36+
@pytest.mark.parametrize("case", sample_mean_cases)
37+
def test_sample_mean(rng_key, case):
38+
samples = InverseGamma(case["a"], case["b"]).sample(rng_key, (100_000,))
39+
avg = jnp.mean(samples, axis=0).item()
40+
np.testing.assert_almost_equal(avg, case["expected"], decimal=2)
41+
42+
43+
sample_variance_cases = [
44+
{"a": 5.5, "b": 5, "expected": invgamma_variance(5.5, 5)},
45+
{"a": 15, "b": 2.0, "expected": invgamma_variance(15, 2.0)},
46+
{"a": 20, "b": 1.5, "expected": invgamma_variance(20, 1.5)},
47+
]
48+
49+
50+
@pytest.mark.parametrize("case", sample_variance_cases)
51+
def test_sample_variance(rng_key, case):
52+
samples = InverseGamma(case["a"], case["b"]).sample(rng_key, (100_000,))
53+
var = jnp.var(samples, axis=0).item()
54+
np.testing.assert_almost_equal(var, case["expected"], decimal=2)
55+
56+
57+
#
58+
# LOGPDF SHAPES
59+
#
60+
61+
expected_logpdf_shapes = [
62+
{
63+
"x": jnp.array([1]),
64+
"a": jnp.array([0]),
65+
"b": jnp.array([1]),
66+
"expected_shape": (1,),
67+
},
68+
{
69+
"x": jnp.array(1),
70+
"a": jnp.array(0),
71+
"b": jnp.array(1),
72+
"expected_shape": (),
73+
},
74+
{
75+
"x": jnp.ones((5)),
76+
"a": jnp.array(0),
77+
"b": jnp.array(1),
78+
"expected_shape": (5,),
79+
},
80+
{
81+
"x": jnp.ones((8, 1)),
82+
"a": jnp.array([1, 1]),
83+
"b": jnp.array([2, 3]),
84+
"expected_shape": (8, 2),
85+
},
86+
{
87+
"x": jnp.array([1, 2, 3, 4]).reshape(4, 1),
88+
"a": jnp.array([1, 4, 10]),
89+
"b": jnp.array([3, 2, 1]),
90+
"expected_shape": (4, 3),
91+
},
92+
{
93+
"x": jnp.array(1),
94+
"a": jnp.array([1, 2]),
95+
"b": jnp.array([5]),
96+
"expected_shape": (2,),
97+
},
98+
]
99+
100+
101+
@pytest.mark.parametrize("case", expected_logpdf_shapes)
102+
def test_logpdf_shape(case):
103+
log_prob = InverseGamma(a=case["a"], b=case["b"]).logpdf(case["x"])
104+
assert log_prob.shape == case["expected_shape"]
105+
106+
107+
#
108+
# SAMPLING SHAPES
109+
#
110+
111+
112+
@pytest.mark.parametrize(
113+
["a", "b", "sample_shape", "expected_shape"],
114+
[
115+
# 5 1d samples
116+
[jnp.array(1), jnp.array(1), (5,), (5,)],
117+
# 5 samples from 2 inverse-gamma distributions
118+
[jnp.array([1, 2]), jnp.array([1, 1.5]), (5,), (5, 2)],
119+
[jnp.array([1, 2]), jnp.array([1, 2]), (5, 2), (5, 2, 2)],
120+
# 10 samples from 4 inverse-gamma distributions
121+
[jnp.array([1, 2, 3, 4]), jnp.array([1, 2, 5, 10]), (10,), (10, 4)],
122+
# 10 samples from a 2x2 batch of inverse-gammas.
123+
[
124+
jnp.array([[1, 2], [5, 10]]),
125+
jnp.array([[1, 2], [4, 6]]),
126+
(5, 2),
127+
(5, 2, 2, 2),
128+
],
129+
],
130+
)
131+
def test_sampling_shape(a, b, sample_shape, expected_shape, rng_key):
132+
assert InverseGamma(a=a, b=b).sample(rng_key, sample_shape).shape == expected_shape

0 commit comments

Comments
 (0)