|
5 | 5 |
|
6 | 6 |
|
7 | 7 | # %%
|
| 8 | +import itertools |
| 9 | + |
8 | 10 | import matplotlib.pyplot as plt
|
| 11 | +import numpy as np |
9 | 12 | import pandas as pd
|
10 |
| -from pymatgen.core import Structure |
11 |
| -from pymatviz import plot_structure_2d, ptable_heatmap_plotly |
| 13 | +from matminer.featurizers.site import CrystalNNFingerprint |
| 14 | +from matminer.featurizers.structure import SiteStatsFingerprint |
| 15 | +from pymatgen.core import Composition, Element, Structure |
| 16 | +from pymatviz import count_elements, plot_structure_2d, ptable_heatmap_plotly |
| 17 | +from tqdm import tqdm |
12 | 18 |
|
13 |
| -from matbench_discovery import ROOT |
| 19 | +from matbench_discovery import MODELS, ROOT |
14 | 20 | from matbench_discovery.data import DATA_FILES
|
| 21 | +from matbench_discovery.data import df_wbm as df_summary |
15 | 22 | from matbench_discovery.metrics import classify_stable
|
16 |
| -from matbench_discovery.preds import df_each_err, df_each_pred, df_preds, each_true_col |
| 23 | +from matbench_discovery.preds import ( |
| 24 | + df_each_err, |
| 25 | + df_each_pred, |
| 26 | + df_metrics, |
| 27 | + df_preds, |
| 28 | + each_true_col, |
| 29 | +) |
17 | 30 |
|
18 | 31 | __author__ = "Janosh Riebesell"
|
19 | 32 | __date__ = "2023-02-15"
|
20 | 33 |
|
21 | 34 | df_each_err[each_true_col] = df_preds[each_true_col]
|
22 |
| -mean_ae_col = "All models mean absolute error (eV/atom)" |
| 35 | +mean_ae_col = "All models MAE (eV/atom)" |
23 | 36 | df_each_err[mean_ae_col] = df_preds[mean_ae_col] = df_each_err.abs().mean(axis=1)
|
24 | 37 |
|
25 | 38 |
|
26 | 39 | # %%
|
27 |
| -df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index( |
28 |
| - "material_id" |
29 |
| -) |
| 40 | +df_wbm = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index("material_id") |
30 | 41 |
|
31 | 42 |
|
32 | 43 | # %%
|
33 | 44 | n_rows, n_cols = 5, 4
|
34 |
| -for which in ("best", "worst"): |
| 45 | +for good_bad, init_final in itertools.product(("best", "worst"), ("initial", "final")): |
35 | 46 | fig, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
|
36 |
| - n_axs = len(axs.flat) |
| 47 | + n_structs = len(axs.flat) |
| 48 | + struct_col = { |
| 49 | + "initial": "initial_structure", |
| 50 | + "final": "computed_structure_entry", |
| 51 | + }[init_final] |
37 | 52 |
|
38 |
| - errs = ( |
39 |
| - df_each_err.mean_ae.nsmallest(n_axs) |
40 |
| - if which == "best" |
41 |
| - else df_each_err.mean_ae.nlargest(n_axs) |
| 53 | + errs = { |
| 54 | + "best": df_each_err[mean_ae_col].nsmallest(n_structs), |
| 55 | + "worst": df_each_err[mean_ae_col].nlargest(n_structs), |
| 56 | + }[good_bad] |
| 57 | + title = ( |
| 58 | + f"{good_bad.title()} {len(errs)} {init_final} structures (across " |
| 59 | + f"{len(list(df_each_pred))} models)\nErrors in (ev/atom)" |
42 | 60 | )
|
43 |
| - title = f"{which} {len(errs)} structures (across {len(list(df_each_pred))} models)" |
44 |
| - fig.suptitle(title, fontsize=16, fontweight="bold", y=0.95) |
| 61 | + fig.suptitle(title, fontsize=20, fontweight="bold", y=1.05) |
45 | 62 |
|
46 |
| - for idx, (ax, (id, err)) in enumerate(zip(axs.flat, errs.items()), 1): |
47 |
| - struct = Structure.from_dict( |
48 |
| - df_cse.computed_structure_entry.loc[id]["structure"] |
49 |
| - ) |
| 63 | + for idx, (ax, (id, error)) in enumerate(zip(axs.flat, errs.items()), 1): |
| 64 | + struct = df_wbm[struct_col].loc[id] |
| 65 | + if init_final == "relaxed": |
| 66 | + struct = struct["structure"] |
| 67 | + struct = Structure.from_dict(struct) |
50 | 68 | plot_structure_2d(struct, ax=ax)
|
51 | 69 | _, spg_num = struct.get_space_group_info()
|
52 | 70 | formula = struct.composition.reduced_formula
|
53 | 71 | ax.set_title(
|
54 |
| - f"{idx}. {formula} (spg={spg_num})\n{id} {err=:.2f}", fontweight="bold" |
| 72 | + f"{idx}. {formula} (spg={spg_num})\n{id} {error=:.2f}", fontweight="bold" |
55 | 73 | )
|
| 74 | + out_path = f"{ROOT}/tmp/figures/{good_bad}-{len(errs)}-structures-{init_final}.webp" |
| 75 | + fig.savefig(out_path, dpi=300) |
| 76 | + |
| 77 | + |
| 78 | +# %% |
| 79 | +n_structs = 100 |
| 80 | +worst_ids = df_each_err[mean_ae_col].nlargest(n_structs).index.tolist() |
| 81 | +best_ids = df_each_err[mean_ae_col].nsmallest(n_structs).index.tolist() |
| 82 | + |
| 83 | +best_init_structs = df_wbm.initial_structure.loc[best_ids].map(Structure.from_dict) |
| 84 | +worst_init_structs = df_wbm.initial_structure.loc[worst_ids].map(Structure.from_dict) |
| 85 | +best_final_structs = df_wbm.computed_structure_entry.loc[best_ids].map( |
| 86 | + lambda cse: Structure.from_dict(cse["structure"]) |
| 87 | +) |
| 88 | +worst_final_structs = df_wbm.computed_structure_entry.loc[worst_ids].map( |
| 89 | + lambda cse: Structure.from_dict(cse["structure"]) |
| 90 | +) |
| 91 | + |
| 92 | + |
| 93 | +# %% |
| 94 | +cnn_fp = CrystalNNFingerprint.from_preset("ops") |
| 95 | +site_stats_fp = SiteStatsFingerprint( |
| 96 | + cnn_fp, stats=("mean", "std_dev", "minimum", "maximum") |
| 97 | +) |
| 98 | + |
| 99 | +worst_fp_diff_norms = ( |
| 100 | + worst_final_structs.map(site_stats_fp.featurize).map(np.array) |
| 101 | + - worst_init_structs.map(site_stats_fp.featurize).map(np.array) |
| 102 | +).map(np.linalg.norm) |
56 | 103 |
|
57 |
| - fig.savefig(f"{ROOT}/tmp/figures/{which}-{len(errs)}-structures.webp", dpi=300) |
| 104 | +best_fp_diff_norms = ( |
| 105 | + best_final_structs.map(site_stats_fp.featurize).map(np.array) |
| 106 | + - best_init_structs.map(site_stats_fp.featurize).map(np.array) |
| 107 | +).map(np.linalg.norm) |
| 108 | + |
| 109 | +df_fp = pd.DataFrame( |
| 110 | + [worst_fp_diff_norms.values, best_fp_diff_norms.values], |
| 111 | + index=["highest-error structures", "lowest-error structures"], |
| 112 | +).T |
| 113 | + |
| 114 | + |
| 115 | +# %% |
| 116 | +fig = df_fp.plot.hist(backend="plotly", nbins=50, barmode="overlay", opacity=0.8) |
| 117 | +title = ( |
| 118 | + f"SiteStatsFingerprint norm-diff between initial/final {n_structs}<br>" |
| 119 | + f"highest/lowest-error structures (mean over {len(list(df_each_pred))} models)" |
| 120 | +) |
| 121 | +fig.layout.title.update(text=title, font_size=20, xanchor="center", x=0.5) |
| 122 | +fig.layout.legend.update( |
| 123 | + title="", yanchor="top", y=0.98, xanchor="right", x=0.98, font_size=16 |
| 124 | +) |
| 125 | +fig.layout.xaxis.title = "|SSFP<sub>initial</sub> - SSFP<sub>final</sub>|" |
| 126 | +fig.show() |
| 127 | +fig.write_image( |
| 128 | + f"{ROOT}/tmp/figures/init-final-fp-diff-norms.webp", width=1000, scale=2 |
| 129 | +) |
58 | 130 |
|
59 | 131 |
|
60 | 132 | # %% plotly scatter plot of largest model errors with points sized by mean error and
|
|
99 | 171 |
|
100 | 172 |
|
101 | 173 | # %%
|
102 |
| -ptable_heatmap_plotly(df_preds[df_preds.all_false_pos].formula, colorscale="Viridis") |
103 |
| -ptable_heatmap_plotly(df_preds[df_preds.all_false_neg].formula, colorscale="Viridis") |
| 174 | +elem_counts: dict[str, pd.Series] = {} |
| 175 | +for col in ("all_false_neg", "all_false_pos"): |
| 176 | + elem_counts[col] = elem_counts.get(col, count_elements(df_preds.query(col).formula)) |
| 177 | + fig = ptable_heatmap_plotly(elem_counts[col], font_size=10) |
| 178 | + fig.layout.title = col |
| 179 | + fig.show() |
| 180 | + |
| 181 | + |
| 182 | +# %% scatter plot error by element against prevalence in training set |
| 183 | +df_mp = pd.read_csv(DATA_FILES.mp_energies, na_filter=False).set_index("material_id") |
| 184 | +# compute number of samples per element in training set |
| 185 | +# counting element occurrences not weighted by composition, assuming model don't learn |
| 186 | +# much more about iron and oxygen from Fe2O3 than from FeO |
| 187 | + |
| 188 | +count_col = "MP Occurrences" |
| 189 | +df_elem_err = count_elements(df_mp.formula_pretty, count_mode="occurrence").to_frame( |
| 190 | + name=count_col |
| 191 | +) |
| 192 | + |
| 193 | +title = "Number of MP structures containing each element" |
| 194 | +fig = df_elem_err[count_col].plot.bar(backend="plotly", title=title) |
| 195 | +fig.update_layout(showlegend=False) |
| 196 | +fig.show() |
| 197 | + |
| 198 | +fig = ptable_heatmap_plotly(df_elem_err[count_col], font_size=10) |
| 199 | +fig.layout.title.update(text=title, x=0.35, y=0.9, font_size=20) |
| 200 | +fig.show() |
| 201 | + |
| 202 | + |
| 203 | +# %% map average model error onto elements |
| 204 | +df_summary["fractional_composition"] = [ |
| 205 | + Composition(comp).fractional_composition for comp in tqdm(df_summary.formula) |
| 206 | +] |
| 207 | + |
| 208 | +df_frac_comp = pd.json_normalize( |
| 209 | + [comp.as_dict() for comp in df_summary["fractional_composition"]] |
| 210 | +).set_index(df_summary.index) |
| 211 | +assert all( |
| 212 | + df_frac_comp.sum(axis=1).round(6) == 1 |
| 213 | +), "composition fractions don't sum to 1" |
| 214 | + |
| 215 | +(len(df_frac_comp) - df_frac_comp.isna().sum()).sort_values().plot.bar(backend="plotly") |
| 216 | + |
| 217 | +# df_frac_comp = df_frac_comp.dropna(axis=1, thresh=100) # remove Xe with only 1 entry |
| 218 | + |
| 219 | + |
| 220 | +# %% |
| 221 | +for model in (*df_metrics, mean_ae_col): |
| 222 | + df_elem_err[model] = ( |
| 223 | + df_frac_comp * df_each_err[model].abs().values[:, None] |
| 224 | + ).mean() |
| 225 | + fig = ptable_heatmap_plotly( |
| 226 | + df_elem_err[model], |
| 227 | + precision=".2f", |
| 228 | + fill_value=None, |
| 229 | + cbar_max=0.2, |
| 230 | + colorscale="Turbo", |
| 231 | + ) |
| 232 | + fig.layout.title.update(text=model, x=0.35, y=0.9, font_size=20) |
| 233 | + fig.show() |
| 234 | + |
| 235 | + |
| 236 | +# %% |
| 237 | +df_elem_err.to_json(f"{MODELS}/per-element/per-element-model-each-errors.json") |
| 238 | + |
| 239 | + |
| 240 | +# %% |
| 241 | +df_elem_err["elem_name"] = [Element(el).long_name for el in df_elem_err.index] |
| 242 | +fig = df_elem_err.plot.scatter( |
| 243 | + x=count_col, |
| 244 | + y=mean_ae_col, |
| 245 | + backend="plotly", |
| 246 | + hover_name="elem_name", |
| 247 | + text=df_elem_err.index.where( |
| 248 | + (df_elem_err[mean_ae_col] > 0.04) | (df_elem_err[count_col] > 10_000) |
| 249 | + ), |
| 250 | + title="Correlation between element-error and element-occurrence in<br>training " |
| 251 | + f"set: {df_elem_err[mean_ae_col].corr(df_elem_err[count_col]):.2f}", |
| 252 | + hover_data={mean_ae_col: ":.2f", count_col: ":,.0f"}, |
| 253 | +) |
| 254 | + |
| 255 | +fig.update_traces(textposition="top center") |
| 256 | +fig.show() |
| 257 | + |
| 258 | +# save_fig(fig, f"{ROOT}/tmp/figures/element-occu-vs-err.webp", scale=2) |
| 259 | +# save_fig(fig, f"{ROOT}/tmp/figures/element-occu-vs-err.pdf") |
104 | 260 |
|
105 | 261 |
|
106 | 262 | # %%
|
|
0 commit comments