Skip to content

Commit 7ed1b09

Browse files
committed
add precision_recall_vs_calc_count() to plot_funcs.py extracted from plot_scripts/precision_recall_vs_calc_count.py
1 parent 8d65a7a commit 7ed1b09

File tree

2 files changed

+117
-88
lines changed

2 files changed

+117
-88
lines changed

mb_discovery/plot_scripts/plot_funcs.py

+101-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import matplotlib.pyplot as plt
66
import numpy as np
77
import pandas as pd
8+
import scipy.interpolate
9+
import scipy.stats
810
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
9-
from scipy.stats import sem as std_err_of_mean
1011

1112

1213
__author__ = "Janosh Riebesell"
@@ -162,7 +163,7 @@ def rolling_mae_vs_hull_dist(
162163
if ax is None:
163164
ax = plt.gca()
164165

165-
ax_is_fresh = len(ax.lines) == 0
166+
is_fresh_ax = len(ax.lines) == 0
166167

167168
bins = np.arange(*x_lim, increment)
168169

@@ -175,15 +176,15 @@ def rolling_mae_vs_hull_dist(
175176

176177
mask = (df[e_above_hull_col] <= high) & (df[e_above_hull_col] > low)
177178
rolling_maes[idx] = df[residual_col].loc[mask].abs().mean()
178-
rolling_stds[idx] = std_err_of_mean(df[residual_col].loc[mask].abs())
179+
rolling_stds[idx] = scipy.stats.sem(df[residual_col].loc[mask].abs())
179180

180181
ax.plot(bins, rolling_maes, **kwargs)
181182

182183
ax.fill_between(
183184
bins, rolling_maes + rolling_stds, rolling_maes - rolling_stds, alpha=0.3
184185
)
185186

186-
if not ax_is_fresh:
187+
if not is_fresh_ax:
187188
# return earlier if all plot objects besides the line were already drawn by a
188189
# previous call
189190
return ax
@@ -249,3 +250,99 @@ def rolling_mae_vs_hull_dist(
249250
ax.set(xlim=x_lim, ylim=(0.0, 0.14))
250251

251252
return ax
253+
254+
255+
def precision_recall_vs_calc_count(
256+
df: pd.DataFrame,
257+
residual_col: str = "residual",
258+
e_above_hull_col: str = "e_above_hull",
259+
criterion: Literal["energy", "std", "neg_std"] = "energy",
260+
stability_thresh: float = 0, # set stability threshold as distance to convex hull
261+
# in eV / atom, usually 0 or 0.1 eV
262+
ax: plt.Axes = None,
263+
label: str = None,
264+
**kwargs: Any,
265+
) -> plt.Axes:
266+
"""Precision and recall as a function of the number of calculations performed."""
267+
if ax is None:
268+
ax = plt.gca()
269+
270+
is_fresh_ax = len(ax.lines) == 0
271+
272+
df = df.sort_values(by="residual")
273+
274+
if criterion == "energy":
275+
test = df[residual_col]
276+
elif "std" in criterion:
277+
# TODO column names to compute standard deviation from are currently hardcoded
278+
# needs to be updated when adding non-aviary models with uncertainty estimation
279+
var_aleatoric = (df.filter(like="_ale_") ** 2).mean(axis=1)
280+
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
281+
std_total = (var_epistemic + var_aleatoric) ** 0.5
282+
283+
if criterion == "std":
284+
test += std_total
285+
elif criterion == "neg_std":
286+
test -= std_total
287+
288+
# stability_thresh = 0.02
289+
stability_thresh = 0
290+
# stability_thresh = 0.10
291+
292+
true_pos_mask = (df[e_above_hull_col] <= stability_thresh) & (
293+
df.residual <= stability_thresh
294+
)
295+
false_neg_mask = (df[e_above_hull_col] <= stability_thresh) & (
296+
df.residual > stability_thresh
297+
)
298+
false_pos_mask = (df[e_above_hull_col] > stability_thresh) & (
299+
df.residual <= stability_thresh
300+
)
301+
302+
true_pos_cumsum = true_pos_mask.cumsum()
303+
304+
ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_mask.cumsum()) * 100
305+
n_true_pos = sum(true_pos_mask)
306+
n_false_neg = sum(false_neg_mask)
307+
n_total_pos = n_true_pos + n_false_neg
308+
tpr = true_pos_cumsum / n_total_pos * 100
309+
310+
end = int(np.argmax(tpr))
311+
312+
xs = np.arange(end)
313+
314+
precision_curve = scipy.interpolate.interp1d(xs, ppv[:end], kind="cubic")
315+
rolling_recall_curve = scipy.interpolate.interp1d(xs, tpr[:end], kind="cubic")
316+
317+
line_kwargs = dict(
318+
linewidth=3,
319+
markevery=[-1],
320+
marker="x",
321+
markersize=14,
322+
markeredgewidth=2.5,
323+
**kwargs,
324+
)
325+
ax.plot(xs, precision_curve(xs), linestyle="-", **line_kwargs)
326+
ax.plot(xs, rolling_recall_curve(xs), linestyle=":", **line_kwargs)
327+
ax.plot((0, 0), (0, 0), label=label, **line_kwargs)
328+
329+
if not is_fresh_ax:
330+
# return earlier if all plot objects besides the line were already drawn by a
331+
# previous call
332+
return ax
333+
334+
ax.set(xlabel="Number of Calculations", ylabel="Percentage")
335+
336+
ax.set(xlim=(0, 8e4), ylim=(0, 100))
337+
338+
[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
339+
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")
340+
legend = ax.legend(
341+
[precision, recall],
342+
("Precision", "Recall"),
343+
frameon=False,
344+
loc="upper right",
345+
)
346+
ax.add_artist(legend)
347+
348+
return ax

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+16-84
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from datetime import datetime
33

44
import matplotlib.pyplot as plt
5-
import numpy as np
65
import pandas as pd
7-
from scipy.interpolate import interp1d
86

97
from mb_discovery import ROOT
8+
from mb_discovery.plot_scripts.plot_funcs import precision_recall_vs_calc_count
109

1110

1211
__author__ = "Rhys Goodall, Janosh Riebesell"
@@ -31,27 +30,22 @@
3130
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
3231
).set_index("material_id")
3332

34-
dfs["M3GNet"] = pd.read_json(
35-
f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
36-
).set_index("material_id")
33+
# dfs["M3GNet"] = pd.read_json(
34+
# f"{ROOT}/data/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
35+
# ).set_index("material_id")
3736

38-
dfs["Wrenformer"] = pd.read_csv(
39-
f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
40-
).set_index("material_id")
37+
# dfs["Wrenformer"] = pd.read_csv(
38+
# f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
39+
# ).set_index("material_id")
4140

4241
# dfs["Wrenformer"]["e_form_target"] = dfs["Wren"]["e_form_target"]
4342
# dfs["M3GNet"]["e_form_target"] = dfs["Wren"]["e_form_target"]
4443

4544

4645
# %%
47-
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
48-
49-
for model_name, color in zip(
50-
("Wren", "CGCNN", "Voronoi", "M3GNet", "Wrenformer"),
51-
("tab:blue", "tab:orange", "teal", "tab:pink", "black"),
52-
strict=True,
46+
for (model_name, df), color in zip(
47+
dfs.items(), ("tab:blue", "tab:orange", "teal", "tab:pink", "black")
5348
):
54-
df = dfs[model_name]
5549
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
5650

5751
assert df.e_above_mp_hull.isna().sum() == 0
@@ -88,81 +82,19 @@
8882
raise KeyError(f"{model_name = }") from exc
8983

9084
df["residual"] = model_preds - targets + df.e_above_mp_hull
91-
df = df.sort_values(by="residual")
92-
93-
# epistemic_var = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
94-
95-
# aleatoric_var = (df.filter(like="_ale_") ** 2).mean(axis=1)
96-
97-
# std_total = (epistemic_var + aleatoric_var) ** 0.5
98-
99-
# criterion = "std"
100-
# test = df.residual + std_total
101-
102-
# criterion = "neg"
103-
# test = df.residual - std_total
104-
105-
criterion = "energy"
106-
107-
# stability_thresh = 0.02
108-
stability_thresh = 0
109-
# stability_thresh = 0.10
110-
111-
true_pos_mask = (df.e_above_mp_hull <= stability_thresh) & (
112-
df.residual <= stability_thresh
113-
)
114-
false_neg_mask = (df.e_above_mp_hull <= stability_thresh) & (
115-
df.residual > stability_thresh
116-
)
117-
false_pos_mask = (df.e_above_mp_hull > stability_thresh) & (
118-
df.residual <= stability_thresh
119-
)
120-
121-
energy_type = "pred"
122-
true_pos_cumsum = true_pos_mask.cumsum()
123-
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
124-
125-
ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_mask.cumsum()) * 100
126-
n_true_pos = sum(true_pos_mask)
127-
n_false_neg = sum(false_neg_mask)
128-
n_total_pos = n_true_pos + n_false_neg
129-
tpr = true_pos_cumsum / n_total_pos * 100
130-
131-
end = int(np.argmax(tpr))
13285

133-
xs = np.arange(end)
134-
135-
precision_curve = interp1d(xs, ppv[:end], kind="cubic")
136-
rolling_recall_curve = interp1d(xs, tpr[:end], kind="cubic")
137-
138-
line_kwargs = dict(
139-
linewidth=3,
86+
ax = precision_recall_vs_calc_count(
87+
df,
88+
residual_col="residual",
89+
e_above_hull_col="e_above_mp_hull",
14090
color=color,
141-
markevery=[-1],
142-
marker="x",
143-
markersize=14,
144-
markeredgewidth=2.5,
91+
label=model_name,
14592
)
146-
ax.plot(xs, precision_curve(xs), linestyle="-", **line_kwargs)
147-
ax.plot(xs, rolling_recall_curve(xs), linestyle=":", **line_kwargs)
148-
ax.plot((0, 0), (0, 0), label=model_name, **line_kwargs)
149-
150-
151-
ax.set(xlabel="Number of Calculations", ylabel="Percentage")
152-
153-
ax.set(xlim=(0, 8e4), ylim=(0, 100))
15493

15594
model_legend = ax.legend(frameon=False, loc="lower right")
15695
ax.add_artist(model_legend)
15796

158-
[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
159-
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")
160-
ax.legend(
161-
[precision, recall], ("Precision", "Recall"), frameon=False, loc="upper right"
162-
)
97+
ax.figure.set_size_inches(10, 9)
16398

164-
img_path = (
165-
f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-"
166-
f"{energy_type=}-{criterion=}-{rare=}.pdf"
167-
)
99+
img_path = f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-{rare=}.pdf"
168100
# plt.savefig(img_path)

0 commit comments

Comments
 (0)