|
17 | 17 | plt.rc("savefig", bbox="tight", dpi=200)
|
18 | 18 | plt.rcParams["figure.constrained_layout.use"] = True
|
19 | 19 | plt.rc("figure", dpi=150)
|
20 |
| -plt.rc("font", size=18) |
| 20 | +plt.rc("font", size=16) |
21 | 21 |
|
22 | 22 |
|
23 | 23 | # %%
|
|
31 | 31 | f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
|
32 | 32 | ).set_index("material_id")
|
33 | 33 |
|
34 |
| -dfs["m3gnet"] = pd.read_json( |
35 |
| - f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results.json.gz" |
| 34 | +dfs["M3GNet"] = pd.read_json( |
| 35 | + f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz" |
36 | 36 | ).set_index("material_id")
|
37 | 37 |
|
| 38 | +dfs["Wrenformer"] = pd.read_csv( |
| 39 | + f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2" |
| 40 | +).set_index("material_id") |
| 41 | + |
| 42 | +# dfs["Wrenformer"]["e_form_target"] = dfs["Wren"]["e_form_target"] |
| 43 | +# dfs["M3GNet"]["e_form_target"] = dfs["Wren"]["e_form_target"] |
| 44 | + |
38 | 45 |
|
39 | 46 | # %%
|
40 | 47 | fig, ax = plt.subplots(1, 1, figsize=(10, 9))
|
41 | 48 |
|
42 | 49 | for model_name, color in zip(
|
43 |
| - ("Wren", "CGCNN", "Voronoi", "M3GNet"), |
44 |
| - ("tab:blue", "tab:orange", "tab:red", "tab:green"), |
| 50 | + ("Wren", "CGCNN", "Voronoi", "M3GNet", "Wrenformer"), |
| 51 | + ("tab:blue", "tab:orange", "teal", "tab:pink", "black"), |
| 52 | + strict=True, |
45 | 53 | ):
|
46 | 54 | df = dfs[model_name]
|
47 |
| - df = df.rename(columns={"e_form_wbm": "e_form_target"}) |
48 |
| - |
49 | 55 | df["e_above_mp_hull"] = df_hull.e_above_mp_hull
|
50 | 56 |
|
51 | 57 | assert df.e_above_mp_hull.isna().sum() == 0
|
|
62 | 68 |
|
63 | 69 | e_above_mp_hull = df.e_above_mp_hull
|
64 | 70 |
|
65 |
| - if df.filter(regex=r"_pred_\d").shape[1] > 1: |
66 |
| - assert df.filter(regex=r"_pred_\d").shape[1] == 10 |
67 |
| - |
68 |
| - model_preds = df.filter(regex=r"_pred_\d").mean(axis=1) |
69 |
| - |
70 |
| - elif model_name == "M3GNet": |
71 |
| - model_preds = df.e_form_m3gnet |
72 |
| - else: |
73 |
| - model_preds = df.e_form_pred |
74 |
| - |
75 |
| - residual = model_preds - df[target_col] + e_above_mp_hull |
| 71 | + try: |
| 72 | + if model_name == "M3GNet": |
| 73 | + model_preds = df.e_form_m3gnet |
| 74 | + targets = df.e_form_wbm |
| 75 | + elif model_name == "Wrenformer": |
| 76 | + model_preds = df.e_form_pred_ens |
| 77 | + targets = df.e_form |
| 78 | + elif df.filter(regex=r"_pred_\d").shape[1] > 1: |
| 79 | + assert df.filter(regex=r"_pred_\d").shape[1] == 10 |
| 80 | + model_preds = df.filter(regex=r"_pred_\d").mean(axis=1) |
| 81 | + targets = df.e_form_target |
| 82 | + elif "e_form_pred" in df and "e_form_target" in df: |
| 83 | + model_preds = df.e_form_pred |
| 84 | + targets = df.e_form_target |
| 85 | + else: |
| 86 | + raise ValueError(f"Unhandled {model_name = }") |
| 87 | + except AttributeError as exc: |
| 88 | + raise KeyError(f"{model_name = }") from exc |
| 89 | + |
| 90 | + df["residual"] = model_preds - targets + df.e_above_mp_hull |
| 91 | + df = df.sort_values(by="residual") |
76 | 92 |
|
77 | 93 | # epistemic_var = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
|
78 | 94 |
|
79 | 95 | # aleatoric_var = (df.filter(like="_ale_") ** 2).mean(axis=1)
|
80 | 96 |
|
81 |
| - # full_std = (epistemic_var + aleatoric_var) ** 0.5 |
| 97 | + # std_total = (epistemic_var + aleatoric_var) ** 0.5 |
82 | 98 |
|
83 | 99 | # criterion = "std"
|
84 |
| - # test = residual + full_std |
| 100 | + # test = df.residual + std_total |
85 | 101 |
|
86 | 102 | # criterion = "neg"
|
87 |
| - # test = residual - full_std |
| 103 | + # test = df.residual - std_total |
88 | 104 |
|
89 | 105 | criterion = "energy"
|
90 | 106 |
|
91 |
| - # thresh = 0.02 |
92 |
| - thresh = 0 |
93 |
| - # thresh = 0.10 |
| 107 | + # stability_thresh = 0.02 |
| 108 | + stability_thresh = 0 |
| 109 | + # stability_thresh = 0.10 |
94 | 110 |
|
95 |
| - n_true_pos = len( |
96 |
| - e_above_mp_hull[(e_above_mp_hull <= thresh) & (residual <= thresh)] |
| 111 | + true_pos_mask = (df.e_above_mp_hull <= stability_thresh) & ( |
| 112 | + df.residual <= stability_thresh |
97 | 113 | )
|
98 |
| - n_false_neg = len( |
99 |
| - e_above_mp_hull[(e_above_mp_hull <= thresh) & (residual > thresh)] |
| 114 | + false_neg_mask = (df.e_above_mp_hull <= stability_thresh) & ( |
| 115 | + df.residual > stability_thresh |
| 116 | + ) |
| 117 | + false_pos_mask = (df.e_above_mp_hull > stability_thresh) & ( |
| 118 | + df.residual <= stability_thresh |
100 | 119 | )
|
101 | 120 |
|
102 |
| - n_total_pos = n_true_pos + n_false_neg |
103 |
| - |
104 |
| - sort = np.argsort(residual) |
105 |
| - e_above_mp_hull = e_above_mp_hull[sort] |
106 |
| - residual = residual[sort] |
107 |
| - |
108 |
| - e_type = "pred" |
109 |
| - true_pos_cumsum = ((e_above_mp_hull <= thresh) & (residual <= thresh)).cumsum() |
110 |
| - false_neg_cumsum = ((e_above_mp_hull <= thresh) & (residual > thresh)).cumsum() |
111 |
| - false_pos_cumsum = ((e_above_mp_hull > thresh) & (residual <= thresh)).cumsum() |
112 |
| - true_neg_cumsum = ((e_above_mp_hull > thresh) & (residual > thresh)).cumsum() |
| 121 | + energy_type = "pred" |
| 122 | + true_pos_cumsum = true_pos_mask.cumsum() |
113 | 123 | xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
|
114 | 124 |
|
115 |
| - ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_cumsum) * 100 |
| 125 | + ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_mask.cumsum()) * 100 |
| 126 | + n_true_pos = sum(true_pos_mask) |
| 127 | + n_false_neg = sum(false_neg_mask) |
| 128 | + n_total_pos = n_true_pos + n_false_neg |
116 | 129 | tpr = true_pos_cumsum / n_total_pos * 100
|
117 | 130 |
|
118 |
| - end = np.argmax(tpr) |
| 131 | + end = int(np.argmax(tpr)) |
119 | 132 |
|
120 |
| - x = np.arange(len(ppv))[:end] |
| 133 | + xs = np.arange(end) |
121 | 134 |
|
122 |
| - precision_curve = interp1d(x, ppv[:end], kind="cubic") |
123 |
| - rolling_recall_curve = interp1d(x, tpr[:end], kind="cubic") |
| 135 | + precision_curve = interp1d(xs, ppv[:end], kind="cubic") |
| 136 | + rolling_recall_curve = interp1d(xs, tpr[:end], kind="cubic") |
124 | 137 |
|
125 | 138 | line_kwargs = dict(
|
126 |
| - linewidth=3, color=color, markevery=[-1], marker="x", markersize=14, mew=2.5 |
| 139 | + linewidth=3, |
| 140 | + color=color, |
| 141 | + markevery=[-1], |
| 142 | + marker="x", |
| 143 | + markersize=14, |
| 144 | + markeredgewidth=2.5, |
127 | 145 | )
|
128 |
| - ax.plot(x[::100], precision_curve(x[::100]), linestyle="-", **line_kwargs) |
129 |
| - ax.plot(x[::100], rolling_recall_curve(x[::100]), linestyle=":", **line_kwargs) |
| 146 | + ax.plot(xs, precision_curve(xs), linestyle="-", **line_kwargs) |
| 147 | + ax.plot(xs, rolling_recall_curve(xs), linestyle=":", **line_kwargs) |
130 | 148 | ax.plot((0, 0), (0, 0), label=model_name, **line_kwargs)
|
131 | 149 |
|
132 | 150 |
|
|
140 | 158 | [precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
|
141 | 159 | [recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")
|
142 | 160 | ax.legend(
|
143 |
| - [precision, recall], ["Precision", "Recall"], frameon=False, loc="upper right" |
| 161 | + [precision, recall], ("Precision", "Recall"), frameon=False, loc="upper right" |
144 | 162 | )
|
145 | 163 |
|
146 | 164 | img_path = (
|
147 | 165 | f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-"
|
148 |
| - f"{e_type=}-{criterion=}-{rare=}.pdf" |
| 166 | + f"{energy_type=}-{criterion=}-{rare=}.pdf" |
149 | 167 | )
|
150 | 168 | # plt.savefig(img_path)
|
0 commit comments