Skip to content

Commit 6c15546

Browse files
committed
Merge branch 'main' into feature/greedy_kernel_points_analytic_test
2 parents 38d6389 + c9dcab7 commit 6c15546

File tree

66 files changed

+2358
-1388
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2358
-1388
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ repos:
4242
- --ignore-case
4343
- --unique
4444
- repo: https://github.com/astral-sh/uv-pre-commit
45-
rev: 0.6.3
45+
rev: 0.6.6
4646
hooks:
4747
# Keep lock file up to date
4848
- id: uv-lock
@@ -105,7 +105,7 @@ repos:
105105
# Enforce that type annotations are used instead of type comments
106106
- id: python-use-type-annotations
107107
- repo: https://github.com/astral-sh/ruff-pre-commit
108-
rev: v0.9.9
108+
rev: v0.11.0
109109
hooks:
110110
# Run the linter.
111111
- id: ruff

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Added Compress++ coreset reduction algorithm.
1313
(https://github.com/gchq/coreax/issues/934)
14+
- Added `reduce_iterative()` method to Kernel Herding. (https://github.com/gchq/coreax/pull/983)
15+
- Added probabilistic iterative Kernel Herding benchmarking results. (https://github.com/gchq/coreax/pull/983)
1416
- Analytic example with integration test for `GreedyKernelPoints` plus an analytic unit
1517
test for the loss function. (https://github.com/gchq/coreax/pull/1004)
1618

@@ -20,6 +22,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2022

2123
### Changed
2224

25+
- Changed the score function used by Stein Thinning in benchmarking.
26+
(https://github.com/gchq/coreax/pull/1000)
27+
- Fixed the random state for UMAP in benchmarking for reproducibility.
28+
(https://github.com/gchq/coreax/pull/1000)
29+
- Reduced the number of dimensions when applying UMAP in `pounce_benchmark.py`.
30+
(https://github.com/gchq/coreax/pull/1000)
2331
- Refactored `GreedyKernelPoints` and associated functions to make more extensible in
2432
future. (https://github.com/gchq/coreax/pull/1004)
2533

benchmark/blobs_benchmark.py

+43-13
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,17 @@
3131
import json
3232
import os
3333
import time
34+
from typing import Union
3435

3536
import jax
3637
import jax.numpy as jnp
38+
import jax.scipy as jsp
3739
import numpy as np
40+
from jaxtyping import Array, Shaped
3841
from sklearn.datasets import make_blobs
3942

40-
from coreax import Data, SlicedScoreMatching
43+
from coreax import Data
44+
from coreax.benchmark_util import IterativeKernelHerding
4145
from coreax.kernels import (
4246
SquaredExponentialKernel,
4347
SteinKernel,
@@ -46,7 +50,6 @@
4650
from coreax.metrics import KSD, MMD
4751
from coreax.solvers import (
4852
CompressPlusPlus,
49-
IterativeKernelHerding,
5053
KernelHerding,
5154
KernelThinning,
5255
RandomSample,
@@ -84,17 +87,32 @@ def setup_stein_kernel(
8487
:param random_seed: An integer seed for the random number generator.
8588
:return: A SteinKernel object.
8689
"""
87-
sliced_score_matcher = SlicedScoreMatching(
88-
jax.random.PRNGKey(random_seed),
89-
jax.random.rademacher,
90-
use_analytic=True,
91-
num_random_vectors=100,
92-
learning_rate=0.001,
93-
num_epochs=50,
94-
)
90+
# Fit a Gaussian kernel density estimator on a subset of points for efficiency
91+
num_data_points = len(dataset)
92+
num_samples_length_scale = min(num_data_points, 1000)
93+
generator = np.random.default_rng(random_seed)
94+
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
95+
kde = jsp.stats.gaussian_kde(dataset.data[idx].T)
96+
97+
# Define the score function as the gradient of log density given by the KDE
98+
def score_function(
99+
x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int],
100+
) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]:
101+
"""
102+
Compute the score function (gradient of log density) for a single point.
103+
104+
:param x: Input point represented as array
105+
:return: Gradient of log probability density at the given point
106+
"""
107+
108+
def logpdf_single(x: Shaped[Array, " d"]) -> Shaped[Array, ""]:
109+
return kde.logpdf(x.reshape(1, -1))[0]
110+
111+
return jax.grad(logpdf_single)(x)
112+
95113
return SteinKernel(
96114
base_kernel=sq_exp_kernel,
97-
score_function=sliced_score_matcher.match(jnp.asarray(dataset.data)),
115+
score_function=score_function,
98116
)
99117

100118

@@ -142,7 +160,7 @@ def setup_solvers(
142160
SteinThinning(
143161
coreset_size=coreset_size,
144162
kernel=stein_kernel,
145-
regularise=False,
163+
regularise=True,
146164
),
147165
),
148166
(
@@ -188,6 +206,18 @@ def setup_solvers(
188206
num_iterations=5,
189207
),
190208
),
209+
(
210+
"CubicProbIterativeHerding",
211+
IterativeKernelHerding(
212+
coreset_size=coreset_size,
213+
kernel=sq_exp_kernel,
214+
probabilistic=True,
215+
temperature=0.001,
216+
random_key=random_key,
217+
num_iterations=10,
218+
t_schedule=1 / jnp.linspace(10, 100, 10) ** 3,
219+
),
220+
),
191221
]
192222

193223

@@ -296,7 +326,7 @@ def main() -> None: # pylint: disable=too-many-locals
296326

297327
# Set up metrics
298328
mmd_metric = MMD(kernel=sq_exp_kernel)
299-
ksd_metric = KSD(kernel=sq_exp_kernel)
329+
ksd_metric = KSD(kernel=stein_kernel) # KSD needs a Stein kernel
300330

301331
# Set up weights optimiser
302332
weights_optimiser = MMDWeightsOptimiser(kernel=sq_exp_kernel)

0 commit comments

Comments
 (0)