|
| 1 | +# %% |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import pandas as pd |
| 4 | +from pymatgen.core import Structure |
| 5 | +from pymatviz import plot_structure_2d, ptable_heatmap_plotly |
| 6 | + |
| 7 | +from matbench_discovery import ROOT |
| 8 | +from matbench_discovery.metrics import classify_stable |
| 9 | +from matbench_discovery.preds import df_each_err, df_each_pred, df_wbm, each_true_col |
| 10 | + |
| 11 | +__author__ = "Janosh Riebesell" |
| 12 | +__date__ = "2023-02-15" |
| 13 | + |
| 14 | +df_each_err[each_true_col] = df_wbm[each_true_col] |
| 15 | +mean_ae_col = "All models mean absolute error (eV/atom)" |
| 16 | +df_each_err[mean_ae_col] = df_wbm[mean_ae_col] = df_each_err.abs().mean(axis=1) |
| 17 | + |
| 18 | + |
| 19 | +# %% |
| 20 | +cse_path = f"{ROOT}/data/wbm/2022-10-19-wbm-computed-structure-entries.json.bz2" |
| 21 | +df_cse = pd.read_json(cse_path).set_index("material_id") |
| 22 | + |
| 23 | + |
| 24 | +# %% |
| 25 | +n_rows, n_cols = 5, 4 |
| 26 | +for which in ("best", "worst"): |
| 27 | + fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_rows, 4 * n_cols)) |
| 28 | + n_axs = len(axs.flat) |
| 29 | + |
| 30 | + errs = ( |
| 31 | + df_each_err.mean_ae.nsmallest(n_axs) |
| 32 | + if which == "best" |
| 33 | + else df_each_err.mean_ae.nlargest(n_axs) |
| 34 | + ) |
| 35 | + title = f"{which} {len(errs)} structures (across {len(list(df_each_pred))} models)" |
| 36 | + fig.suptitle(title, fontsize=16, fontweight="bold", y=0.95) |
| 37 | + |
| 38 | + for idx, (ax, (id, err)) in enumerate(zip(axs.flat, errs.items()), 1): |
| 39 | + struct = Structure.from_dict( |
| 40 | + df_cse.computed_structure_entry.loc[id]["structure"] |
| 41 | + ) |
| 42 | + plot_structure_2d(struct, ax=ax) |
| 43 | + _, spg_num = struct.get_space_group_info() |
| 44 | + formula = struct.composition.reduced_formula |
| 45 | + ax.set_title( |
| 46 | + f"{idx}. {formula} (spg={spg_num})\n{id} {err=:.2f}", fontweight="bold" |
| 47 | + ) |
| 48 | + |
| 49 | + fig.savefig(f"{ROOT}/tmp/figures/{which}-{len(errs)}-structures.webp", dpi=300) |
| 50 | + |
| 51 | + |
| 52 | +# %% plotly scatter plot of largest model errors with points sized by mean error and |
| 53 | +# colored by true stability |
| 54 | +fig = df_wbm.nlargest(200, mean_ae_col).plot.scatter( |
| 55 | + x=each_true_col, |
| 56 | + y=mean_ae_col, |
| 57 | + color=each_true_col, |
| 58 | + size=mean_ae_col, |
| 59 | + backend="plotly", |
| 60 | +) |
| 61 | +fig.layout.coloraxis.colorbar.update( |
| 62 | + title="DFT distance to convex hull (eV/atom)", |
| 63 | + title_side="top", |
| 64 | + yanchor="bottom", |
| 65 | + y=1, |
| 66 | + xanchor="center", |
| 67 | + x=0.5, |
| 68 | + orientation="h", |
| 69 | + thickness=12, |
| 70 | +) |
| 71 | +fig.show() |
| 72 | + |
| 73 | + |
| 74 | +# %% find materials that were misclassified by all models |
| 75 | +for model in df_each_pred: |
| 76 | + true_pos, false_neg, false_pos, true_neg = classify_stable( |
| 77 | + df_each_pred[model], df_wbm[each_true_col] |
| 78 | + ) |
| 79 | + df_wbm[f"{model}_true_pos"] = true_pos |
| 80 | + df_wbm[f"{model}_false_neg"] = false_neg |
| 81 | + df_wbm[f"{model}_false_pos"] = false_pos |
| 82 | + df_wbm[f"{model}_true_neg"] = true_neg |
| 83 | + |
| 84 | + |
| 85 | +df_wbm["all_true_pos"] = df_wbm.filter(like="_true_pos").all(axis=1) |
| 86 | +df_wbm["all_false_neg"] = df_wbm.filter(like="_false_neg").all(axis=1) |
| 87 | +df_wbm["all_false_pos"] = df_wbm.filter(like="_false_pos").all(axis=1) |
| 88 | +df_wbm["all_true_neg"] = df_wbm.filter(like="_true_neg").all(axis=1) |
| 89 | + |
| 90 | +df_wbm.filter(like="all_").sum() |
| 91 | + |
| 92 | + |
| 93 | +# %% |
| 94 | +ptable_heatmap_plotly(df_wbm[df_wbm.all_false_pos].formula, colorscale="Viridis") |
| 95 | +ptable_heatmap_plotly(df_wbm[df_wbm.all_false_neg].formula, colorscale="Viridis") |
| 96 | + |
| 97 | + |
| 98 | +# %% |
| 99 | +df_each_err.abs().mean().sort_values() |
| 100 | +df_each_err.abs().mean(axis=1).nlargest(25) |
| 101 | + |
| 102 | + |
| 103 | +# %% get mean distance to convex hull for each classification |
| 104 | +df_wbm.query("all_true_pos").describe() |
| 105 | +df_wbm.query("all_false_pos").describe() |
0 commit comments