Skip to content

Commit 26c316f

Browse files
committed
display # divergences while sampling
We added a callback that allows to display the largest number of divergences on a single chain while sampling. In the spirit of creating a fully interactive sampling experience.
1 parent af1adc9 commit 26c316f

13 files changed

+93
-199
lines changed

mcx/diagnostics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .gelman_rubin import online_gelman_rubin
2+
from .mcmc import divergences
23

34
__all__ = [
5+
"divergences",
46
"online_gelman_rubin",
57
]

mcx/diagnostics/gelman_rubin.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class GelmanRubinState(NamedTuple):
2828
w_state: WelfordAlgorithmState
2929
rhat: jnp.DeviceArray
3030
metric: jnp.DeviceArray
31-
metric_name: str
3231

3332

3433
def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]:
@@ -112,12 +111,12 @@ def covariance(
112111
return init, update, covariance
113112

114113

115-
def online_gelman_rubin():
114+
def online_gelman_rubin() -> Tuple[str, Callable, Callable]:
116115
"""Online estimation of the Gelman-Rubin diagnostic."""
117-
116+
metric_name = "worst_rhat"
118117
w_init, w_update, w_covariance = welford_algorithm(True)
119118

120-
def init(init_state):
119+
def init(init_state) -> GelmanRubinState:
121120
"""Initialise the online gelman/rubin estimator
122121
123122
Parameters
@@ -132,9 +131,10 @@ def init(init_state):
132131
"""
133132
n_chains, n_dims = init_state.position.shape
134133
w_state = w_init(n_chains, n_dims)
135-
return GelmanRubinState(w_state, 0, jnp.nan, "worst_rhat")
134+
return GelmanRubinState(w_state, 0, jnp.nan)
136135

137-
def update(chain_state, rhat_state):
136+
@jax.jit
137+
def update(chain_state, _, rhat_state: GelmanRubinState) -> GelmanRubinState:
138138
"""Update rhat estimates
139139
140140
Parameters
@@ -148,7 +148,7 @@ def update(chain_state, rhat_state):
148148
-------
149149
An updated GelmanRubinState object
150150
"""
151-
within_state, _, _, metric_name = rhat_state
151+
within_state, *_ = rhat_state
152152

153153
positions = chain_state.position
154154
within_state = w_update(within_state, positions)
@@ -159,9 +159,9 @@ def update(chain_state, rhat_state):
159159
rhat = jnp.sqrt(estimator / within_var)
160160
worst_rhat = rhat[jnp.argmax(jnp.abs(rhat - 1.0))]
161161

162-
return GelmanRubinState(within_state, rhat, worst_rhat, metric_name)
162+
return GelmanRubinState(within_state, rhat, worst_rhat)
163163

164-
return init, update
164+
return metric_name, init, update
165165

166166

167167
def split_gelman_rubin():

mcx/diagnostics/mcmc.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Diagnostics that are specific to MCMC algorithms."""
2+
from typing import Callable, NamedTuple, Tuple
3+
4+
import jax
5+
import jax.numpy as jnp
6+
7+
8+
class DivergencesState(NamedTuple):
9+
num_divergences: jnp.ndarray
10+
metric: int # maximum number of divergences
11+
12+
13+
def divergences() -> Tuple[str, Callable, Callable]:
14+
"""Count the number of divergences.
15+
16+
We keep a count of the current number of divergences for each chain and
17+
return the maximum number of divergences to be displayed.
18+
19+
"""
20+
metric_name = "divergences"
21+
22+
def init(init_state) -> DivergencesState:
23+
"""Initialize the divergence counters."""
24+
num_chains, _ = init_state.position.shape
25+
num_divergences = jnp.zeros(num_chains)
26+
return DivergencesState(num_divergences, 0)
27+
28+
@jax.jit
29+
def update(_, info, divergence_state: DivergencesState) -> DivergencesState:
30+
"""Update the number of divergences."""
31+
num_divergences, *_ = divergence_state
32+
is_divergent = info.is_divergent.astype(int)
33+
num_divergences = num_divergences + is_divergent
34+
max_num_divergences = jnp.max(num_divergences)
35+
return DivergencesState(num_divergences, max_num_divergences)
36+
37+
return metric_name, init, update

mcx/distributions/distribution.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def __init__(self, *args) -> None:
6363

6464
@abstractmethod
6565
def sample(
66-
self, rng_key: jax.random.PRNGKey, sample_shape: Union[Tuple[()], Tuple[int]]
66+
self, rng_key: jnp.ndarray, sample_shape: Union[Tuple[()], Tuple[int]]
6767
) -> jax.numpy.DeviceArray:
6868
"""Obtain samples from the distribution.
6969
7070
Parameters
7171
----------
72-
rng_key: jax.random.PRNGKey
72+
rng_key: jnp.ndarray
7373
The pseudo random number generator key to use to draw samples.
7474
sample_shape: Tuple[int]
7575
The number of independant, identically distributed samples to draw
@@ -84,7 +84,7 @@ def sample(
8484

8585
def forward(
8686
self,
87-
rng_key: jax.random.PRNGKey,
87+
rng_key: jnp.ndarray,
8888
sample_shape: Union[Tuple[()], Tuple[int]] = (),
8989
) -> jax.numpy.DeviceArray:
9090
"""Generate forward samples from the distribution. Defined for compatibility with

mcx/inference/adaptation/num_steps_adaptation.py

-80
This file was deleted.

mcx/inference/adaptation/stan.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def stan_hmc_warmup(
8383
)
8484

8585
def init(
86-
rng_key: jax.random.PRNGKey, initial_state: HMCState, initial_step_size: int
86+
rng_key: jnp.ndarray, initial_state: HMCState, initial_step_size: int
8787
) -> StanWarmupState:
8888
"""Initialize the warmup.
8989
@@ -109,7 +109,7 @@ def init(
109109

110110
@jax.jit
111111
def update(
112-
rng_key: jax.random.PRNGKey,
112+
rng_key: jnp.ndarray,
113113
stage: int,
114114
is_middle_window_end: bool,
115115
chain_state: HMCState,
@@ -197,7 +197,7 @@ def init(initial_step_size: float) -> DualAveragingState:
197197

198198
@jax.jit
199199
def update(
200-
state: Tuple[jax.random.PRNGKey, HMCState, HMCInfo, StanWarmupState]
200+
state: Tuple[jnp.ndarray, HMCState, HMCInfo, StanWarmupState]
201201
) -> StanWarmupState:
202202
rng_key, chain_state, chain_info, warmup_state = state
203203

@@ -244,7 +244,7 @@ def init(chain_state: HMCState) -> MassMatrixAdaptationState:
244244

245245
@jax.jit
246246
def update(
247-
state: Tuple[jax.random.PRNGKey, HMCState, HMCInfo, StanWarmupState]
247+
state: Tuple[jnp.ndarray, HMCState, HMCInfo, StanWarmupState]
248248
) -> StanWarmupState:
249249
"""Move the warmup by one state when in a slow adaptation interval.
250250

mcx/inference/adaptation/step_size_adaptation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,15 @@ class ReasonableStepSizeState(NamedTuple):
171171
The current step size in the search.
172172
"""
173173

174-
rng_key: jax.random.PRNGKey
174+
rng_key: jnp.ndarray
175175
direction: int
176176
previous_direction: int
177177
step_size: float
178178

179179

180180
@partial(jax.jit, static_argnums=(1,))
181181
def find_reasonable_step_size(
182-
rng_key: jax.random.PRNGKey,
182+
rng_key: jnp.ndarray,
183183
kernel_generator: Callable[[float, jnp.DeviceArray], Callable],
184184
reference_hmc_state: HMCState,
185185
inverse_mass_matrix: jnp.DeviceArray,

mcx/inference/hmc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def make_state(position):
9292

9393
def warmup(
9494
self,
95-
rng_key: jax.random.PRNGKey,
95+
rng_key: jax.numpy.ndarray,
9696
initial_state: HMCState,
9797
kernel_factory: Callable,
9898
num_chains,

mcx/inference/kernels.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def hmc_kernel(
105105
"""
106106

107107
@jax.jit
108-
def kernel(
109-
rng_key: jax.random.PRNGKey, state: HMCState
110-
) -> Tuple[HMCState, HMCInfo]:
108+
def kernel(rng_key: jnp.ndarray, state: HMCState) -> Tuple[HMCState, HMCInfo]:
111109
"""Moves the chain by one step using the Hamiltonian dynamics.
112110
113111
Parameters
@@ -219,9 +217,7 @@ def rwm_kernel(logpdf: Callable, proposal_generator: Callable) -> Callable:
219217
"""
220218

221219
@jax.jit
222-
def kernel(
223-
rng_key: jax.random.PRNGKey, state: RWMState
224-
) -> Tuple[RWMState, RWMInfo]:
220+
def kernel(rng_key: jnp.ndarray, state: RWMState) -> Tuple[RWMState, RWMInfo]:
225221
"""Moves the chain by one step using the Random Walk Metropolis algorithm.
226222
227223
Parameters

mcx/inference/metrics.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
KineticEnergy = Callable[[jnp.DeviceArray], float]
13-
MomentumGenerator = Callable[[jax.random.PRNGKey], jnp.DeviceArray]
13+
MomentumGenerator = Callable[[jnp.ndarray], jnp.DeviceArray]
1414

1515

1616
def gaussian_euclidean_metric(
@@ -35,7 +35,7 @@ def gaussian_euclidean_metric(
3535
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
3636

3737
@jax.jit
38-
def momentum_generator(rng_key: jax.random.PRNGKey) -> jnp.DeviceArray:
38+
def momentum_generator(rng_key: jnp.ndarray) -> jnp.DeviceArray:
3939
std = jax.random.normal(rng_key, shape)
4040
p = jnp.multiply(std, mass_matrix_sqrt)
4141
return p
@@ -52,7 +52,7 @@ def kinetic_energy(momentum: jnp.DeviceArray) -> float:
5252
mass_matrix_sqrt = cholesky_of_inverse(inverse_mass_matrix)
5353

5454
@jax.jit
55-
def momentum_generator(rng_key: jax.random.PRNGKey) -> jnp.DeviceArray:
55+
def momentum_generator(rng_key: jnp.ndarray) -> jnp.DeviceArray:
5656
std = jax.random.normal(rng_key, shape)
5757
p = jnp.dot(std, mass_matrix_sqrt)
5858
return p
@@ -67,7 +67,7 @@ def kinetic_energy(momentum: jnp.DeviceArray) -> float:
6767
else:
6868
raise ValueError(
6969
"The mass matrix has the wrong number of dimensions:"
70-
f" expected 1 or 2, got {jnp.dim(inverse_mass_matrix)}."
70+
f" expected 1 or 2, got {jnp.ndim(inverse_mass_matrix)}."
7171
)
7272

7373

0 commit comments

Comments
 (0)