|
3 | 3 | from functools import partial
|
4 | 4 | from typing import (
|
5 | 5 | Any,
|
| 6 | + Callable, |
6 | 7 | Optional,
|
7 | 8 | Tuple,
|
8 | 9 | )
|
@@ -174,7 +175,7 @@ def test_mmd_chunk_size_equivalence(
|
174 | 175 | 2,
|
175 | 176 | ],
|
176 | 177 | )
|
177 |
| -def test_mmd_chunk_size_initialization_valid( |
| 178 | +def test_mmd_chunk_size_valid( |
178 | 179 | chunk_size: Optional[int],
|
179 | 180 | ) -> None:
|
180 | 181 | """Test MMD initialization with valid chunk sizes.
|
@@ -230,3 +231,58 @@ def test_mmd_chunk_size_invalid(
|
230 | 231 | kernel=kernel,
|
231 | 232 | chunk_size=chunk_size,
|
232 | 233 | )
|
| 234 | + |
| 235 | + |
| 236 | +@pytest.mark.parametrize( |
| 237 | + "kernel", |
| 238 | + [ |
| 239 | + partial( |
| 240 | + rbf_kernel, |
| 241 | + sigma=DEFAULT_SIGMA, |
| 242 | + ), |
| 243 | + lambda X, Y: X + Y, # simple kernel |
| 244 | + ], |
| 245 | +) |
| 246 | +def test_mmd_kernel_valid( |
| 247 | + kernel: Callable, # type: ignore |
| 248 | +) -> None: |
| 249 | + """Test MMD initialization with valid kernels. |
| 250 | +
|
| 251 | + :param kernel: kernel to test |
| 252 | + :type kernel: Callable |
| 253 | + """ |
| 254 | + np.random.seed(seed=RANDOM_SEED) |
| 255 | + X_ref = np.random.normal(0, 1, 100) |
| 256 | + X_test = np.random.normal(0, 1, 100) |
| 257 | + |
| 258 | + detector = MMD( |
| 259 | + kernel=kernel, |
| 260 | + ) |
| 261 | + _ = detector.fit(X=X_ref) |
| 262 | + result = detector.compare(X=X_test)[0] |
| 263 | + |
| 264 | + assert result is not None |
| 265 | + |
| 266 | + |
| 267 | +@pytest.mark.parametrize( |
| 268 | + "kernel", |
| 269 | + [ |
| 270 | + None, |
| 271 | + "invalid", |
| 272 | + 123, |
| 273 | + [1, 2], |
| 274 | + {1: 2}, |
| 275 | + ], |
| 276 | +) |
| 277 | +def test_mmd_kernel_invalid( |
| 278 | + kernel: Any, |
| 279 | +) -> None: |
| 280 | + """Test MMD initialization with invalid kernels. |
| 281 | +
|
| 282 | + :param kernel: kernel to test |
| 283 | + :type kernel: Any |
| 284 | + """ |
| 285 | + with pytest.raises((TypeError, ValueError)): |
| 286 | + MMD( |
| 287 | + kernel=kernel, |
| 288 | + ) |
0 commit comments