Skip to content

Commit 3406455

Browse files
authored
New cluster module with functions for embedding, projecting and scattering compositions (#285)
* new `cluster` module with functions for embedding, plotting, and projecting compositions, structures to be added later - `matbench.py` example script for visualizing clustering across MatBench datasets using various embedding and projection methods. - `embed.py` for chemical embedding functions and `project.py` for dimensionality reduction techniques. - `plot.py` for generating clustering visualizations with Plotly. - unit tests for embedding, plotting, and projection functions to ensure functionality and robustness. - update `__init__.py` files to include new modules and functions for easier access. * fix tests and add 'cluster' optional deps set * unit tests for axis labels for pre-computed embeddings when projecting to 3D * make `project_vectors` return projector objects and add more tests - `project_vectors` now returns both projected data and the fitted projection object (e.g., PCA, TSNE). - `cluster_compositions` now uses the returned PCA object to get explained variance - unit tests check type and attributes of projection objects for various methods - remove `return_explained_variance` parameter from `project_vectors` - add missing type annotations in docstrings * cleanup * add composition clustering section to readme.md and rename matbench.py to cluster_compositions_matbench.py with more polished examples * don't install umap-learn in pytest CI due to numba clash * skip tests that require UMAP * remove one_hot_encode log_transform keyword and install move umap to separate optional deps set to install cluster deps with matminer in CI * enhance cluster_compositions chemical system visualization options - Renamed `point_size` to `marker_size` for clarity. - Updated `show_chem_sys` parameter to accept more options: "color", "shape", "color+shape", or None. - Added warnings for potential symbol duplication when using shape visualization with many unique chemical systems. - new unit tests to cover various visualization mode combinations * swap readme 3d composition cluster example to matbench-perovskites - cleanup `cluster_compositions` colorbar config
1 parent 0d942a7 commit 3406455

29 files changed

+2756
-42
lines changed

.github/workflows/test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ jobs:
6161
python-version: "3.10"
6262

6363
- name: Install package and dependencies
64-
run: pip install -e .[make-assets]
64+
run: pip install -e .[make-assets,cluster]
6565

6666
- name: Run script
6767
run: python ${{ matrix.script }}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""This example script clusters the smallest MatBench datasets
2+
(matbench_steels and matbench_jdft2d) using different embedding and projection methods.
3+
Resulting plots are colored by target property of each dataset.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import gzip
9+
import json
10+
import os
11+
from typing import TYPE_CHECKING, Any
12+
13+
import pandas as pd
14+
from matminer.datasets import load_dataset
15+
from pymatgen.core import Composition
16+
17+
import pymatviz as pmv
18+
from pymatviz.cluster.composition import (
19+
EmbeddingMethod,
20+
matminer_featurize,
21+
one_hot_encode,
22+
)
23+
from pymatviz.enums import Key
24+
25+
26+
if TYPE_CHECKING:
27+
import plotly.graph_objects as go
28+
29+
from pymatviz.cluster.composition import ProjectionMethod
30+
31+
32+
pmv.set_plotly_template("pymatviz_white")
33+
module_dir = os.path.dirname(__file__)
34+
plot_dir = f"{module_dir}/tmp/figs/composition_clustering"
35+
cache_dir = f"{module_dir}/tmp/embeddings"
36+
os.makedirs(plot_dir, exist_ok=True)
37+
os.makedirs(cache_dir, exist_ok=True)
38+
39+
40+
def format_composition(formula: str) -> str:
41+
"""Format long steel compositions into 2-column layout, sorted by amount."""
42+
comp = Composition(formula)
43+
# Sort elements by amount in descending order
44+
element_pairs = []
45+
for idx, (elem, amt) in enumerate(
46+
sorted(comp.items(), key=lambda x: x[1], reverse=True)
47+
):
48+
suffix = "<br>" if idx % 2 == 1 else ""
49+
element_pairs.append(f"{elem}: {amt:.4}{suffix}")
50+
return "\t\t".join(element_pairs).replace("<br>\t\t", "<br>")
51+
52+
53+
def process_dataset(
54+
dataset_name: str,
55+
target_col: str,
56+
target_label: str,
57+
embed_method: EmbeddingMethod,
58+
projection_method: ProjectionMethod,
59+
n_components: int,
60+
) -> go.Figure:
61+
"""Process a single dataset and create clustering visualizations.
62+
63+
Args:
64+
dataset_name (str): Name of the MatBench dataset to load
65+
target_col (str): Name of the target property column
66+
target_label (str): Display label for the property
67+
embed_method (EmbeddingMethod): Method to convert compositions to vectors
68+
projection_method (ProjectionMethod): Method to reduce dimensionality
69+
n_components (int): Number of dimensions for projection (2 or 3)
70+
71+
Returns:
72+
fig: Plotly figure
73+
"""
74+
# Load dataset
75+
df_data = load_dataset(dataset_name)
76+
77+
# Extract compositions and target values
78+
if Key.composition in df_data:
79+
compositions = df_data[Key.composition].tolist()
80+
else:
81+
# Extract formula from structure
82+
compositions = [struct.formula for struct in df_data[Key.structure]]
83+
84+
properties = df_data[target_col].tolist()
85+
86+
# Create a DataFrame to align compositions and properties
87+
df_with_prop = pd.DataFrame(
88+
{"composition": compositions, "property": properties}
89+
).dropna()
90+
compositions = df_with_prop["composition"].tolist()
91+
properties = df_with_prop["property"].tolist()
92+
93+
# Try to load cached embeddings
94+
cache_file = f"{cache_dir}/{dataset_name}_{embed_method}.json.gz"
95+
embeddings_dict = None
96+
if os.path.isfile(cache_file):
97+
with gzip.open(cache_file, mode="rt") as file:
98+
embeddings_dict = json.load(file)
99+
100+
if embeddings_dict is None:
101+
# Create embeddings
102+
if embed_method == "one-hot":
103+
embeddings = one_hot_encode(compositions)
104+
elif embed_method in ["magpie", "matscholar_el"]:
105+
embeddings = matminer_featurize(compositions, preset=embed_method)
106+
else:
107+
raise ValueError(f"Unknown {embed_method=}")
108+
109+
# Convert to dictionary mapping compositions to their embeddings
110+
embeddings_dict = dict(zip(compositions, embeddings, strict=True))
111+
112+
# Cache the embeddings
113+
with gzip.open(cache_file, mode="wt") as file:
114+
default_handler = lambda x: x.tolist() if hasattr(x, "tolist") else x
115+
json.dump(embeddings_dict, file, default=default_handler)
116+
117+
# Create plot with pre-computed embeddings
118+
fig = pmv.cluster_compositions(
119+
compositions=embeddings_dict,
120+
properties=dict(zip(compositions, properties, strict=True)),
121+
prop_name=target_label,
122+
projection_method=projection_method,
123+
n_components=n_components,
124+
marker_size=8,
125+
opacity=0.8,
126+
width=1000,
127+
height=600,
128+
show_chem_sys="shape" if len(compositions) < 1000 else None,
129+
)
130+
131+
# Update title and margins
132+
title = f"{dataset_name} - {embed_method} + {projection_method} ({n_components}D)"
133+
fig.layout.update(title=dict(text=title, x=0.5), margin_t=50)
134+
# format compositions and coordinates in hover tooltip
135+
custom_data = [
136+
[format_composition(comp) if dataset_name == "matbench_steels" else comp]
137+
for comp in compositions
138+
]
139+
fig.update_traces(
140+
hovertemplate=(
141+
"%{customdata[0]}<br>" # Formatted composition
142+
f"{projection_method} 1: %{{x:.2f}}<br>" # First projection coordinate
143+
f"{projection_method} 2: %{{y:.2f}}<br>" # Second projection coordinate
144+
+ (f"{projection_method} 3: %{{z:.2f}}<br>" if n_components == 3 else "")
145+
+ f"{target_label}: %{{marker.color:.2f}}" # Property value
146+
),
147+
customdata=custom_data,
148+
)
149+
150+
return fig
151+
152+
153+
mb_jdft2d = ("matbench_jdft2d", "exfoliation_en", "Exfoliation Energy (meV/atom)")
154+
mb_steels = ("matbench_steels", "yield strength", "Yield Strength (MPa)")
155+
mb_dielectric = ("matbench_dielectric", "n", "Refractive index")
156+
mb_perovskites = ("matbench_perovskites", "e_form", "Formation energy (eV/atom)")
157+
mb_phonons = ("matbench_phonons", "last phdos peak", "Max Phonon Peak (cm⁻¹)")
158+
mb_bulk_modulus = ("matbench_log_kvrh", "log10(K_VRH)", "Bulk Modulus (GPa)")
159+
plot_combinations: list[
160+
tuple[str, str, str, EmbeddingMethod, ProjectionMethod, int, dict[str, Any]]
161+
] = [
162+
# 1. Steels with PCA (2D) - shows clear linear trends
163+
(*mb_steels, "magpie", "pca", 2, dict(x=0.01, xanchor="left")),
164+
# 2. Steels with t-SNE (2D) - shows non-linear clustering
165+
(*mb_steels, "magpie", "tsne", 2, dict(x=0.01, xanchor="left")),
166+
# 3. JDFT2D with UMAP (2D) - shows modern non-linear projection
167+
(*mb_jdft2d, "magpie", "umap", 2, dict(x=0.01, xanchor="left")),
168+
# 4. JDFT2D with one-hot encoding and PCA (3D) - shows raw element relationships
169+
(*mb_jdft2d, "one-hot", "pca", 3, dict()),
170+
# 5. Steels with Matscholar embedding and t-SNE (3D) - shows advanced embedding
171+
(*mb_steels, "matscholar_el", "tsne", 3, dict(x=0.5, y=0.8)),
172+
# 6. Dielectric with PCA (2D) - shows clear linear trends
173+
(*mb_dielectric, "magpie", "pca", 2, dict(x=0.01, xanchor="left")),
174+
# 7. Perovskites with PCA (2D) - shows clear linear trends
175+
(*mb_perovskites, "magpie", "pca", 2, dict(x=0.01, xanchor="left")),
176+
# 8. Phonons with PCA (2D) - shows clear linear trends
177+
(*mb_phonons, "magpie", "pca", 2, dict(x=0.01, xanchor="left")),
178+
# 9. Bulk Modulus with PCA (2D) - shows clear linear trends
179+
(*mb_bulk_modulus, "magpie", "pca", 2, dict(x=0.99, y=0.96, yanchor="top")),
180+
# 10. Perovskites with t-SNE (3D) - shows raw element relationships
181+
(*mb_perovskites, "magpie", "tsne", 3, dict()),
182+
]
183+
184+
for (
185+
data_name,
186+
target_col,
187+
target_label,
188+
embed_method,
189+
proj_method,
190+
n_components,
191+
cbar_args,
192+
) in plot_combinations:
193+
fig = process_dataset(
194+
dataset_name=data_name,
195+
target_col=target_col,
196+
target_label=target_label,
197+
embed_method=embed_method,
198+
projection_method=proj_method,
199+
n_components=n_components,
200+
)
201+
fig.update_layout(coloraxis_colorbar=cbar_args)
202+
203+
# Save as HTML and SVG
204+
img_name = f"{data_name}-{embed_method}-{proj_method}-{n_components}d".replace(
205+
"_", "-"
206+
)
207+
fig.write_html(f"{plot_dir}/{img_name}.html", include_plotlyjs="cdn")
208+
pmv.io.save_and_compress_svg(fig, img_name)
209+
210+
fig.show()

assets/svg/matbench-jdft2d-magpie-umap-2d.svg

+1
Loading

assets/svg/matbench-log-kvrh-magpie-pca-2d.svg

+1
Loading

assets/svg/matbench-perovskites-magpie-pca-2d.svg

+1
Loading

assets/svg/matbench-perovskites-magpie-tsne-3d.svg

+1
Loading

assets/svg/matbench-phonons-magpie-pca-2d.svg

+1
Loading

assets/svg/matbench-steels-magpie-tsne-2d.svg

+1
Loading

assets/svg/matbench-steels-matscholar-el-tsne-3d.svg

+1
Loading

examples/diatomics/calc_mlip_diatomic_curves.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def calc_diatomic_curve(
131131
atomic_numbers = [*range(1, 85)]
132132
# atomic_numbers = [*range(1, 85), *range(89, 95)]
133133
else:
134-
raise ValueError(f"Unknown model: {model_name}")
134+
raise ValueError(f"Unknown {model_name=}")
135135

136136
kwargs = dict(
137137
calculator=calculator,

pymatviz/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
bar,
2121
brillouin,
2222
classify,
23+
cluster,
2324
colors,
2425
coordination,
2526
data,
@@ -44,6 +45,7 @@
4445
from pymatviz.brillouin import brillouin_zone_3d
4546
from pymatviz.classify import precision_recall_curve_plotly, roc_curve_plotly
4647
from pymatviz.classify.confusion_matrix import confusion_matrix
48+
from pymatviz.cluster.composition import cluster_compositions
4749
from pymatviz.coordination import coordination_hist, coordination_vs_cutoff_line
4850
from pymatviz.enums import Key, angstrom_per_atom, cubic_angstrom, eV
4951
from pymatviz.histogram import elements_hist, histogram, spacegroup_bar
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Chemical clustering module for material composition analysis.
2+
3+
This module provides utilities for clustering and visualizing materials based on their
4+
chemical composition.
5+
"""
6+
7+
from pymatviz.cluster.composition.embed import matminer_featurize, one_hot_encode
8+
from pymatviz.cluster.composition.plot import (
9+
EmbeddingMethod,
10+
ProjectionMethod,
11+
cluster_compositions,
12+
)
13+
from pymatviz.cluster.composition.project import project_vectors

0 commit comments

Comments
 (0)