Skip to content

Commit 32fe6b3

Browse files
Add MMD kernel initialization tests
1 parent 762c17b commit 32fe6b3

File tree

1 file changed

+57
-1
lines changed
  • frouros/tests/unit/detectors/data_drift/batch/distance_based

1 file changed

+57
-1
lines changed

frouros/tests/unit/detectors/data_drift/batch/distance_based/test_mmd.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import partial
44
from typing import (
55
Any,
6+
Callable,
67
Optional,
78
Tuple,
89
)
@@ -174,7 +175,7 @@ def test_mmd_chunk_size_equivalence(
174175
2,
175176
],
176177
)
177-
def test_mmd_chunk_size_initialization_valid(
178+
def test_mmd_chunk_size_valid(
178179
chunk_size: Optional[int],
179180
) -> None:
180181
"""Test MMD initialization with valid chunk sizes.
@@ -230,3 +231,58 @@ def test_mmd_chunk_size_invalid(
230231
kernel=kernel,
231232
chunk_size=chunk_size,
232233
)
234+
235+
236+
@pytest.mark.parametrize(
237+
"kernel",
238+
[
239+
partial(
240+
rbf_kernel,
241+
sigma=DEFAULT_SIGMA,
242+
),
243+
lambda X, Y: X + Y, # simple kernel
244+
],
245+
)
246+
def test_mmd_kernel_valid(
247+
kernel: Callable, # type: ignore
248+
) -> None:
249+
"""Test MMD initialization with valid kernels.
250+
251+
:param kernel: kernel to test
252+
:type kernel: Callable
253+
"""
254+
np.random.seed(seed=RANDOM_SEED)
255+
X_ref = np.random.normal(0, 1, 100)
256+
X_test = np.random.normal(0, 1, 100)
257+
258+
detector = MMD(
259+
kernel=kernel,
260+
)
261+
_ = detector.fit(X=X_ref)
262+
result = detector.compare(X=X_test)[0]
263+
264+
assert result is not None
265+
266+
267+
@pytest.mark.parametrize(
268+
"kernel",
269+
[
270+
None,
271+
"invalid",
272+
123,
273+
[1, 2],
274+
{1: 2},
275+
],
276+
)
277+
def test_mmd_kernel_invalid(
278+
kernel: Any,
279+
) -> None:
280+
"""Test MMD initialization with invalid kernels.
281+
282+
:param kernel: kernel to test
283+
:type kernel: Any
284+
"""
285+
with pytest.raises((TypeError, ValueError)):
286+
MMD(
287+
kernel=kernel,
288+
)

0 commit comments

Comments
 (0)