|
2 | 2 | from datetime import datetime
|
3 | 3 |
|
4 | 4 | import pandas as pd
|
| 5 | +from sklearn.metrics import f1_score |
5 | 6 |
|
6 | 7 | from mb_discovery import ROOT
|
7 | 8 | from mb_discovery.plots import StabilityCriterion, precision_recall_vs_calc_count
|
|
15 | 16 | # %%
|
16 | 17 | DATA_DIR = f"{ROOT}/data/2022-06-11-from-rhys"
|
17 | 18 | df_hull = pd.read_csv(f"{DATA_DIR}/wbm-e-above-mp-hull.csv").set_index("material_id")
|
| 19 | +rare = "all" |
18 | 20 |
|
19 | 21 | dfs: dict[str, pd.DataFrame] = {}
|
20 | 22 | for model_name in ("wren", "cgcnn", "voronoi"):
|
|
47 | 49 | # %%
|
48 | 50 | stability_crit: StabilityCriterion = "energy"
|
49 | 51 | colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
|
| 52 | +F1s: dict[str, float] = {} |
50 | 53 |
|
51 |
| -for (model_name, df), color in zip(dfs.items(), colors): |
52 |
| - rare = "all" |
53 |
| - |
| 54 | +for model_name, df in dfs.items(): |
54 | 55 | # from pymatgen.core import Composition
|
55 | 56 | # rare = "no-lanthanides"
|
56 | 57 | # df["contains_rare_earths"] = df.composition.map(
|
|
91 | 92 | assert n_nans < 10, f"{model_name=} has {n_nans=}"
|
92 | 93 | df = df.dropna()
|
93 | 94 |
|
| 95 | + F1 = f1_score(df.e_above_mp_hull < 0, df.e_above_hull_pred < 0) |
| 96 | + F1s[model_name] = F1 |
| 97 | + |
| 98 | + |
| 99 | +# %% |
| 100 | +for (model_name, F1), color in zip(sorted(F1s.items(), key=lambda x: x[1]), colors): |
| 101 | + df = dfs[model_name] |
| 102 | + |
94 | 103 | ax = precision_recall_vs_calc_count(
|
95 | 104 | e_above_hull_error=df.e_above_hull_pred + df.e_above_mp_hull,
|
96 | 105 | e_above_hull_true=df.e_above_mp_hull,
|
97 | 106 | color=color,
|
98 |
| - label=model_name, |
| 107 | + label=f"{model_name} {F1=:.2}", |
99 | 108 | intersect_lines="recall_xy", # or "precision_xy", None, 'all'
|
100 | 109 | stability_crit=stability_crit,
|
101 | 110 | std_pred=std_total,
|
102 | 111 | )
|
103 | 112 |
|
| 113 | +# optimal recall line finds all stable materials without any false positives |
| 114 | +# can be included to confirm all models start out of with near optimal recall |
| 115 | +# and to see how much each model overshoots total n_stable |
| 116 | +n_below_hull = sum(df_hull.e_above_mp_hull < 0) |
| 117 | +ax.plot( |
| 118 | + [0, n_below_hull], |
| 119 | + [0, 100], |
| 120 | + color="green", |
| 121 | + linestyle="dashed", |
| 122 | + linewidth=1, |
| 123 | + label="Optimal Recall", |
| 124 | +) |
| 125 | + |
104 | 126 | ax.figure.set_size_inches(10, 9)
|
105 | 127 | ax.set(xlim=(0, None))
|
106 | 128 | # keep this outside loop so all model names appear in legend
|
107 | 129 | ax.legend(frameon=False, loc="lower right")
|
108 | 130 |
|
109 | 131 | img_name = f"{today}-precision-recall-vs-calc-count-{rare=}"
|
110 | 132 | ax.set(title=img_name.replace("-", "/", 2).replace("-", " ").title())
|
| 133 | +# x-ticks every 10k materials |
| 134 | +ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000)) |
111 | 135 |
|
112 | 136 |
|
113 | 137 | # %%
|
|
0 commit comments