Skip to content

Commit af1adc9

Browse files
sidravi1rlouf
authored andcommitted
added real-time Rhat
allows online metrics to be passed to sample_loop
1 parent f79a34e commit af1adc9

File tree

6 files changed

+197
-27
lines changed

6 files changed

+197
-27
lines changed

mcx/diagnostics/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .gelman_rubin import online_gelman_rubin
2+
3+
__all__ = [
4+
"online_gelman_rubin",
5+
]

mcx/diagnostics/gelman_rubin.py

+138-15
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,165 @@
11
"""Kernel to compute the Gelman-Rubin convergence diagnostic (Rhat) online.
22
"""
3-
from typing import NamedTuple
3+
from typing import Callable, NamedTuple, Tuple
44

5+
import jax
56
import jax.numpy as jnp
67

7-
from mcx.inference.warmup.mass_matrix_adaptation import (
8-
WelfordAlgorithmState,
9-
welford_algorithm,
10-
)
8+
9+
class WelfordAlgorithmState(NamedTuple):
10+
"""State carried through the Welford algorithm.
11+
12+
mean
13+
The running sample mean.
14+
m2
15+
The running value of the sum of difference of squares. See documentation
16+
of the `welford_algorithm` function for an explanation.
17+
sample_size
18+
The number of successive states the previous values have been computed on;
19+
also the current number of iterations of the algorithm.
20+
"""
21+
22+
mean: float
23+
m2: float
24+
sample_size: int
1125

1226

1327
class GelmanRubinState(NamedTuple):
1428
w_state: WelfordAlgorithmState
15-
rhat: float
29+
rhat: jnp.DeviceArray
30+
metric: jnp.DeviceArray
31+
metric_name: str
32+
33+
34+
def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]:
35+
"""Welford's online estimator of covariance.
36+
37+
It is possible to compute the variance of a population of values in an
38+
on-line fashion to avoid storing intermediate results. The naive recurrence
39+
relations between the sample mean and variance at a step and the next are
40+
however not numerically stable.
41+
42+
Welford's algorithm uses the sum of square of differences
43+
:math:`M_{2,n} = \\sum_{i=1}^n \\left(x_i-\\overline{x_n}\right)^2`
44+
for updating where :math:`x_n` is the current mean and the following
45+
recurrence relationships
46+
47+
Parameters
48+
----------
49+
is_diagonal_matrix
50+
When True the algorithm adapts and returns a diagonal mass matrix
51+
(default), otherwise adapts and returns a dense mass matrix.
52+
53+
.. math:
54+
M_{2,n} = M_{2, n-1} + (x_n-\\overline{x}_{n-1})(x_n-\\overline{x}_n)
55+
\\sigma_n^2 = \\frac{M_{2,n}}{n}
56+
"""
57+
58+
def init(n_chains: int, n_dims: int) -> WelfordAlgorithmState:
59+
"""Initialize the covariance estimation.
60+
61+
When the matrix is diagonal it is sufficient to work with an array that contains
62+
the diagonal value. Otherwise we need to work with the matrix in full.
63+
64+
Parameters
65+
----------
66+
n_chains: int
67+
The number of chains being run
68+
n_dims: int
69+
The number of variables
70+
"""
71+
sample_size = 0
72+
mean = jnp.zeros((n_chains, n_dims))
73+
if is_diagonal_matrix:
74+
m2 = jnp.zeros((n_chains, n_dims))
75+
else:
76+
m2 = jnp.zeros((n_chains, n_chains, n_dims))
77+
return WelfordAlgorithmState(mean, m2, sample_size)
78+
79+
@jax.jit
80+
def update(
81+
state: WelfordAlgorithmState, value: jnp.DeviceArray
82+
) -> WelfordAlgorithmState:
83+
"""Update the M2 matrix using the new value.
84+
85+
Parameters
86+
----------
87+
state: WelfordAlgorithmState
88+
The current state of the Welford Algorithm
89+
value: jax.numpy.DeviceArray, shape (1,)
90+
The new sample (typically position of the chain) used to update m2
91+
"""
92+
mean, m2, sample_size = state
93+
sample_size = sample_size + 1
94+
95+
delta = value - mean
96+
mean = mean + delta / sample_size
97+
updated_delta = value - mean
98+
if is_diagonal_matrix:
99+
new_m2 = m2 + delta * updated_delta
100+
else:
101+
new_m2 = m2 + jnp.outer(updated_delta, delta)
102+
103+
return WelfordAlgorithmState(mean, new_m2, sample_size)
104+
105+
def covariance(
106+
state: WelfordAlgorithmState,
107+
) -> Tuple[jnp.DeviceArray, int, jnp.DeviceArray]:
108+
mean, m2, sample_size = state
109+
covariance = m2 / (sample_size - 1)
110+
return covariance, sample_size, mean
111+
112+
return init, update, covariance
16113

