|
1 | 1 | # %%
|
2 |
| -import pandas as pd |
3 | 2 | from sklearn.metrics import f1_score
|
4 | 3 |
|
5 | 4 | from matbench_discovery import ROOT, today
|
6 |
| -from matbench_discovery.plot_scripts import df_wbm |
| 5 | +from matbench_discovery.plot_scripts import load_df_wbm_with_preds |
7 | 6 | from matbench_discovery.plots import StabilityCriterion, cumulative_clf_metric, plt
|
8 | 7 |
|
9 | 8 | __author__ = "Rhys Goodall, Janosh Riebesell"
|
10 | 9 |
|
11 | 10 |
|
12 | 11 | # %%
|
13 |
| -dfs: dict[str, pd.DataFrame] = {} |
14 |
| -for model_name in ("wren", "cgcnn", "voronoi"): |
15 |
| - csv_path = ( |
16 |
| - f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv" |
17 |
| - ) |
18 |
| - df = pd.read_csv(csv_path).set_index("material_id") |
19 |
| - dfs[model_name] = df |
20 |
| - |
21 |
| -dfs["m3gnet"] = pd.read_json( |
22 |
| - f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz" |
23 |
| -).set_index("material_id") |
24 |
| - |
25 |
| -dfs["wrenformer"] = pd.read_csv( |
26 |
| - f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv" |
27 |
| -).set_index("material_id") |
| 12 | +models = ( |
| 13 | + "Wren, CGCNN IS2RE, CGCNN RS2RE, Voronoi IS2RE, Voronoi RS2RE, " |
| 14 | + "Wrenformer, MEGNet" |
| 15 | +).split(", ") |
28 | 16 |
|
29 |
| -dfs["bowsr_megnet"] = pd.read_json( |
30 |
| - f"{ROOT}/models/bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.json.gz" |
31 |
| -).set_index("material_id") |
| 17 | +df_wbm = load_df_wbm_with_preds(models=models).round(3) |
32 | 18 |
|
33 |
| -print(f"loaded models: {list(dfs)}") |
| 19 | +target_col = "e_form_per_atom_mp2020_corrected" |
| 20 | +e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp" |
34 | 21 |
|
35 | 22 |
|
36 | 23 | # %%
|
37 | 24 | stability_crit: StabilityCriterion = "energy"
|
38 | 25 | colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
|
39 |
| -F1s: dict[str, float] = {} |
40 |
| - |
41 |
| -for model_name, df in sorted(dfs.items()): |
42 |
| - if "std" in stability_crit: |
43 |
| - # TODO column names to compute standard deviation from are currently hardcoded |
44 |
| - # needs to be updated when adding non-aviary models with uncertainty estimation |
45 |
| - var_aleatoric = (df.filter(regex=r"_ale_\d") ** 2).mean(axis=1) |
46 |
| - var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0) |
47 |
| - std_total = (var_epistemic + var_aleatoric) ** 0.5 |
48 |
| - else: |
49 |
| - std_total = None |
50 |
| - |
51 |
| - try: |
52 |
| - if model_name == "m3gnet": |
53 |
| - model_preds = df.e_form_m3gnet |
54 |
| - elif "wrenformer" in model_name: |
55 |
| - model_preds = df.e_form_per_atom_pred_ens |
56 |
| - elif len(pred_cols := df.filter(like="e_form_pred").columns) >= 1: |
57 |
| - # Voronoi+RF has single prediction column, Wren and CGCNN each have 10 |
58 |
| - # other cases are unexpected |
59 |
| - assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}" |
60 |
| - model_preds = df[pred_cols].mean(axis=1) |
61 |
| - elif model_name == "bowsr_megnet": |
62 |
| - model_preds = df.e_form_per_atom_bowsr_megnet |
63 |
| - else: |
64 |
| - raise ValueError(f"Unhandled {model_name = }") |
65 |
| - except AttributeError as exc: |
66 |
| - raise KeyError(f"{model_name = }") from exc |
67 |
| - |
68 |
| - df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp |
69 |
| - df["e_form_per_atom"] = df_wbm.e_form_per_atom_mp2020_corrected |
70 |
| - df["e_above_hull_pred"] = model_preds - df.e_form_per_atom |
71 |
| - if n_nans := df.isna().values.sum() > 0: |
72 |
| - assert n_nans < 10, f"{model_name=} has {n_nans=}" |
73 |
| - df = df.dropna() |
74 |
| - |
75 |
| - F1 = f1_score(df.e_above_hull_mp < 0, df.e_above_hull_pred < 0) |
76 |
| - F1s[model_name] = F1 |
77 | 26 |
|
78 | 27 |
|
79 | 28 | # %%
|
80 | 29 | fig, (ax_prec, ax_recall) = plt.subplots(1, 2, figsize=(15, 7), sharey=True)
|
81 | 30 |
|
82 |
| -for (model_name, F1), color in zip(sorted(F1s.items(), key=lambda x: x[1]), colors): |
83 |
| - df = dfs[model_name] |
84 |
| - e_above_hull_error = df.e_above_hull_pred + df.e_above_hull_mp |
85 |
| - e_above_hull_true = df.e_above_hull_mp |
| 31 | +for model_name, color in zip(models, colors): |
| 32 | + |
| 33 | + e_above_hull_pred = df_wbm[model_name] - df_wbm[target_col] |
| 34 | + |
| 35 | + F1 = f1_score(df_wbm[e_above_hull_col] < 0, e_above_hull_pred < 0) |
| 36 | + |
| 37 | + e_above_hull_error = e_above_hull_pred + df_wbm[e_above_hull_col] |
86 | 38 | cumulative_clf_metric(
|
87 | 39 | e_above_hull_error,
|
88 |
| - e_above_hull_true, |
| 40 | + df_wbm[e_above_hull_col], |
89 | 41 | color=color,
|
90 |
| - label=f"{model_name}\n{F1=:.2}", |
| 42 | + label=f"{model_name}\n{F1=:.3}", |
91 | 43 | project_end_point="xy",
|
92 | 44 | stability_crit=stability_crit,
|
93 | 45 | ax=ax_prec,
|
|
96 | 48 |
|
97 | 49 | cumulative_clf_metric(
|
98 | 50 | e_above_hull_error,
|
99 |
| - e_above_hull_true, |
| 51 | + df_wbm[e_above_hull_col], |
100 | 52 | color=color,
|
101 |
| - label=f"{model_name}\n{F1=:.2}", |
| 53 | + label=f"{model_name}\n{F1=:.3}", |
102 | 54 | project_end_point="xy",
|
103 | 55 | stability_crit=stability_crit,
|
104 | 56 | ax=ax_recall,
|
|
0 commit comments