|
1 | 1 | """Kernel to compute the Gelman-Rubin convergence diagnostic (Rhat) online.
|
2 | 2 | """
|
3 |
| -from typing import NamedTuple |
| 3 | +from typing import Callable, NamedTuple, Tuple |
4 | 4 |
|
| 5 | +import jax |
5 | 6 | import jax.numpy as jnp
|
6 | 7 |
|
7 |
| -from mcx.inference.warmup.mass_matrix_adaptation import ( |
8 |
| - WelfordAlgorithmState, |
9 |
| - welford_algorithm, |
10 |
| -) |
| 8 | + |
| 9 | +class WelfordAlgorithmState(NamedTuple): |
| 10 | + """State carried through the Welford algorithm. |
| 11 | +
|
| 12 | + mean |
| 13 | + The running sample mean. |
| 14 | + m2 |
| 15 | + The running value of the sum of difference of squares. See documentation |
| 16 | + of the `welford_algorithm` function for an explanation. |
| 17 | + sample_size |
| 18 | + The number of successive states the previous values have been computed on; |
| 19 | + also the current number of iterations of the algorithm. |
| 20 | + """ |
| 21 | + |
| 22 | + mean: float |
| 23 | + m2: float |
| 24 | + sample_size: int |
11 | 25 |
|
12 | 26 |
|
13 | 27 | class GelmanRubinState(NamedTuple):
|
14 | 28 | w_state: WelfordAlgorithmState
|
15 |
| - rhat: float |
| 29 | + rhat: jnp.DeviceArray |
| 30 | + metric: jnp.DeviceArray |
| 31 | + metric_name: str |
| 32 | + |
| 33 | + |
| 34 | +def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]: |
| 35 | + """Welford's online estimator of covariance. |
| 36 | +
|
| 37 | + It is possible to compute the variance of a population of values in an |
| 38 | + on-line fashion to avoid storing intermediate results. The naive recurrence |
| 39 | + relations between the sample mean and variance at a step and the next are |
| 40 | + however not numerically stable. |
| 41 | +
|
| 42 | + Welford's algorithm uses the sum of square of differences |
| 43 | + :math:`M_{2,n} = \\sum_{i=1}^n \\left(x_i-\\overline{x_n}\right)^2` |
| 44 | + for updating where :math:`x_n` is the current mean and the following |
| 45 | + recurrence relationships |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + is_diagonal_matrix |
| 50 | + When True the algorithm adapts and returns a diagonal mass matrix |
| 51 | + (default), otherwise adapts and returns a dense mass matrix. |
| 52 | +
|
| 53 | + .. math: |
| 54 | + M_{2,n} = M_{2, n-1} + (x_n-\\overline{x}_{n-1})(x_n-\\overline{x}_n) |
| 55 | + \\sigma_n^2 = \\frac{M_{2,n}}{n} |
| 56 | + """ |
| 57 | + |
| 58 | + def init(n_chains: int, n_dims: int) -> WelfordAlgorithmState: |
| 59 | + """Initialize the covariance estimation. |
| 60 | +
|
| 61 | + When the matrix is diagonal it is sufficient to work with an array that contains |
| 62 | + the diagonal value. Otherwise we need to work with the matrix in full. |
| 63 | +
|
| 64 | + Parameters |
| 65 | + ---------- |
| 66 | + n_chains: int |
| 67 | + The number of chains being run |
| 68 | + n_dims: int |
| 69 | + The number of variables |
| 70 | + """ |
| 71 | + sample_size = 0 |
| 72 | + mean = jnp.zeros((n_chains, n_dims)) |
| 73 | + if is_diagonal_matrix: |
| 74 | + m2 = jnp.zeros((n_chains, n_dims)) |
| 75 | + else: |
| 76 | + m2 = jnp.zeros((n_chains, n_chains, n_dims)) |
| 77 | + return WelfordAlgorithmState(mean, m2, sample_size) |
| 78 | + |
| 79 | + @jax.jit |
| 80 | + def update( |
| 81 | + state: WelfordAlgorithmState, value: jnp.DeviceArray |
| 82 | + ) -> WelfordAlgorithmState: |
| 83 | + """Update the M2 matrix using the new value. |
| 84 | +
|
| 85 | + Parameters |
| 86 | + ---------- |
| 87 | + state: WelfordAlgorithmState |
| 88 | + The current state of the Welford Algorithm |
| 89 | + value: jax.numpy.DeviceArray, shape (1,) |
| 90 | + The new sample (typically position of the chain) used to update m2 |
| 91 | + """ |
| 92 | + mean, m2, sample_size = state |
| 93 | + sample_size = sample_size + 1 |
| 94 | + |
| 95 | + delta = value - mean |
| 96 | + mean = mean + delta / sample_size |
| 97 | + updated_delta = value - mean |
| 98 | + if is_diagonal_matrix: |
| 99 | + new_m2 = m2 + delta * updated_delta |
| 100 | + else: |
| 101 | + new_m2 = m2 + jnp.outer(updated_delta, delta) |
| 102 | + |
| 103 | + return WelfordAlgorithmState(mean, new_m2, sample_size) |
| 104 | + |
| 105 | + def covariance( |
| 106 | + state: WelfordAlgorithmState, |
| 107 | + ) -> Tuple[jnp.DeviceArray, int, jnp.DeviceArray]: |
| 108 | + mean, m2, sample_size = state |
| 109 | + covariance = m2 / (sample_size - 1) |
| 110 | + return covariance, sample_size, mean |
| 111 | + |
| 112 | + return init, update, covariance |
16 | 113 |
|
17 | 114 |
|
18 | 115 | def online_gelman_rubin():
|
19 | 116 | """Online estimation of the Gelman-Rubin diagnostic."""
|
20 | 117 |
|
21 | 118 | w_init, w_update, w_covariance = welford_algorithm(True)
|
22 | 119 |
|
23 |
| - def init(num_chains): |
24 |
| - w_state = w_init(num_chains) |
25 |
| - return GelmanRubinState(w_state, 0) |
| 120 | + def init(init_state): |
| 121 | + """Initialise the online gelman/rubin estimator |
| 122 | +
|
| 123 | + Parameters |
| 124 | + ---------- |
| 125 | + num_chains: int |
| 126 | + The number of chains being run |
| 127 | +
|
| 128 | + Returns |
| 129 | + ------- |
| 130 | + GelmanRubinState with all values set to zeros. |
| 131 | +
|
| 132 | + """ |
| 133 | + n_chains, n_dims = init_state.position.shape |
| 134 | + w_state = w_init(n_chains, n_dims) |
| 135 | + return GelmanRubinState(w_state, 0, jnp.nan, "worst_rhat") |
26 | 136 |
|
27 | 137 | def update(chain_state, rhat_state):
|
28 |
| - within_state, step, num_chains, _, _, _ = rhat_state |
| 138 | + """Update rhat estimates |
| 139 | +
|
| 140 | + Parameters |
| 141 | + ---------- |
| 142 | + chain_state: HMCState |
| 143 | + The chain state |
| 144 | + rhat_state: GelmanRubinState |
| 145 | + The GelmanRubinState from the previous draw |
| 146 | +
|
| 147 | + Returns |
| 148 | + ------- |
| 149 | + An updated GelmanRubinState object |
| 150 | + """ |
| 151 | + within_state, _, _, metric_name = rhat_state |
29 | 152 |
|
30 | 153 | positions = chain_state.position
|
31 | 154 | within_state = w_update(within_state, positions)
|
32 |
| - |
33 |
| - covariance, step, mean = w_covariance(rhat_state) |
34 |
| - within_var = jnp.mean(covariance) |
35 |
| - between_var = jnp.var(mean, ddof=1) |
| 155 | + covariance, step, mean = w_covariance(within_state) |
| 156 | + within_var = jnp.mean(covariance, axis=0) |
| 157 | + between_var = jnp.var(mean, axis=0, ddof=1) |
36 | 158 | estimator = ((step - 1) / step) * within_var + between_var
|
37 | 159 | rhat = jnp.sqrt(estimator / within_var)
|
| 160 | + worst_rhat = rhat[jnp.argmax(jnp.abs(rhat - 1.0))] |
38 | 161 |
|
39 |
| - return GelmanRubinState(within_state, rhat) |
| 162 | + return GelmanRubinState(within_state, rhat, worst_rhat, metric_name) |
40 | 163 |
|
41 | 164 | return init, update
|
42 | 165 |
|
|
0 commit comments