Skip to content

Commit 509ecb6

Browse files
Merge pull request #638 from guillaume-vignal/fix/correlation_plot
Ensure finite values in distance matrix for clustering
2 parents e8749e2 + ff8d339 commit 509ecb6

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

shapash/plots/plot_correlations.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,23 @@ def cluster_corr(corr, degree, inplace=False):
100100
if corr.shape[0] < 2:
101101
return corr
102102

103-
pairwise_distances = sch.distance.pdist(corr**degree)
103+
# Compute pairwise distances based on transformed correlation matrix
104+
pairwise_distances = sch.distance.pdist(np.abs(corr) ** degree)
105+
106+
# Replace non-finite values (NaN, inf) with the maximum valid distance or 0 if all are invalid
107+
finite_mask = np.isfinite(pairwise_distances)
108+
if not np.all(finite_mask):
109+
# Use the maximum of the valid distances as a fallback value
110+
max_valid = pairwise_distances[finite_mask].max() if np.any(finite_mask) else 0.0
111+
pairwise_distances[~finite_mask] = max_valid
112+
113+
# Perform hierarchical clustering
104114
linkage = sch.linkage(pairwise_distances, method="complete")
115+
116+
# Define threshold for cluster cutting (half of the maximum distance)
105117
cluster_distance_threshold = pairwise_distances.max() / 2
118+
119+
# Assign cluster labels
106120
idx_to_cluster_array = sch.fcluster(linkage, cluster_distance_threshold, criterion="distance")
107121
idx = np.argsort(idx_to_cluster_array)
108122

@@ -220,8 +234,13 @@ def prepare_corr_matrix(df_subset):
220234
title += f"<span style='font-size: 12px;'><br />{subtitle}</span>"
221235
dict_t = style_dict_default["dict_title"] | {"text": title, "y": adjust_title_height(height)}
222236

237+
if corr.min().min() >= 0:
238+
colorscale = ["rgb(255, 255, 255)"] + style_dict_default["init_contrib_colorscale"][5:-1]
239+
else:
240+
colorscale = style_dict_default["init_contrib_colorscale"]
241+
223242
fig.update_layout(
224-
coloraxis=dict(colorscale=["rgb(255, 255, 255)"] + style_dict_default["init_contrib_colorscale"][5:-1]),
243+
coloraxis=dict(colorscale=colorscale),
225244
showlegend=True,
226245
title=dict_t,
227246
width=width,

0 commit comments

Comments
 (0)