|
| 1 | +"""This example script clusters the smallest MatBench datasets |
| 2 | +(matbench_steels and matbench_jdft2d) using different embedding and projection methods. |
| 3 | +Resulting plots are colored by target property of each dataset. |
| 4 | +""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import gzip |
| 9 | +import json |
| 10 | +import os |
| 11 | +from typing import TYPE_CHECKING, Any |
| 12 | + |
| 13 | +import pandas as pd |
| 14 | +from matminer.datasets import load_dataset |
| 15 | +from pymatgen.core import Composition |
| 16 | + |
| 17 | +import pymatviz as pmv |
| 18 | +from pymatviz.cluster.composition import ( |
| 19 | + EmbeddingMethod, |
| 20 | + matminer_featurize, |
| 21 | + one_hot_encode, |
| 22 | +) |
| 23 | +from pymatviz.enums import Key |
| 24 | + |
| 25 | + |
| 26 | +if TYPE_CHECKING: |
| 27 | + import plotly.graph_objects as go |
| 28 | + |
| 29 | + from pymatviz.cluster.composition import ProjectionMethod |
| 30 | + |
| 31 | + |
| 32 | +pmv.set_plotly_template("pymatviz_white") |
| 33 | +module_dir = os.path.dirname(__file__) |
| 34 | +plot_dir = f"{module_dir}/tmp/figs/composition_clustering" |
| 35 | +cache_dir = f"{module_dir}/tmp/embeddings" |
| 36 | +os.makedirs(plot_dir, exist_ok=True) |
| 37 | +os.makedirs(cache_dir, exist_ok=True) |
| 38 | + |
| 39 | + |
| 40 | +def format_composition(formula: str) -> str: |
| 41 | + """Format long steel compositions into 2-column layout, sorted by amount.""" |
| 42 | + comp = Composition(formula) |
| 43 | + # Sort elements by amount in descending order |
| 44 | + element_pairs = [] |
| 45 | + for idx, (elem, amt) in enumerate( |
| 46 | + sorted(comp.items(), key=lambda x: x[1], reverse=True) |
| 47 | + ): |
| 48 | + suffix = "<br>" if idx % 2 == 1 else "" |
| 49 | + element_pairs.append(f"{elem}: {amt:.4}{suffix}") |
| 50 | + return "\t\t".join(element_pairs).replace("<br>\t\t", "<br>") |
| 51 | + |
| 52 | + |
| 53 | +def process_dataset( |
| 54 | + dataset_name: str, |
| 55 | + target_col: str, |
| 56 | + target_label: str, |
| 57 | + embed_method: EmbeddingMethod, |
| 58 | + projection_method: ProjectionMethod, |
| 59 | + n_components: int, |
| 60 | +) -> go.Figure: |
| 61 | + """Process a single dataset and create clustering visualizations. |
| 62 | +
|
| 63 | + Args: |
| 64 | + dataset_name (str): Name of the MatBench dataset to load |
| 65 | + target_col (str): Name of the target property column |
| 66 | + target_label (str): Display label for the property |
| 67 | + embed_method (EmbeddingMethod): Method to convert compositions to vectors |
| 68 | + projection_method (ProjectionMethod): Method to reduce dimensionality |
| 69 | + n_components (int): Number of dimensions for projection (2 or 3) |
| 70 | +
|
| 71 | + Returns: |
| 72 | + fig: Plotly figure |
| 73 | + """ |
| 74 | + # Load dataset |
| 75 | + df_data = load_dataset(dataset_name) |
| 76 | + |
| 77 | + # Extract compositions and target values |
| 78 | + if Key.composition in df_data: |
| 79 | + compositions = df_data[Key.composition].tolist() |
| 80 | + else: |
| 81 | + # Extract formula from structure |
| 82 | + compositions = [struct.formula for struct in df_data[Key.structure]] |
| 83 | + |
| 84 | + properties = df_data[target_col].tolist() |
| 85 | + |
| 86 | + # Create a DataFrame to align compositions and properties |
| 87 | + df_with_prop = pd.DataFrame( |
| 88 | + {"composition": compositions, "property": properties} |
| 89 | + ).dropna() |
| 90 | + compositions = df_with_prop["composition"].tolist() |
| 91 | + properties = df_with_prop["property"].tolist() |
| 92 | + |
| 93 | + # Try to load cached embeddings |
| 94 | + cache_file = f"{cache_dir}/{dataset_name}_{embed_method}.json.gz" |
| 95 | + embeddings_dict = None |
| 96 | + if os.path.isfile(cache_file): |
| 97 | + with gzip.open(cache_file, mode="rt") as file: |
| 98 | + embeddings_dict = json.load(file) |
| 99 | + |
| 100 | + if embeddings_dict is None: |
| 101 | + # Create embeddings |
| 102 | + if embed_method == "one-hot": |
| 103 | + embeddings = one_hot_encode(compositions) |
| 104 | + elif embed_method in ["magpie", "matscholar_el"]: |
| 105 | + embeddings = matminer_featurize(compositions, preset=embed_method) |
| 106 | + else: |
| 107 | + raise ValueError(f"Unknown {embed_method=}") |
| 108 | + |
| 109 | + # Convert to dictionary mapping compositions to their embeddings |
| 110 | + embeddings_dict = dict(zip(compositions, embeddings, strict=True)) |
| 111 | + |
| 112 | + # Cache the embeddings |
| 113 | + with gzip.open(cache_file, mode="wt") as file: |
| 114 | + default_handler = lambda x: x.tolist() if hasattr(x, "tolist") else x |
| 115 | + json.dump(embeddings_dict, file, default=default_handler) |
| 116 | + |
| 117 | + # Create plot with pre-computed embeddings |
| 118 | + fig = pmv.cluster_compositions( |
| 119 | + compositions=embeddings_dict, |
| 120 | + properties=dict(zip(compositions, properties, strict=True)), |
| 121 | + prop_name=target_label, |
| 122 | + projection_method=projection_method, |
| 123 | + n_components=n_components, |
| 124 | + marker_size=8, |
| 125 | + opacity=0.8, |
| 126 | + width=1000, |
| 127 | + height=600, |
| 128 | + show_chem_sys="shape" if len(compositions) < 1000 else None, |
| 129 | + ) |
| 130 | + |
| 131 | + # Update title and margins |
| 132 | + title = f"{dataset_name} - {embed_method} + {projection_method} ({n_components}D)" |
| 133 | + fig.layout.update(title=dict(text=title, x=0.5), margin_t=50) |
| 134 | + # format compositions and coordinates in hover tooltip |
| 135 | + custom_data = [ |
| 136 | + [format_composition(comp) if dataset_name == "matbench_steels" else comp] |
| 137 | + for comp in compositions |
| 138 | + ] |
| 139 | + fig.update_traces( |
| 140 | + hovertemplate=( |
| 141 | + "%{customdata[0]}<br>" # Formatted composition |
| 142 | + f"{projection_method} 1: %{{x:.2f}}<br>" # First projection coordinate |
| 143 | + f"{projection_method} 2: %{{y:.2f}}<br>" # Second projection coordinate |
| 144 | + + (f"{projection_method} 3: %{{z:.2f}}<br>" if n_components == 3 else "") |
| 145 | + + f"{target_label}: %{{marker.color:.2f}}" # Property value |
| 146 | + ), |
| 147 | + customdata=custom_data, |
| 148 | + ) |
| 149 | + |
| 150 | + return fig |
| 151 | + |
| 152 | + |
| 153 | +mb_jdft2d = ("matbench_jdft2d", "exfoliation_en", "Exfoliation Energy (meV/atom)") |
| 154 | +mb_steels = ("matbench_steels", "yield strength", "Yield Strength (MPa)") |
| 155 | +mb_dielectric = ("matbench_dielectric", "n", "Refractive index") |
| 156 | +mb_perovskites = ("matbench_perovskites", "e_form", "Formation energy (eV/atom)") |
| 157 | +mb_phonons = ("matbench_phonons", "last phdos peak", "Max Phonon Peak (cm⁻¹)") |
| 158 | +mb_bulk_modulus = ("matbench_log_kvrh", "log10(K_VRH)", "Bulk Modulus (GPa)") |
| 159 | +plot_combinations: list[ |
| 160 | + tuple[str, str, str, EmbeddingMethod, ProjectionMethod, int, dict[str, Any]] |
| 161 | +] = [ |
| 162 | + # 1. Steels with PCA (2D) - shows clear linear trends |
| 163 | + (*mb_steels, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 164 | + # 2. Steels with t-SNE (2D) - shows non-linear clustering |
| 165 | + (*mb_steels, "magpie", "tsne", 2, dict(x=0.01, xanchor="left")), |
| 166 | + # 3. JDFT2D with UMAP (2D) - shows modern non-linear projection |
| 167 | + (*mb_jdft2d, "magpie", "umap", 2, dict(x=0.01, xanchor="left")), |
| 168 | + # 4. JDFT2D with one-hot encoding and PCA (3D) - shows raw element relationships |
| 169 | + (*mb_jdft2d, "one-hot", "pca", 3, dict()), |
| 170 | + # 5. Steels with Matscholar embedding and t-SNE (3D) - shows advanced embedding |
| 171 | + (*mb_steels, "matscholar_el", "tsne", 3, dict(x=0.5, y=0.8)), |
| 172 | + # 6. Dielectric with PCA (2D) - shows clear linear trends |
| 173 | + (*mb_dielectric, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 174 | + # 7. Perovskites with PCA (2D) - shows clear linear trends |
| 175 | + (*mb_perovskites, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 176 | + # 8. Phonons with PCA (2D) - shows clear linear trends |
| 177 | + (*mb_phonons, "magpie", "pca", 2, dict(x=0.01, xanchor="left")), |
| 178 | + # 9. Bulk Modulus with PCA (2D) - shows clear linear trends |
| 179 | + (*mb_bulk_modulus, "magpie", "pca", 2, dict(x=0.99, y=0.96, yanchor="top")), |
| 180 | + # 10. Perovskites with t-SNE (3D) - shows raw element relationships |
| 181 | + (*mb_perovskites, "magpie", "tsne", 3, dict()), |
| 182 | +] |
| 183 | + |
| 184 | +for ( |
| 185 | + data_name, |
| 186 | + target_col, |
| 187 | + target_label, |
| 188 | + embed_method, |
| 189 | + proj_method, |
| 190 | + n_components, |
| 191 | + cbar_args, |
| 192 | +) in plot_combinations: |
| 193 | + fig = process_dataset( |
| 194 | + dataset_name=data_name, |
| 195 | + target_col=target_col, |
| 196 | + target_label=target_label, |
| 197 | + embed_method=embed_method, |
| 198 | + projection_method=proj_method, |
| 199 | + n_components=n_components, |
| 200 | + ) |
| 201 | + fig.update_layout(coloraxis_colorbar=cbar_args) |
| 202 | + |
| 203 | + # Save as HTML and SVG |
| 204 | + img_name = f"{data_name}-{embed_method}-{proj_method}-{n_components}d".replace( |
| 205 | + "_", "-" |
| 206 | + ) |
| 207 | + fig.write_html(f"{plot_dir}/{img_name}.html", include_plotlyjs="cdn") |
| 208 | + pmv.io.save_and_compress_svg(fig, img_name) |
| 209 | + |
| 210 | + fig.show() |
0 commit comments