Skip to content

Commit 8daf748

Browse files
committed
more refactors in precision_recall_vs_calc_count.py and moving_hull_dist_mae_compare_models.py
1 parent 90e1b89 commit 8daf748

File tree

2 files changed

+68
-50
lines changed

2 files changed

+68
-50
lines changed

mb_discovery/plot_scripts/plot_funcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
plt.rc("savefig", bbox="tight", dpi=200)
2525
plt.rcParams["figure.constrained_layout.use"] = True
2626
plt.rc("figure", dpi=150)
27-
plt.rc("font", size=14)
27+
plt.rc("font", size=16)
2828

2929

3030
def hist_classify_stable_as_func_of_hull_dist(

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+67-49
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
plt.rc("savefig", bbox="tight", dpi=200)
1818
plt.rcParams["figure.constrained_layout.use"] = True
1919
plt.rc("figure", dpi=150)
20-
plt.rc("font", size=18)
20+
plt.rc("font", size=16)
2121

2222

2323
# %%
@@ -31,21 +31,27 @@
3131
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
3232
).set_index("material_id")
3333

34-
dfs["m3gnet"] = pd.read_json(
35-
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results.json.gz"
34+
dfs["M3GNet"] = pd.read_json(
35+
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
3636
).set_index("material_id")
3737

38+
dfs["Wrenformer"] = pd.read_csv(
39+
f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
40+
).set_index("material_id")
41+
42+
# dfs["Wrenformer"]["e_form_target"] = dfs["Wren"]["e_form_target"]
43+
# dfs["M3GNet"]["e_form_target"] = dfs["Wren"]["e_form_target"]
44+
3845

3946
# %%
4047
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
4148

4249
for model_name, color in zip(
43-
("Wren", "CGCNN", "Voronoi", "M3GNet"),
44-
("tab:blue", "tab:orange", "tab:red", "tab:green"),
50+
("Wren", "CGCNN", "Voronoi", "M3GNet", "Wrenformer"),
51+
("tab:blue", "tab:orange", "teal", "tab:pink", "black"),
52+
strict=True,
4553
):
4654
df = dfs[model_name]
47-
df = df.rename(columns={"e_form_wbm": "e_form_target"})
48-
4955
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
5056

5157
assert df.e_above_mp_hull.isna().sum() == 0
@@ -62,71 +68,83 @@
6268

6369
e_above_mp_hull = df.e_above_mp_hull
6470

65-
if df.filter(regex=r"_pred_\d").shape[1] > 1:
66-
assert df.filter(regex=r"_pred_\d").shape[1] == 10
67-
68-
model_preds = df.filter(regex=r"_pred_\d").mean(axis=1)
69-
70-
elif model_name == "M3GNet":
71-
model_preds = df.e_form_m3gnet
72-
else:
73-
model_preds = df.e_form_pred
74-
75-
residual = model_preds - df[target_col] + e_above_mp_hull
71+
try:
72+
if model_name == "M3GNet":
73+
model_preds = df.e_form_m3gnet
74+
targets = df.e_form_wbm
75+
elif model_name == "Wrenformer":
76+
model_preds = df.e_form_pred_ens
77+
targets = df.e_form
78+
elif df.filter(regex=r"_pred_\d").shape[1] > 1:
79+
assert df.filter(regex=r"_pred_\d").shape[1] == 10
80+
model_preds = df.filter(regex=r"_pred_\d").mean(axis=1)
81+
targets = df.e_form_target
82+
elif "e_form_pred" in df and "e_form_target" in df:
83+
model_preds = df.e_form_pred
84+
targets = df.e_form_target
85+
else:
86+
raise ValueError(f"Unhandled {model_name = }")
87+
except AttributeError as exc:
88+
raise KeyError(f"{model_name = }") from exc
89+
90+
df["residual"] = model_preds - targets + df.e_above_mp_hull
91+
df = df.sort_values(by="residual")
7692

7793
# epistemic_var = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
7894

7995
# aleatoric_var = (df.filter(like="_ale_") ** 2).mean(axis=1)
8096

81-
# full_std = (epistemic_var + aleatoric_var) ** 0.5
97+
# std_total = (epistemic_var + aleatoric_var) ** 0.5
8298

8399
# criterion = "std"
84-
# test = residual + full_std
100+
# test = df.residual + std_total
85101

86102
# criterion = "neg"
87-
# test = residual - full_std
103+
# test = df.residual - std_total
88104

89105
criterion = "energy"
90106

91-
# thresh = 0.02
92-
thresh = 0
93-
# thresh = 0.10
107+
# stability_thresh = 0.02
108+
stability_thresh = 0
109+
# stability_thresh = 0.10
94110

95-
n_true_pos = len(
96-
e_above_mp_hull[(e_above_mp_hull <= thresh) & (residual <= thresh)]
111+
true_pos_mask = (df.e_above_mp_hull <= stability_thresh) & (
112+
df.residual <= stability_thresh
97113
)
98-
n_false_neg = len(
99-
e_above_mp_hull[(e_above_mp_hull <= thresh) & (residual > thresh)]
114+
false_neg_mask = (df.e_above_mp_hull <= stability_thresh) & (
115+
df.residual > stability_thresh
116+
)
117+
false_pos_mask = (df.e_above_mp_hull > stability_thresh) & (
118+
df.residual <= stability_thresh
100119
)
101120

102-
n_total_pos = n_true_pos + n_false_neg
103-
104-
sort = np.argsort(residual)
105-
e_above_mp_hull = e_above_mp_hull[sort]
106-
residual = residual[sort]
107-
108-
e_type = "pred"
109-
true_pos_cumsum = ((e_above_mp_hull <= thresh) & (residual <= thresh)).cumsum()
110-
false_neg_cumsum = ((e_above_mp_hull <= thresh) & (residual > thresh)).cumsum()
111-
false_pos_cumsum = ((e_above_mp_hull > thresh) & (residual <= thresh)).cumsum()
112-
true_neg_cumsum = ((e_above_mp_hull > thresh) & (residual > thresh)).cumsum()
121+
energy_type = "pred"
122+
true_pos_cumsum = true_pos_mask.cumsum()
113123
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
114124

115-
ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_cumsum) * 100
125+
ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_mask.cumsum()) * 100
126+
n_true_pos = sum(true_pos_mask)
127+
n_false_neg = sum(false_neg_mask)
128+
n_total_pos = n_true_pos + n_false_neg
116129
tpr = true_pos_cumsum / n_total_pos * 100
117130

118-
end = np.argmax(tpr)
131+
end = int(np.argmax(tpr))
119132

120-
x = np.arange(len(ppv))[:end]
133+
xs = np.arange(end)
121134

122-
precision_curve = interp1d(x, ppv[:end], kind="cubic")
123-
rolling_recall_curve = interp1d(x, tpr[:end], kind="cubic")
135+
precision_curve = interp1d(xs, ppv[:end], kind="cubic")
136+
rolling_recall_curve = interp1d(xs, tpr[:end], kind="cubic")
124137

125138
line_kwargs = dict(
126-
linewidth=3, color=color, markevery=[-1], marker="x", markersize=14, mew=2.5
139+
linewidth=3,
140+
color=color,
141+
markevery=[-1],
142+
marker="x",
143+
markersize=14,
144+
markeredgewidth=2.5,
127145
)
128-
ax.plot(x[::100], precision_curve(x[::100]), linestyle="-", **line_kwargs)
129-
ax.plot(x[::100], rolling_recall_curve(x[::100]), linestyle=":", **line_kwargs)
146+
ax.plot(xs, precision_curve(xs), linestyle="-", **line_kwargs)
147+
ax.plot(xs, rolling_recall_curve(xs), linestyle=":", **line_kwargs)
130148
ax.plot((0, 0), (0, 0), label=model_name, **line_kwargs)
131149

132150

@@ -140,11 +158,11 @@
140158
[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
141159
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")
142160
ax.legend(
143-
[precision, recall], ["Precision", "Recall"], frameon=False, loc="upper right"
161+
[precision, recall], ("Precision", "Recall"), frameon=False, loc="upper right"
144162
)
145163

146164
img_path = (
147165
f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-"
148-
f"{e_type=}-{criterion=}-{rare=}.pdf"
166+
f"{energy_type=}-{criterion=}-{rare=}.pdf"
149167
)
150168
# plt.savefig(img_path)

0 commit comments

Comments
 (0)