Skip to content

Commit 1c7dc1f

Browse files
committed
refactor cluster_compositions to take DataFrame as 1st arg, not composition list
cluster_compositions is still unreleased, so breaking changes are fine - allow passing in pre-computed plot coordinates via dataframe column - breaking: rename `projection_method` parameter to `projection` for uniformity. - hover tooltips use dynamic method labels - many more unit tests for `cluster_compositions`
1 parent 65ccb7c commit 1c7dc1f

File tree

7 files changed

+1557
-768
lines changed

7 files changed

+1557
-768
lines changed

assets/scripts/cluster/composition/cluster_compositions_matbench.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def process_dataset(
6262
target_col: str,
6363
target_label: str,
6464
embed_method: EmbeddingMethod,
65-
projection_method: ProjectionMethod,
65+
projection: ProjectionMethod,
6666
n_components: int,
6767
) -> go.Figure:
6868
"""Process a single dataset and create clustering visualizations.
@@ -72,7 +72,7 @@ def process_dataset(
7272
target_col (str): Name of the target property column
7373
target_label (str): Display label for the property
7474
embed_method (EmbeddingMethod): Method to convert compositions to vectors
75-
projection_method (ProjectionMethod): Method to reduce dimensionality
75+
projection (ProjectionMethod): Method to reduce dimensionality
7676
n_components (int): Number of dimensions for projection (2 or 3)
7777
7878
Returns:
@@ -121,12 +121,18 @@ def process_dataset(
121121
default_handler = lambda x: x.tolist() if hasattr(x, "tolist") else x
122122
json.dump(embeddings_dict, file, default=default_handler)
123123

124-
# Create plot with pre-computed embeddings
124+
df_plot = pd.DataFrame({"composition": compositions})
125+
df_plot[target_label] = properties
126+
127+
if "embeddings" not in df_plot:
128+
df_plot["embeddings"] = [embeddings_dict.get(comp) for comp in compositions]
129+
125130
fig = pmv.cluster_compositions(
126-
compositions=embeddings_dict,
127-
properties=dict(zip(compositions, properties, strict=True)),
131+
df=df_plot,
132+
composition_col="composition",
128133
prop_name=target_label,
129-
projection_method=projection_method,
134+
embedding_method="embeddings",
135+
projection=projection,
130136
n_components=n_components,
131137
marker_size=8,
132138
opacity=0.8,
@@ -136,7 +142,7 @@ def process_dataset(
136142
)
137143

138144
# Update title and margins
139-
title = f"{dataset_name} - {embed_method} + {projection_method} ({n_components}D)"
145+
title = f"{dataset_name} - {embed_method} + {projection} ({n_components}D)"
140146
fig.layout.update(title=dict(text=title, x=0.5), margin_t=50)
141147
# format compositions and coordinates in hover tooltip
142148
custom_data = [
@@ -146,9 +152,9 @@ def process_dataset(
146152
fig.update_traces(
147153
hovertemplate=(
148154
"%{customdata[0]}<br>" # Formatted composition
149-
f"{projection_method} 1: %{{x:.2f}}<br>" # First projection coordinate
150-
f"{projection_method} 2: %{{y:.2f}}<br>" # Second projection coordinate
151-
+ (f"{projection_method} 3: %{{z:.2f}}<br>" if n_components == 3 else "")
155+
f"{projection} 1: %{{x:.2f}}<br>" # First projection coordinate
156+
f"{projection} 2: %{{y:.2f}}<br>" # Second projection coordinate
157+
+ (f"{projection} 3: %{{z:.2f}}<br>" if n_components == 3 else "")
152158
+ f"{target_label}: %{{marker.color:.2f}}" # Property value
153159
),
154160
customdata=custom_data,
@@ -203,7 +209,7 @@ def process_dataset(
203209
target_col=target_col,
204210
target_label=target_label,
205211
embed_method=embed_method,
206-
projection_method=proj_method,
212+
projection=proj_method,
207213
n_components=n_components,
208214
)
209215
fig.update_layout(coloraxis_colorbar=cbar_args)

assets/scripts/phonons/phonon_bands.py

+3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
raise SystemExit(0) from None # need atomate2 for MontyDecoder to load PhononDBDoc
1919

2020

21+
pmv.set_plotly_template("pymatviz_white")
22+
23+
2124
# %% Plot phonon bands and DOS
2225
for mp_id, formula in (
2326
("mp-2758", "Sr4Se4"),

pymatviz/cluster/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Cluster analysis tools."""
2+
3+
from pymatviz.cluster import composition
4+
from pymatviz.cluster.composition import embed, plot, project

0 commit comments

Comments
 (0)