Skip to content

Commit 690af90

Browse files
committed
Extend giotto-ai#428 to Amplitude and PairwiseDistance, add regression test, improve edge case handling in _metrics
1 parent 4d5ec8c commit 690af90

File tree

5 files changed

+66
-29
lines changed

5 files changed

+66
-29
lines changed

gtda/diagrams/_metrics.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def persistence_images(diagrams, sampling, step_size, sigma, weights):
120120
# WARNING: modifies `diagrams` in place
121121
persistence_images_ = \
122122
np.zeros((len(diagrams), len(sampling), len(sampling)), dtype=float)
123-
# If both step sizes are zero, we return a trivial image
124-
if (step_size == 0).all():
123+
# If either step size is zero, we return a trivial image
124+
if (step_size == 0).any():
125125
return persistence_images_
126126

127127
# Transform diagrams from (birth, death, dim) to (birth, persistence, dim)
@@ -310,9 +310,13 @@ def _parallel_pairwise(
310310
none_dict = {dim: None for dim in homology_dimensions}
311311
samplings = effective_metric_params.pop("samplings", none_dict)
312312
step_sizes = effective_metric_params.pop("step_sizes", none_dict)
313+
if metric in ["heat", "persistence_image"]:
314+
parallel_kwargs = {"mmap_mode": "c"}
315+
else:
316+
parallel_kwargs = {}
313317

314318
n_columns = len(X2)
315-
distance_matrices = Parallel(n_jobs=n_jobs)(
319+
distance_matrices = Parallel(n_jobs=n_jobs, **parallel_kwargs)(
316320
delayed(metric_func)(
317321
_subdiagrams(X1, [dim], remove_dim=True),
318322
_subdiagrams(X2[s], [dim], remove_dim=True),
@@ -416,8 +420,12 @@ def _parallel_amplitude(X, metric, metric_params, homology_dimensions, n_jobs):
416420
none_dict = {dim: None for dim in homology_dimensions}
417421
samplings = effective_metric_params.pop("samplings", none_dict)
418422
step_sizes = effective_metric_params.pop("step_sizes", none_dict)
423+
if metric in ["heat", "persistence_image"]:
424+
parallel_kwargs = {"mmap_mode": "c"}
425+
else:
426+
parallel_kwargs = {}
419427

420-
amplitude_arrays = Parallel(n_jobs=n_jobs)(
428+
amplitude_arrays = Parallel(n_jobs=n_jobs, **parallel_kwargs)(
421429
delayed(amplitude_func)(
422430
_subdiagrams(X[s], [dim], remove_dim=True),
423431
sampling=samplings[dim],

gtda/diagrams/distance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ def transform(self, X, y=None):
229229
230230
"""
231231
check_is_fitted(self)
232-
X = check_diagrams(X, copy=True)
232+
Xt = check_diagrams(X, copy=True)
233233

234-
Xt = _parallel_pairwise(X, self._X, self.metric,
234+
Xt = _parallel_pairwise(Xt, self._X, self.metric,
235235
self.effective_metric_params_,
236236
self.homology_dimensions_,
237237
self.n_jobs)

gtda/diagrams/features.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def transform(self, X, y=None):
386386
self.effective_metric_params_,
387387
self.homology_dimensions_,
388388
self.n_jobs)
389-
if self.order is None:
390-
return Xt
391-
Xt = np.linalg.norm(Xt, axis=1, ord=self.order).reshape(-1, 1)
389+
if self.order is not None:
390+
Xt = np.linalg.norm(Xt, axis=1, ord=self.order).reshape(-1, 1)
391+
392392
return Xt

gtda/diagrams/tests/test_distance.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_not_fitted(transformer):
255255

256256
@pytest.mark.parametrize(('metric', 'metric_params'), parameters_distance)
257257
@pytest.mark.parametrize('order', [2., None])
258-
@pytest.mark.parametrize('n_jobs', [1, 2, 4])
258+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
259259
def test_dd_transform(metric, metric_params, order, n_jobs):
260260
# X_fit == X_transform
261261
dd = PairwiseDistance(metric=metric, metric_params=metric_params,
@@ -297,7 +297,7 @@ def test_dd_transform(metric, metric_params, order, n_jobs):
297297

298298
@pytest.mark.parametrize(('metric', 'metric_params'), parameters_amplitude)
299299
@pytest.mark.parametrize('order', [None, 2.])
300-
@pytest.mark.parametrize('n_jobs', [1, 2, 4])
300+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
301301
def test_da_transform(metric, metric_params, order, n_jobs):
302302
n_expected_columns = n_homology_dimensions if order is None else 1
303303

@@ -315,7 +315,7 @@ def test_da_transform(metric, metric_params, order, n_jobs):
315315

316316
@pytest.mark.parametrize(('metric', 'metric_params', 'order'),
317317
[('bottleneck', None, None)])
318-
@pytest.mark.parametrize('n_jobs', [1, 2, 4])
318+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
319319
def test_da_transform_bottleneck(metric, metric_params, order, n_jobs):
320320
da = Amplitude(metric=metric, metric_params=metric_params,
321321
order=order, n_jobs=n_jobs)
@@ -340,3 +340,24 @@ def test_pi_zero_weight_function(transformer_cls, order, Xnew):
340340
X_res = transformer.fit(X1).transform(Xnew)
341341

342342
assert np.array_equal(X_res, np.zeros_like(X_res))
343+
344+
345+
@pytest.mark.parametrize('metric', ['heat', 'persistence_image'])
346+
@pytest.mark.parametrize('transformer_cls', [Amplitude, PairwiseDistance])
347+
def test_large_hk_pi_parallel(metric, transformer_cls):
348+
"""Test that Amplitude and PairwiseDistance do not break with a read-only
349+
error when the input array is at least 1MB, the metric is either 'heat'
350+
or 'persistence_image', and more than 1 process is used (triggering
351+
joblib's use of memmaps)."""
352+
X = np.linspace(0, 100, 300000)
353+
n_bins = 10
354+
diagrams = np.expand_dims(
355+
np.stack([X, X, np.zeros(len(X))]).transpose(), axis=0
356+
)
357+
358+
transformer = transformer_cls(
359+
metric=metric, metric_params={'sigma': 1, 'n_bins': n_bins}, n_jobs=2
360+
)
361+
Xt = transformer.fit_transform(diagrams)
362+
363+
assert_almost_equal(Xt, np.zeros_like(Xt))

gtda/diagrams/tests/test_features_representations.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ def test_fit_transform_plot_wrong_hom_dims(transformer):
7070
transformer.fit_transform_plot(X, sample=0, homology_dimensions=(2,))
7171

7272

73-
def test_pe_transform():
74-
pe = PersistenceEntropy()
73+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
74+
def test_pe_transform(n_jobs):
75+
pe = PersistenceEntropy(n_jobs=n_jobs)
7576
diagram_res = np.array([[1., 0.91829583405]])
7677

7778
assert_almost_equal(pe.fit_transform(X), diagram_res)
@@ -82,22 +83,25 @@ def test_pe_transform():
8283

8384

8485
@pytest.mark.parametrize('n_bins', list(range(10, 51, 10)))
85-
def test_bc_transform_shape(n_bins):
86-
bc = BettiCurve(n_bins=n_bins)
86+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
87+
def test_bc_transform_shape(n_bins, n_jobs):
88+
bc = BettiCurve(n_bins=n_bins, n_jobs=n_jobs)
8789
X_res = bc.fit_transform(X)
8890
assert X_res.shape == (1, bc._n_dimensions, n_bins)
8991

9092

9193
@pytest.mark.parametrize('n_bins', list(range(10, 51, 10)))
9294
@pytest.mark.parametrize('n_layers', list(range(1, 10)))
93-
def test_pl_transform_shape(n_bins, n_layers):
94-
pl = PersistenceLandscape(n_bins=n_bins, n_layers=n_layers)
95+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
96+
def test_pl_transform_shape(n_bins, n_layers, n_jobs):
97+
pl = PersistenceLandscape(n_bins=n_bins, n_layers=n_layers, n_jobs=n_jobs)
9598
X_res = pl.fit_transform(X)
9699
assert X_res.shape == (1, pl._n_dimensions, n_layers, n_bins)
97100

98101

99-
def test_pi_zero_weight_function():
100-
pi = PersistenceImage(weight_function=lambda x: x * 0.)
102+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
103+
def test_pi_zero_weight_function(n_jobs):
104+
pi = PersistenceImage(weight_function=lambda x: x * 0., n_jobs=n_jobs)
101105
X_res = pi.fit_transform(X)
102106
assert np.array_equal(X_res, np.zeros_like(X_res))
103107

@@ -153,18 +157,20 @@ def test_large_pi_null_parallel():
153157
assert_almost_equal(pi.fit_transform(diagrams)[0], 0)
154158

155159

156-
def test_silhouette_transform():
157-
sht = Silhouette(n_bins=31, power=1.)
160+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
161+
def test_silhouette_transform(n_jobs):
162+
sht = Silhouette(n_bins=31, power=1., n_jobs=n_jobs)
158163
X_sht_res = np.array([0., 0.05, 0.1, 0.15, 0.2, 0.25, 0.2, 0.15, 0.1,
159164
0.05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0., 0.05,
160165
0.1, 0.15, 0.2, 0.25, 0.2, 0.15, 0.1, 0.05, 0.])
161166

162167
assert_almost_equal(sht.fit_transform(X)[0][0], X_sht_res)
163168

164169

165-
def test_silhouette_big_order():
170+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
171+
def test_silhouette_big_order(n_jobs):
166172
diagrams = np.array([[[0, 2, 0], [1, 4, 0]]])
167-
sht_10 = Silhouette(n_bins=41, power=10.)
173+
sht_10 = Silhouette(n_bins=41, power=10., n_jobs=n_jobs)
168174
X_sht_res = np.array([0., 0.00170459, 0.00340919, 0.00511378, 0.00681837,
169175
0.00852296, 0.01022756, 0.01193215, 0.01363674,
170176
0.01534133, 0.01704593, 0.11363674, 0.21022756,
@@ -179,10 +185,11 @@ def test_silhouette_big_order():
179185
assert_almost_equal(sht_10.fit_transform(diagrams)[0][0], X_sht_res)
180186

181187

182-
@pytest.mark.parametrize('transformer', [HeatKernel(), PersistenceImage()])
183-
def test_all_pts_the_same(transformer):
188+
@pytest.mark.parametrize('transformer_cls', [HeatKernel, PersistenceImage])
189+
@pytest.mark.parametrize('n_jobs', [1, 2, -1])
190+
def test_all_pts_the_same(transformer_cls, n_jobs):
184191
X = np.zeros((1, 4, 3))
185-
X_res = transformer.fit_transform(X)
192+
X_res = transformer_cls(n_jobs=n_jobs).fit_transform(X)
186193
assert np.array_equal(X_res, np.zeros_like(X_res))
187194

188195

@@ -222,13 +229,14 @@ def get_input(pts, dims):
222229
return X
223230

224231

232+
@pytest.mark.parametrize('n_jobs', [1, 2])
225233
@given(pts_gen, dims_gen)
226-
def test_hk_shape(pts, dims):
234+
def test_hk_shape(n_jobs, pts, dims):
227235
n_bins = 10
228236
X = get_input(pts, dims)
229237
sigma = (np.max(X[:, :, :2]) - np.min(X[:, :, :2])) / 2
230238

231-
hk = HeatKernel(sigma=sigma, n_bins=n_bins)
239+
hk = HeatKernel(sigma=sigma, n_bins=n_bins, n_jobs=n_jobs)
232240
num_dimensions = len(np.unique(dims))
233241
X_t = hk.fit_transform(X)
234242

0 commit comments

Comments
 (0)