|
2 | 2 | from datetime import datetime
|
3 | 3 |
|
4 | 4 | import matplotlib.pyplot as plt
|
5 |
| -import numpy as np |
6 | 5 | import pandas as pd
|
7 |
| -from scipy.interpolate import interp1d |
8 | 6 |
|
9 | 7 | from mb_discovery import ROOT
|
| 8 | +from mb_discovery.plot_scripts.plot_funcs import precision_recall_vs_calc_count |
10 | 9 |
|
11 | 10 |
|
12 | 11 | __author__ = "Rhys Goodall, Janosh Riebesell"
|
|
31 | 30 | f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
|
32 | 31 | ).set_index("material_id")
|
33 | 32 |
|
34 |
| -dfs["M3GNet"] = pd.read_json( |
35 |
| - f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz" |
36 |
| -).set_index("material_id") |
| 33 | +# dfs["M3GNet"] = pd.read_json( |
| 34 | +# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz" |
| 35 | +# ).set_index("material_id") |
37 | 36 |
|
38 |
| -dfs["Wrenformer"] = pd.read_csv( |
39 |
| - f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2" |
40 |
| -).set_index("material_id") |
| 37 | +# dfs["Wrenformer"] = pd.read_csv( |
| 38 | +# f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2" |
| 39 | +# ).set_index("material_id") |
41 | 40 |
|
42 | 41 | # dfs["Wrenformer"]["e_form_target"] = dfs["Wren"]["e_form_target"]
|
43 | 42 | # dfs["M3GNet"]["e_form_target"] = dfs["Wren"]["e_form_target"]
|
44 | 43 |
|
45 | 44 |
|
46 | 45 | # %%
|
47 |
| -fig, ax = plt.subplots(1, 1, figsize=(10, 9)) |
48 |
| - |
49 |
| -for model_name, color in zip( |
50 |
| - ("Wren", "CGCNN", "Voronoi", "M3GNet", "Wrenformer"), |
51 |
| - ("tab:blue", "tab:orange", "teal", "tab:pink", "black"), |
52 |
| - strict=True, |
| 46 | +for (model_name, df), color in zip( |
| 47 | + dfs.items(), ("tab:blue", "tab:orange", "teal", "tab:pink", "black") |
53 | 48 | ):
|
54 |
| - df = dfs[model_name] |
55 | 49 | df["e_above_mp_hull"] = df_hull.e_above_mp_hull
|
56 | 50 |
|
57 | 51 | assert df.e_above_mp_hull.isna().sum() == 0
|
|
88 | 82 | raise KeyError(f"{model_name = }") from exc
|
89 | 83 |
|
90 | 84 | df["residual"] = model_preds - targets + df.e_above_mp_hull
|
91 |
| - df = df.sort_values(by="residual") |
92 |
| - |
93 |
| - # epistemic_var = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0) |
94 |
| - |
95 |
| - # aleatoric_var = (df.filter(like="_ale_") ** 2).mean(axis=1) |
96 |
| - |
97 |
| - # std_total = (epistemic_var + aleatoric_var) ** 0.5 |
98 |
| - |
99 |
| - # criterion = "std" |
100 |
| - # test = df.residual + std_total |
101 |
| - |
102 |
| - # criterion = "neg" |
103 |
| - # test = df.residual - std_total |
104 |
| - |
105 |
| - criterion = "energy" |
106 |
| - |
107 |
| - # stability_thresh = 0.02 |
108 |
| - stability_thresh = 0 |
109 |
| - # stability_thresh = 0.10 |
110 |
| - |
111 |
| - true_pos_mask = (df.e_above_mp_hull <= stability_thresh) & ( |
112 |
| - df.residual <= stability_thresh |
113 |
| - ) |
114 |
| - false_neg_mask = (df.e_above_mp_hull <= stability_thresh) & ( |
115 |
| - df.residual > stability_thresh |
116 |
| - ) |
117 |
| - false_pos_mask = (df.e_above_mp_hull > stability_thresh) & ( |
118 |
| - df.residual <= stability_thresh |
119 |
| - ) |
120 |
| - |
121 |
| - energy_type = "pred" |
122 |
| - true_pos_cumsum = true_pos_mask.cumsum() |
123 |
| - xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom" |
124 |
| - |
125 |
| - ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_mask.cumsum()) * 100 |
126 |
| - n_true_pos = sum(true_pos_mask) |
127 |
| - n_false_neg = sum(false_neg_mask) |
128 |
| - n_total_pos = n_true_pos + n_false_neg |
129 |
| - tpr = true_pos_cumsum / n_total_pos * 100 |
130 |
| - |
131 |
| - end = int(np.argmax(tpr)) |
132 | 85 |
|
133 |
| - xs = np.arange(end) |
134 |
| - |
135 |
| - precision_curve = interp1d(xs, ppv[:end], kind="cubic") |
136 |
| - rolling_recall_curve = interp1d(xs, tpr[:end], kind="cubic") |
137 |
| - |
138 |
| - line_kwargs = dict( |
139 |
| - linewidth=3, |
| 86 | + ax = precision_recall_vs_calc_count( |
| 87 | + df, |
| 88 | + residual_col="residual", |
| 89 | + e_above_hull_col="e_above_mp_hull", |
140 | 90 | color=color,
|
141 |
| - markevery=[-1], |
142 |
| - marker="x", |
143 |
| - markersize=14, |
144 |
| - markeredgewidth=2.5, |
| 91 | + label=model_name, |
145 | 92 | )
|
146 |
| - ax.plot(xs, precision_curve(xs), linestyle="-", **line_kwargs) |
147 |
| - ax.plot(xs, rolling_recall_curve(xs), linestyle=":", **line_kwargs) |
148 |
| - ax.plot((0, 0), (0, 0), label=model_name, **line_kwargs) |
149 |
| - |
150 |
| - |
151 |
| -ax.set(xlabel="Number of Calculations", ylabel="Percentage") |
152 |
| - |
153 |
| -ax.set(xlim=(0, 8e4), ylim=(0, 100)) |
154 | 93 |
|
155 | 94 | model_legend = ax.legend(frameon=False, loc="lower right")
|
156 | 95 | ax.add_artist(model_legend)
|
157 | 96 |
|
158 |
| -[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-") |
159 |
| -[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":") |
160 |
| -ax.legend( |
161 |
| - [precision, recall], ("Precision", "Recall"), frameon=False, loc="upper right" |
162 |
| -) |
| 97 | +ax.figure.set_size_inches(10, 9) |
163 | 98 |
|
164 |
| -img_path = ( |
165 |
| - f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-" |
166 |
| - f"{energy_type=}-{criterion=}-{rare=}.pdf" |
167 |
| -) |
| 99 | +img_path = f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-{rare=}.pdf" |
168 | 100 | # plt.savefig(img_path)
|
0 commit comments