Skip to content

Commit 91c35c5

Browse files
Extend BhattacharyyaDistance to multivariate
1 parent 96a562f commit 91c35c5

File tree

7 files changed

+93
-31
lines changed

7 files changed

+93
-31
lines changed

frouros/detectors/data_drift/batch/distance_based/base.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class BaseDistanceBasedBins(BaseDistanceBased):
120120

121121
def __init__(
122122
self,
123+
statistical_type: BaseStatisticalType,
123124
statistical_method: Callable, # type: ignore
124125
statistical_kwargs: dict[str, Any],
125126
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
@@ -137,9 +138,12 @@ def __init__(
137138
:type num_bins: int
138139
"""
139140
super().__init__(
140-
statistical_type=UnivariateData(),
141+
statistical_type=statistical_type,
141142
statistical_method=statistical_method,
142-
statistical_kwargs={**statistical_kwargs, "num_bins": num_bins},
143+
statistical_kwargs={
144+
**statistical_kwargs,
145+
"num_bins": num_bins,
146+
},
143147
callbacks=callbacks,
144148
)
145149
self.num_bins = num_bins
@@ -171,23 +175,40 @@ def _distance_measure(
171175
X: np.ndarray, # noqa: N803
172176
**kwargs: Any,
173177
) -> DistanceResult:
174-
distance_bins = self._distance_measure_bins(X_ref=X_ref, X=X)
175-
distance = DistanceResult(distance=distance_bins)
178+
distance_bins = self._distance_measure_bins(
179+
X_ref=X_ref,
180+
X=X,
181+
)
182+
distance = DistanceResult(
183+
distance=distance_bins,
184+
)
176185
return distance
177186

178187
@staticmethod
179188
def _calculate_bins_values(
180189
X_ref: np.ndarray, # noqa: N803
181190
X: np.ndarray,
182191
num_bins: int = 10,
183-
) -> np.ndarray:
184-
bins = np.histogram(np.hstack((X_ref, X)), bins=num_bins)[ # get the bin edges
185-
1
192+
) -> Tuple[np.ndarray, np.ndarray]:
193+
# Add a new axis if X_ref and X are 1D
194+
if X_ref.ndim == 1:
195+
X_ref = X_ref[:, np.newaxis]
196+
X = X[:, np.newaxis]
197+
198+
min_edge = np.min(np.vstack((X_ref, X)), axis=0)
199+
max_edge = np.max(np.vstack((X_ref, X)), axis=0)
200+
bins = [
201+
np.linspace(min_edge[i], max_edge[i], num_bins + 1)
202+
for i in range(X_ref.shape[1])
186203
]
187-
X_ref_percents = ( # noqa: N806
188-
np.histogram(a=X_ref, bins=bins)[0] / X_ref.shape[0]
189-
) # noqa: N806
190-
X_percents = np.histogram(a=X, bins=bins)[0] / X.shape[0] # noqa: N806
204+
205+
X_ref_hist, _ = np.histogramdd(X_ref, bins=bins)
206+
X_hist, _ = np.histogramdd(X, bins=bins)
207+
208+
# Normalize histograms
209+
X_ref_percents = X_ref_hist / X_ref.shape[0]
210+
X_percents = X_hist / X.shape[0]
211+
191212
return X_ref_percents, X_percents
192213

193214
@abc.abstractmethod

frouros/detectors/data_drift/batch/distance_based/bhattacharyya_distance.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from frouros.callbacks.batch.base import BaseCallbackBatch
8+
from frouros.detectors.data_drift.base import MultivariateData
89
from frouros.detectors.data_drift.batch.distance_based.base import (
910
BaseDistanceBasedBins,
1011
)
@@ -13,7 +14,8 @@
1314
class BhattacharyyaDistance(BaseDistanceBasedBins):
1415
"""Bhattacharyya distance [bhattacharyya1946measure]_ detector.
1516
16-
:param num_bins: number of bins in which to divide probabilities, defaults to 10
17+
:param num_bins: number of bins per dimension in which to
18+
divide probabilities, defaults to 10
1719
:type num_bins: int
1820
:param callbacks: callbacks, defaults to None
1921
:type callbacks: Optional[Union[BaseCallback, list[Callback]]]
@@ -29,12 +31,12 @@ class BhattacharyyaDistance(BaseDistanceBasedBins):
2931
>>> from frouros.detectors.data_drift import BhattacharyyaDistance
3032
>>> import numpy as np
3133
>>> np.random.seed(seed=31)
32-
>>> X = np.random.normal(loc=0, scale=1, size=100)
33-
>>> Y = np.random.normal(loc=1, scale=1, size=100)
34-
>>> detector = BhattacharyyaDistance(num_bins=20)
34+
>>> X = np.random.multivariate_normal(mean=[1, 1], cov=[[2, 0], [0, 2]], size=100)
35+
>>> Y = np.random.multivariate_normal(mean=[0, 0], cov=[[2, 1], [1, 2]], size=100)
36+
>>> detector = BhattacharyyaDistance(num_bins=10)
3537
>>> _ = detector.fit(X=X)
3638
>>> detector.compare(X=Y)
37-
DistanceResult(distance=0.2182101059622703)
39+
DistanceResult(distance=0.3413868461814531)
3840
"""
3941

4042
def __init__( # noqa: D107
@@ -43,6 +45,7 @@ def __init__( # noqa: D107
4345
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
4446
) -> None:
4547
super().__init__(
48+
statistical_type=MultivariateData(),
4649
statistical_method=self._bhattacharyya,
4750
statistical_kwargs={
4851
"num_bins": num_bins,
@@ -56,7 +59,11 @@ def _distance_measure_bins(
5659
X_ref: np.ndarray, # noqa: N803
5760
X: np.ndarray, # noqa: N803
5861
) -> float:
59-
bhattacharyya = self._bhattacharyya(X=X_ref, Y=X, num_bins=self.num_bins)
62+
bhattacharyya = self._bhattacharyya(
63+
X=X_ref,
64+
Y=X,
65+
num_bins=self.num_bins,
66+
)
6067
return bhattacharyya
6168

6269
@staticmethod
@@ -70,7 +77,23 @@ def _bhattacharyya(
7077
X_percents,
7178
Y_percents,
7279
) = BaseDistanceBasedBins._calculate_bins_values(
73-
X_ref=X, X=Y, num_bins=num_bins
80+
X_ref=X,
81+
X=Y,
82+
num_bins=num_bins,
7483
)
75-
bhattacharyya = 1 - np.sum(np.sqrt(X_percents * Y_percents))
84+
85+
# Add small epsilon to avoid log(0)
86+
epsilon = np.finfo(float).eps
87+
X_percents = X_percents + epsilon
88+
Y_percents = Y_percents + epsilon
89+
90+
# Compute Bhattacharyya coefficient
91+
bc = np.sum(np.sqrt(X_percents * Y_percents))
92+
# Clip between [0,1] to avoid numerical errors
93+
bc = np.clip(bc, a_min=0, a_max=1)
94+
95+
# Compute Bhattacharyya distance
96+
# Use absolute value to avoid negative zero values
97+
bhattacharyya = np.abs(-np.log(bc))
98+
7699
return bhattacharyya

frouros/detectors/data_drift/batch/distance_based/hellinger_distance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from frouros.callbacks.batch.base import BaseCallbackBatch
8+
from frouros.detectors.data_drift.base import UnivariateData
89
from frouros.detectors.data_drift.batch.distance_based.base import (
910
BaseDistanceBasedBins,
1011
)
@@ -45,6 +46,7 @@ def __init__( # noqa: D107
4546
) -> None:
4647
sqrt_div = np.sqrt(2)
4748
super().__init__(
49+
statistical_type=UnivariateData(),
4850
statistical_method=self._hellinger,
4951
statistical_kwargs={
5052
"num_bins": num_bins,

frouros/detectors/data_drift/batch/distance_based/hi_normalized_complement.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from frouros.callbacks.batch.base import BaseCallbackBatch
8+
from frouros.detectors.data_drift.base import UnivariateData
89
from frouros.detectors.data_drift.batch.distance_based.base import (
910
BaseDistanceBasedBins,
1011
)
@@ -43,6 +44,7 @@ def __init__( # noqa: D107
4344
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
4445
) -> None:
4546
super().__init__(
47+
statistical_type=UnivariateData(),
4648
statistical_method=self._hi_normalized_complement,
4749
statistical_kwargs={
4850
"num_bins": num_bins,

frouros/detectors/data_drift/batch/distance_based/psi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
from frouros.callbacks.batch.base import BaseCallbackBatch
9+
from frouros.detectors.data_drift.base import UnivariateData
910
from frouros.detectors.data_drift.batch.distance_based.base import (
1011
BaseDistanceBasedBins,
1112
DistanceResult,
@@ -45,6 +46,7 @@ def __init__( # noqa: D107
4546
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
4647
) -> None:
4748
super().__init__(
49+
statistical_type=UnivariateData(),
4850
statistical_method=self._psi,
4951
statistical_kwargs={
5052
"num_bins": num_bins,

frouros/tests/integration/test_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
@pytest.mark.parametrize(
5252
"detector_class, expected_distance, expected_p_value",
5353
[
54-
(BhattacharyyaDistance, 0.55516059, 0.0),
54+
(BhattacharyyaDistance, 0.81004188, 0.0),
5555
(EMD, 3.85346006, 0.0),
5656
(EnergyDistance, 2.11059982, 0.0),
5757
(HellingerDistance, 0.74509099, 0.0),

frouros/tests/integration/test_data_drift.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Test data drift detectors."""
22

3-
from typing import Any, Tuple, Union
3+
from typing import (
4+
Any,
5+
Tuple,
6+
Union,
7+
)
48

59
import numpy as np
610
import pytest
@@ -26,12 +30,8 @@
2630
WelchTTest,
2731
)
2832
from frouros.detectors.data_drift.batch.base import BaseDataDriftBatch
29-
from frouros.detectors.data_drift.streaming import (
30-
MMD as MMDStreaming,
31-
)
32-
from frouros.detectors.data_drift.streaming import ( # noqa: N811
33-
IncrementalKSTest,
34-
)
33+
from frouros.detectors.data_drift.streaming import MMD as MMDStreaming
34+
from frouros.detectors.data_drift.streaming import IncrementalKSTest
3535

3636

3737
@pytest.mark.parametrize(
@@ -102,7 +102,7 @@ def test_batch_distance_based_univariate(
102102
[
103103
(PSI(), 461.20379435),
104104
(HellingerDistance(), 0.74509099),
105-
(BhattacharyyaDistance(), 0.55516059),
105+
(BhattacharyyaDistance(), 0.810041883),
106106
],
107107
)
108108
def test_batch_distance_bins_based_univariate_different_distribution(
@@ -133,7 +133,7 @@ def test_batch_distance_bins_based_univariate_different_distribution(
133133
[
134134
(PSI(), 0.01840072),
135135
(HellingerDistance(), 0.04792538),
136-
(BhattacharyyaDistance(), 0.00229684),
136+
(BhattacharyyaDistance(), 0.00229948),
137137
],
138138
)
139139
def test_batch_distance_bins_based_univariate_same_distribution(
@@ -214,7 +214,13 @@ def test_batch_statistical_univariate(
214214
assert np.isclose(p_value, expected_p_value)
215215

216216

217-
@pytest.mark.parametrize("detector, expected_distance", [(MMD(), 0.10163633)])
217+
@pytest.mark.parametrize(
218+
"detector, expected_distance",
219+
[
220+
(BhattacharyyaDistance(), 0.39327743),
221+
(MMD(), 0.10163633),
222+
],
223+
)
218224
def test_batch_distance_based_multivariate_different_distribution(
219225
X_ref_multivariate: np.ndarray, # noqa: N803
220226
X_test_multivariate: np.ndarray, # noqa: N803
@@ -238,7 +244,13 @@ def test_batch_distance_based_multivariate_different_distribution(
238244
assert np.isclose(statistic, expected_distance)
239245

240246

241-
@pytest.mark.parametrize("detector, expected_distance", [(MMD(), 0.01570397)])
247+
@pytest.mark.parametrize(
248+
"detector, expected_distance",
249+
[
250+
(BhattacharyyaDistance(), 0.39772951),
251+
(MMD(), 0.01570397),
252+
],
253+
)
242254
def test_batch_distance_based_multivariate_same_distribution(
243255
multivariate_distribution_p: Tuple[np.ndarray, np.ndarray],
244256
detector: BaseDataDriftBatch,

0 commit comments

Comments
 (0)