Skip to content

Commit c067d88

Browse files
committed
aeon elastic_barycenter_average was updated
1 parent 8310733 commit c067d88

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

pycvi/cluster.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
- :func:`pycvi.cluster.get_clustering`, that converts an array of predicted label for each datapoint (sklearn type of clustering encoding) to a list of datapoints for each cluster (PyCVI type of clustering encoding)
1111
1212
"""
13-
13+
import inspect
1414
import numpy as np
1515
from sklearn.preprocessing import StandardScaler
1616
from aeon.clustering.averaging._barycenter_averaging import (
@@ -65,7 +65,25 @@ def compute_center(
6565
if dims[0] == 1:
6666
center = cluster[0]
6767
else:
68-
center = elastic_barycenter_average(np.swapaxes(cluster, 1, 2))
68+
69+
# Check args of elastic_barycenter_average
70+
f_args = inspect.getfullargspec(elastic_barycenter_average)
71+
72+
# aeon version > 0.8.1 (#1339)
73+
if (
74+
"init_barycenter" in f_args[0] and "method" in f_args[0]
75+
and "distance" in f_args[0]
76+
):
77+
center = elastic_barycenter_average(
78+
np.swapaxes(cluster, 1, 2),
79+
distance="dtw",
80+
init_barycenter="medoids",
81+
method="petitjean"
82+
)
83+
# aeon version <= 0.8.1
84+
else:
85+
center = elastic_barycenter_average(np.swapaxes(cluster, 1, 2))
86+
6987
center = np.swapaxes(center, 0, 1)
7088

7189
else:

0 commit comments

Comments
 (0)