17114

18115
def online_gelman_rubin():
19116
"""Online estimation of the Gelman-Rubin diagnostic."""
20117

21118
w_init, w_update, w_covariance = welford_algorithm(True)
22119

23-
def init(num_chains):
24-
w_state = w_init(num_chains)
25-
return GelmanRubinState(w_state, 0)
120+
def init(init_state):
121+
"""Initialise the online gelman/rubin estimator
122+
123+
Parameters
124+
----------
125+
num_chains: int
126+
The number of chains being run
127+
128+
Returns
129+
-------
130+
GelmanRubinState with all values set to zeros.
131+
132+
"""
133+
n_chains, n_dims = init_state.position.shape
134+
w_state = w_init(n_chains, n_dims)
135+
return GelmanRubinState(w_state, 0, jnp.nan, "worst_rhat")
26136

27137
def update(chain_state, rhat_state):
28-
within_state, step, num_chains, _, _, _ = rhat_state
138+
"""Update rhat estimates
139+
140+
Parameters
141+
----------
142+
chain_state: HMCState
143+
The chain state
144+
rhat_state: GelmanRubinState
145+
The GelmanRubinState from the previous draw
146+
147+
Returns
148+
-------
149+
An updated GelmanRubinState object
150+
"""
151+
within_state, _, _, metric_name = rhat_state
29152

30153
positions = chain_state.position
31154
within_state = w_update(within_state, positions)
32-
33-
covariance, step, mean = w_covariance(rhat_state)
34-
within_var = jnp.mean(covariance)
35-
between_var = jnp.var(mean, ddof=1)
155+
covariance, step, mean = w_covariance(within_state)
156+
within_var = jnp.mean(covariance, axis=0)
157+
between_var = jnp.var(mean, axis=0, ddof=1)
36158
estimator = ((step - 1) / step) * within_var + between_var
37159
rhat = jnp.sqrt(estimator / within_var)
160+
worst_rhat = rhat[jnp.argmax(jnp.abs(rhat - 1.0))]
38161

39-
return GelmanRubinState(within_state, rhat)
162+
return GelmanRubinState(within_state, rhat, worst_rhat, metric_name)
40163

41164
return init, update
42165

mcx/distributions/mvnormal.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def __init__(self, mu, covariance_matrix):
3939
if (mu_event_shape, mu_event_shape) != covariance_event_shape:
4040
raise ValueError(
4141
(
42-
f"The number of dimensions implied by `mu` ({mu_event_shape}),"
43-
"does not match the dimensions implied by `covariance_matrix` "
44-
f"({covariance_event_shape})"
42+
f"The number of dimensions implied by `mu`(dims = {mu_event_shape})"
43+
", does not match the dimensions implied by `covariance_matrix`"
44+
f"(dims = {covariance_event_shape})"
4545
)
4646
)
4747

mcx/sample.py

