Skip to content

Commit 3dd2cc9

Browse files
committed
Removed use of scipy's mahalanobis function
It creates deprecation warning if you dont' pass in 1D arrays. Filterpy's mahalanobis function is more forgiving, so switched to using it instead.
1 parent e164fc1 commit 3dd2cc9

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

filterpy/kalman/tests/test_fm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import numpy as np
2222
import matplotlib.pyplot as plt
2323
from filterpy.kalman import FadingKalmanFilter
24+
from filterpy.stats import mahalanobis
2425
from pytest import approx
25-
from scipy.spatial.distance import mahalanobis as scipy_mahalanobis
2626

2727
DO_PLOT = False
2828
def test_noisy_1d():
@@ -59,7 +59,7 @@ def test_noisy_1d():
5959

6060
# test mahalanobis
6161
a = np.zeros(f.y.shape)
62-
maha = scipy_mahalanobis(a, f.y, f.SI)
62+
maha = mahalanobis(a, f.y, f.S)
6363
assert f.mahalanobis == approx(maha)
6464
print(z, maha, f.y, f.S)
6565
assert maha < 4
@@ -87,4 +87,4 @@ def test_noisy_1d():
8787

8888
if __name__ == "__main__":
8989
DO_PLOT = True
90-
test_noisy_1d()
90+
test_noisy_1d()

filterpy/kalman/tests/test_kf.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from pytest import approx
2525
from filterpy.kalman import KalmanFilter, update, predict, batch_filter
2626
from filterpy.common import Q_discrete_white_noise, kinematic_kf, Saver
27+
from filterpy.stats import mahalanobis
2728
from scipy.linalg import block_diag, norm
28-
from scipy.spatial.distance import mahalanobis as scipy_mahalanobis
2929

3030
DO_PLOT = False
3131

@@ -121,7 +121,7 @@ def test_noisy_1d():
121121

122122
# test mahalanobis
123123
a = np.zeros(f.y.shape)
124-
maha = scipy_mahalanobis(a, f.y, f.SI)
124+
maha = mahalanobis(a, f.y, f.S)
125125
assert f.mahalanobis == approx(maha)
126126

127127

@@ -233,7 +233,7 @@ def test_noisy_11d():
233233

234234
# test mahalanobis
235235
a = np.zeros(f.y.shape)
236-
maha = scipy_mahalanobis(a, f.y, f.SI)
236+
maha = mahalanobis(a, f.y, f.S)
237237
assert f.mahalanobis == approx(maha)
238238

239239
# now do a batch run with the stored z values so we can test that
@@ -443,7 +443,7 @@ def test_steadystate():
443443
cv.update_steadystate([i])
444444
# test mahalanobis
445445
a = np.zeros(cv.y.shape)
446-
maha = scipy_mahalanobis(a, cv.y, cv.SI)
446+
maha = mahalanobis(a, cv.y, cv.S)
447447
assert cv.mahalanobis == approx(maha)
448448

449449

filterpy/stats/tests/test_stats.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,30 @@
2323
import numpy as np
2424
from numpy.linalg import inv
2525
import scipy
26-
from scipy.spatial.distance import mahalanobis as scipy_mahalanobis
27-
from filterpy.stats import norm_cdf, multivariate_gaussian, logpdf, mahalanobis
26+
from scipy.spatial.distance import mahalanobis as _scipy_mahalanobis
27+
from filterpy.stats import (norm_cdf, multivariate_gaussian, logpdf,
28+
mahalanobis)
2829
from scipy import linalg
2930

3031

3132
ITERS = 10000
3233

34+
35+
def scipy_mahalanobis(x, mean, cov):
36+
# scipy 1.9 will not accept scalars as input, so force the correct
37+
# behavior so we don't get deprecation warnings or exceptions
38+
39+
def validate_vector(u):
40+
u = np.asarray(u).squeeze()
41+
# Ensure values such as u=1 and u=[1] still return 1-D arrays.
42+
u = np.atleast_1d(u)
43+
return u
44+
45+
x = validate_vector(x)
46+
mean = validate_vector(mean)
47+
return _scipy_mahalanobis(x, mean, cov)
48+
49+
3350
def test_mahalanobis():
3451
global a, b, S
3552
# int test
@@ -92,7 +109,6 @@ def test_multivariate_gaussian():
92109
with warnings.catch_warnings():
93110
warnings.simplefilter("ignore")
94111

95-
96112
# test that we treat lists and arrays the same
97113
mean= (0, 0)
98114
cov=[[1, .5], [.5, 1]]
@@ -294,6 +310,8 @@ def covariance_3d_plot_test():
294310
plot_3d_covariance(mu, C, alpha=.4, std=3, limit_xyz=True, ax=ax)
295311

296312
if __name__ == "__main__":
313+
test_multivariate_gaussian()
314+
test_mahalanobis()
297315
test_logpdf2()
298316
covariance_3d_plot_test()
299317
plt.figure()

0 commit comments

Comments
 (0)