|
31 | 31 | import json
|
32 | 32 | import os
|
33 | 33 | import time
|
| 34 | +from typing import Union |
34 | 35 |
|
35 | 36 | import jax
|
36 | 37 | import jax.numpy as jnp
|
| 38 | +import jax.scipy as jsp |
37 | 39 | import numpy as np
|
| 40 | +from jaxtyping import Array, Shaped |
38 | 41 | from sklearn.datasets import make_blobs
|
39 | 42 |
|
40 |
| -from coreax import Data, SlicedScoreMatching |
| 43 | +from coreax import Data |
| 44 | +from coreax.benchmark_util import IterativeKernelHerding |
41 | 45 | from coreax.kernels import (
|
42 | 46 | SquaredExponentialKernel,
|
43 | 47 | SteinKernel,
|
|
46 | 50 | from coreax.metrics import KSD, MMD
|
47 | 51 | from coreax.solvers import (
|
48 | 52 | CompressPlusPlus,
|
49 |
| - IterativeKernelHerding, |
50 | 53 | KernelHerding,
|
51 | 54 | KernelThinning,
|
52 | 55 | RandomSample,
|
@@ -84,17 +87,32 @@ def setup_stein_kernel(
|
84 | 87 | :param random_seed: An integer seed for the random number generator.
|
85 | 88 | :return: A SteinKernel object.
|
86 | 89 | """
|
87 |
| - sliced_score_matcher = SlicedScoreMatching( |
88 |
| - jax.random.PRNGKey(random_seed), |
89 |
| - jax.random.rademacher, |
90 |
| - use_analytic=True, |
91 |
| - num_random_vectors=100, |
92 |
| - learning_rate=0.001, |
93 |
| - num_epochs=50, |
94 |
| - ) |
| 90 | + # Fit a Gaussian kernel density estimator on a subset of points for efficiency |
| 91 | + num_data_points = len(dataset) |
| 92 | + num_samples_length_scale = min(num_data_points, 1000) |
| 93 | + generator = np.random.default_rng(random_seed) |
| 94 | + idx = generator.choice(num_data_points, num_samples_length_scale, replace=False) |
| 95 | + kde = jsp.stats.gaussian_kde(dataset.data[idx].T) |
| 96 | + |
| 97 | + # Define the score function as the gradient of log density given by the KDE |
| 98 | + def score_function( |
| 99 | + x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int], |
| 100 | + ) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]: |
| 101 | + """ |
| 102 | + Compute the score function (gradient of log density) for a single point. |
| 103 | +
|
| 104 | + :param x: Input point represented as array |
| 105 | + :return: Gradient of log probability density at the given point |
| 106 | + """ |
| 107 | + |
| 108 | + def logpdf_single(x: Shaped[Array, " d"]) -> Shaped[Array, ""]: |
| 109 | + return kde.logpdf(x.reshape(1, -1))[0] |
| 110 | + |
| 111 | + return jax.grad(logpdf_single)(x) |
| 112 | + |
95 | 113 | return SteinKernel(
|
96 | 114 | base_kernel=sq_exp_kernel,
|
97 |
| - score_function=sliced_score_matcher.match(jnp.asarray(dataset.data)), |
| 115 | + score_function=score_function, |
98 | 116 | )
|
99 | 117 |
|
100 | 118 |
|
@@ -142,7 +160,7 @@ def setup_solvers(
|
142 | 160 | SteinThinning(
|
143 | 161 | coreset_size=coreset_size,
|
144 | 162 | kernel=stein_kernel,
|
145 |
| - regularise=False, |
| 163 | + regularise=True, |
146 | 164 | ),
|
147 | 165 | ),
|
148 | 166 | (
|
@@ -188,6 +206,18 @@ def setup_solvers(
|
188 | 206 | num_iterations=5,
|
189 | 207 | ),
|
190 | 208 | ),
|
| 209 | + ( |
| 210 | + "CubicProbIterativeHerding", |
| 211 | + IterativeKernelHerding( |
| 212 | + coreset_size=coreset_size, |
| 213 | + kernel=sq_exp_kernel, |
| 214 | + probabilistic=True, |
| 215 | + temperature=0.001, |
| 216 | + random_key=random_key, |
| 217 | + num_iterations=10, |
| 218 | + t_schedule=1 / jnp.linspace(10, 100, 10) ** 3, |
| 219 | + ), |
| 220 | + ), |
191 | 221 | ]
|
192 | 222 |
|
193 | 223 |
|
@@ -296,7 +326,7 @@ def main() -> None: # pylint: disable=too-many-locals
|
296 | 326 |
|
297 | 327 | # Set up metrics
|
298 | 328 | mmd_metric = MMD(kernel=sq_exp_kernel)
|
299 |
| - ksd_metric = KSD(kernel=sq_exp_kernel) |
| 329 | + ksd_metric = KSD(kernel=stein_kernel) # KSD needs a Stein kernel |
300 | 330 |
|
301 | 331 | # Set up weights optimiser
|
302 | 332 | weights_optimiser = MMDWeightsOptimiser(kernel=sq_exp_kernel)
|
|
0 commit comments