+47-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Sample from the multivariate distribution defined by the model."""
2-
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
2+
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple
33

44
import jax
55
import jax.numpy as jnp
@@ -8,6 +8,7 @@
88
from tqdm import tqdm
99

1010
import mcx
11+
from mcx.diagnostics import online_gelman_rubin
1112
from mcx.jax import progress_bar_factory
1213
from mcx.jax import ravel_pytree as mcx_ravel_pytree
1314
from mcx.trace import Trace
@@ -321,6 +322,9 @@ def run(
321322
num_samples: int = 1000,
322323
num_warmup_steps: int = 1000,
323324
compile: bool = False,
325+
metrics: Sequence[Callable[..., Tuple[Callable, Callable]]] = [
326+
online_gelman_rubin
327+
],
324328
**warmup_kwargs,
325329
) -> Trace:
326330
"""Run the posterior inference.
@@ -336,9 +340,13 @@ def run(
336340
num_warmup_steps
337341
The number of warmup_steps to perform.
338342
compile
339-
If False the progress of the warmup and samplig will be displayed.
343+
If False the progress of the warmup and sampling will be displayed.
340344
Otherwise it will use `lax.scan` to iterate (which is potentially
341345
faster).
346+
metrics
347+
A list of functions to generate online metrics when sampling. Only
348+
used when `compile` is False. Each function must return two functions -
349+
an `init` function and an `update` function.
342350
warmup_kwargs
343351
Parameters to pass to the evaluator's warmup.
344352
@@ -349,6 +357,10 @@ def run(
349357
the inference process (e.g. divergences for evaluators in the
350358
HMC family).
351359
360+
Notes
361+
-----
362+
Passing functions to `metrics` may slow down sampling. It may be useful to have
363+
online metrics when building or diagnosing a model.
352364
"""
353365
if not self.is_warmed_up:
354366
self.warmup(num_warmup_steps, compile, **warmup_kwargs)
@@ -373,8 +385,15 @@ def update_one_chain(rng_key, parameters, chain_state):
373385
update_one_chain, self.state, self.parameters, rng_keys, self.num_chains
374386
)
375387
else:
388+
if metrics is None:
389+
metrics = ()
376390
last_state, chain = sample_loop(
377-
update_one_chain, self.state, self.parameters, rng_keys, self.num_chains
391+
update_one_chain,
392+
self.state,
393+
self.parameters,
394+
rng_keys,
395+
self.num_chains,
396+
metrics,
378397
)
379398

380399
samples, sampling_info = self.evaluator.make_trace(
@@ -466,6 +485,7 @@ def sample_loop(
466485
parameters: jnp.DeviceArray,
467486
rng_keys: jnp.DeviceArray,
468487
num_chains: int,
488+
metrics: Sequence[Callable[..., Tuple[Callable, Callable]]],
469489
) -> Tuple:
470490
"""Sample using a Python loop.
471491
@@ -507,8 +527,12 @@ def sample_loop(
507527
The parameters of the evaluator.
508528
rng_keys: array (n_samples,)
509529
JAX PRNGKeys used for each sampling step.
510-
num_chains
530+
num_chains : int
511531
The number of chains
532+
metrics:
533+
A list of functions to generate real-time metrics when sampling.
534+
Each function must return two functions - an `init` function and
535+
an `update` function.
512536
513537
Returns
514538
-------
@@ -531,7 +555,15 @@ def get_unravel_fn():
531555

532556
_, unravel_fn = get_unravel_fn()
533557

534-
with tqdm(rng_keys, unit="samples") as progress:
558+
metrics_init, metrics_update = [], []
559+
for metric_func in metrics:
560+
init_func, update_func = metric_func()
561+
metrics_init.append(init_func)
562+
metrics_update.append(update_func)
563+
564+
metrics_state = [init_func(init_state) for init_func in metrics_init]
565+
566+
with tqdm(rng_keys, unit="samples", mininterval=0.1) as progress:
535567
progress.set_description(
536568
f"Collecting {num_samples:,} samples across {num_chains:,} chains",
537569
refresh=False,
@@ -540,7 +572,17 @@ def get_unravel_fn():
540572
state = init_state
541573
try:
542574
for _, key in enumerate(progress):
575+
metrics_state = [
576+
update_func(state, m_state)
577+
for update_func, m_state in zip(metrics_update, metrics_state)
578+
]
543579
state, _, ravelled_state = update_loop(state, key)
580+
postfix_dict = {
581+
m_state.metric_name: f"{m_state.metric:0.2f}"
582+
for m_state in metrics_state
583+
}
584+
if postfix_dict:
585+
progress.set_postfix(postfix_dict)
544586
chain.append(ravelled_state)
545587
except KeyboardInterrupt:
546588
pass

mcx/trace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Trace(InferenceData):
8585
The integration with ArviZ is seemless: MCX traces can be passed to ArviZ's
8686
diagnostics, statistics and plotting functions.
8787
88-
>>> import arvix as az
88+
>>> import arviz as az
8989
>>> az.plot_trace(trace)
9090
9191

tests/hmc_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def linear_regression(x, lmbda=1.0):
2424
def linear_regression_mvn(x, lmbda=1.):
2525
sigma <~ dist.Exponential(lmbda)
2626
sigma2 <~ dist.Exponential(lmbda)
27-
rho <~ dist.Uniform(0, 1)
28-
cov = jnp.array([[sigma, rho*sigma*sigma2],[rho*sigma*sigma2, sigma2]])
27+
rho <~ dist.Uniform(-1, 1)
28+
cov = jnp.array([[sigma**2, rho*sigma*sigma2],[rho*sigma*sigma2, sigma2**2]])
2929
coeffs <~ dist.MvNormal(jnp.ones(x.shape[-1]), cov)
3030
y = jnp.dot(x, coeffs)
3131
predictions <~ dist.Normal(y, sigma)
@@ -70,7 +70,7 @@ def test_linear_regression_mvn():
7070
y_data = x_data @ np.array([3, 1]) + np.random.normal(size=x_data.shape[0])
7171

7272
kernel = HMC(
73-
num_integration_steps=90,
73+
num_integration_steps=10,
7474
)
7575

7676
rng_key = jax.random.PRNGKey(2)

0 commit comments

Comments
 (0)