|
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | import pandas as pd
|
| 6 | +import plotly.express as px |
6 | 7 | from pymatgen.core import Composition
|
7 | 8 | from pymatviz import count_elements, ptable_heatmap_plotly
|
8 | 9 | from pymatviz.utils import save_fig
|
9 | 10 |
|
10 | 11 | from matbench_discovery import FIGS, ROOT, today
|
| 12 | +from matbench_discovery import plots as plots |
11 | 13 | from matbench_discovery.data import DATA_FILES, df_wbm
|
12 | 14 | from matbench_discovery.energy import mp_elem_reference_entries
|
13 |
| -from matbench_discovery.plots import pio |
| 15 | +from matbench_discovery.preds import df_each_err, each_true_col |
| 16 | + |
| 17 | +__author__ = "Janosh Riebesell" |
| 18 | +__date__ = "2023-03-30" |
14 | 19 |
|
15 | 20 | """
|
16 |
| -Compare MP and WBM elemental prevalence. Starting with WBM, MP below. |
| 21 | +WBM exploratory data analysis. |
| 22 | +Start with comparing MP and WBM elemental prevalence. |
17 | 23 | """
|
18 | 24 |
|
19 | 25 | module_dir = os.path.dirname(__file__)
|
20 |
| -print(f"{pio.templates.default=}") |
21 | 26 | about_data_page = f"{ROOT}/site/src/routes/about-the-data"
|
22 | 27 |
|
23 | 28 |
|
|
170 | 175 | fig.show()
|
171 | 176 |
|
172 | 177 | save_fig(fig, f"{FIGS}/mp-elemental-ref-energies.svelte")
|
| 178 | + |
| 179 | + |
| 180 | +# %% plot 2d and 3d t-SNE projections of one-hot encoded element vectors summed by |
| 181 | +# weight in each WBM composition. TLDR: no obvious structure in the data |
| 182 | +# was hoping to find certain clusters to have higher or lower errors after seeing |
| 183 | +# many models struggle on the halogens in per-element error periodic table heatmaps |
| 184 | +# https://matbench-discovery.janosh.dev/models |
| 185 | +df_2d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-2d.csv.gz") |
| 186 | +df_2d_tsne = df_2d_tsne.set_index("material_id") |
| 187 | + |
| 188 | +df_3d_tsne = pd.read_csv(f"{module_dir}/tsne/one-hot-112-composition-3d.csv.gz") |
| 189 | +model = "Wrenformer" |
| 190 | +df_3d_tsne = pd.read_csv( |
| 191 | + f"{module_dir}/tsne/one-hot-112-composition+{model}-each-err-3d-metric=eucl.csv.gz" |
| 192 | +) |
| 193 | +df_3d_tsne = df_3d_tsne.set_index("material_id") |
| 194 | + |
| 195 | +df_wbm[list(df_2d_tsne)] = df_2d_tsne |
| 196 | +df_wbm[list(df_3d_tsne)] = df_3d_tsne |
| 197 | +df_wbm[list(df_each_err.add_suffix(" abs EACH error"))] = df_each_err.abs() |
| 198 | + |
| 199 | + |
| 200 | +# %% |
| 201 | +color_col = f"{model} abs EACH error" |
| 202 | +clr_range_max = df_wbm[color_col].mean() + df_wbm[color_col].std() |
| 203 | + |
| 204 | + |
| 205 | +# %% |
| 206 | +fig = px.scatter( |
| 207 | + df_wbm, |
| 208 | + x="2d t-SNE 1", |
| 209 | + y="2d t-SNE 2", |
| 210 | + color=color_col, |
| 211 | + hover_name="material_id", |
| 212 | + hover_data=("formula", each_true_col), |
| 213 | + range_color=(0, clr_range_max), |
| 214 | +) |
| 215 | +fig.show() |
| 216 | + |
| 217 | + |
| 218 | +# %% |
| 219 | +fig = px.scatter_3d( |
| 220 | + df_wbm, |
| 221 | + x="3d t-SNE 1", |
| 222 | + y="3d t-SNE 2", |
| 223 | + z="3d t-SNE 3", |
| 224 | + color=color_col, |
| 225 | + custom_data=["material_id", "formula", each_true_col, color_col], |
| 226 | + range_color=(0, clr_range_max), |
| 227 | +) |
| 228 | +fig.data[0].hovertemplate = ( |
| 229 | + "<b>material_id: %{customdata[0]}</b><br><br>" |
| 230 | + "t-SNE: (%{x:.2f}, %{y:.2f}, %{z:.2f})<br>" |
| 231 | + "Formula: %{customdata[1]}<br>" |
| 232 | + "E<sub>above hull</sub>: %{customdata[2]:.2f}<br>" |
| 233 | + f"{color_col}: %{{customdata[3]:.2f}}<br>" |
| 234 | +) |
| 235 | +fig.show() |
0 commit comments