@@ -28,7 +28,6 @@ class GelmanRubinState(NamedTuple):
28
28
w_state : WelfordAlgorithmState
29
29
rhat : jnp .DeviceArray
30
30
metric : jnp .DeviceArray
31
- metric_name : str
32
31
33
32
34
33
def welford_algorithm (is_diagonal_matrix : bool ) -> Tuple [Callable , Callable , Callable ]:
@@ -112,12 +111,12 @@ def covariance(
112
111
return init , update , covariance
113
112
114
113
115
- def online_gelman_rubin ():
114
+ def online_gelman_rubin () -> Tuple [ str , Callable , Callable ] :
116
115
"""Online estimation of the Gelman-Rubin diagnostic."""
117
-
116
+ metric_name = "worst_rhat"
118
117
w_init , w_update , w_covariance = welford_algorithm (True )
119
118
120
- def init (init_state ):
119
+ def init (init_state ) -> GelmanRubinState :
121
120
"""Initialise the online gelman/rubin estimator
122
121
123
122
Parameters
@@ -132,9 +131,10 @@ def init(init_state):
132
131
"""
133
132
n_chains , n_dims = init_state .position .shape
134
133
w_state = w_init (n_chains , n_dims )
135
- return GelmanRubinState (w_state , 0 , jnp .nan , "worst_rhat" )
134
+ return GelmanRubinState (w_state , 0 , jnp .nan )
136
135
137
- def update (chain_state , rhat_state ):
136
+ @jax .jit
137
+ def update (chain_state , _ , rhat_state : GelmanRubinState ) -> GelmanRubinState :
138
138
"""Update rhat estimates
139
139
140
140
Parameters
@@ -148,7 +148,7 @@ def update(chain_state, rhat_state):
148
148
-------
149
149
An updated GelmanRubinState object
150
150
"""
151
- within_state , _ , _ , metric_name = rhat_state
151
+ within_state , * _ = rhat_state
152
152
153
153
positions = chain_state .position
154
154
within_state = w_update (within_state , positions )
@@ -159,9 +159,9 @@ def update(chain_state, rhat_state):
159
159
rhat = jnp .sqrt (estimator / within_var )
160
160
worst_rhat = rhat [jnp .argmax (jnp .abs (rhat - 1.0 ))]
161
161
162
- return GelmanRubinState (within_state , rhat , worst_rhat , metric_name )
162
+ return GelmanRubinState (within_state , rhat , worst_rhat )
163
163
164
- return init , update
164
+ return metric_name , init , update
165
165
166
166
167
167
def split_gelman_rubin ():
0 commit comments