Skip to content

Commit 8d65a7a

Browse files
committed
extract plot func rolling_mae_vs_hull_dist() from rolling MAE plot scripts
1 parent 967b482 commit 8d65a7a

6 files changed

+266
-5
lines changed

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
plt.rc("savefig", bbox="tight", dpi=200)
3131
plt.rcParams["figure.constrained_layout.use"] = True
32-
plt.rc("figure", dpi=150)
32+
plt.rc("figure", dpi=200)
3333
plt.rc("font", size=16)
3434

3535

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
plt.rc("savefig", bbox="tight", dpi=200)
3030
plt.rcParams["figure.constrained_layout.use"] = True
31-
plt.rc("figure", dpi=150)
31+
plt.rc("figure", dpi=200)
3232
plt.rc("font", size=16)
3333

3434

mb_discovery/plot_scripts/plot_funcs.py

+114-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Literal, Sequence
3+
from typing import Any, Literal, Sequence
44

55
import matplotlib.pyplot as plt
6+
import numpy as np
67
import pandas as pd
8+
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
9+
from scipy.stats import sem as std_err_of_mean
710

811

912
__author__ = "Janosh Riebesell"
@@ -12,7 +15,7 @@
1215

1316
plt.rc("savefig", bbox="tight", dpi=200)
1417
plt.rcParams["figure.constrained_layout.use"] = True
15-
plt.rc("figure", dpi=150)
18+
plt.rc("figure", dpi=200)
1619
plt.rc("font", size=16)
1720

1821

@@ -137,3 +140,112 @@ def hist_classified_stable_as_func_of_hull_dist(
137140
ax.set(xlabel=xlabel, ylabel="Number of compounds")
138141

139142
return ax
143+
144+
145+
def rolling_mae_vs_hull_dist(
146+
df: pd.DataFrame,
147+
e_above_hull_col: str,
148+
residual_col: str = "residual",
149+
half_window: float = 0.02,
150+
increment: float = 0.002,
151+
x_lim: tuple[float, float] = (-0.2, 0.3),
152+
ax: plt.Axes = None,
153+
**kwargs: Any,
154+
) -> plt.Axes:
155+
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
156+
bar is shown for the windowing period of 40 meV per atom used when calculating
157+
the rolling MAE. The standard error in the mean is shaded
158+
around each curve. The highlighted V-shaped region shows the area in which the
159+
average absolute error is greater than the energy to the known convex hull. This is
160+
where models are most at risk of misclassifying structures.
161+
"""
162+
if ax is None:
163+
ax = plt.gca()
164+
165+
ax_is_fresh = len(ax.lines) == 0
166+
167+
bins = np.arange(*x_lim, increment)
168+
169+
rolling_maes = np.zeros_like(bins)
170+
rolling_stds = np.zeros_like(bins)
171+
df = df.sort_values(by=e_above_hull_col)
172+
for idx, bin_center in enumerate(bins):
173+
low = bin_center - half_window
174+
high = bin_center + half_window
175+
176+
mask = (df[e_above_hull_col] <= high) & (df[e_above_hull_col] > low)
177+
rolling_maes[idx] = df[residual_col].loc[mask].abs().mean()
178+
rolling_stds[idx] = std_err_of_mean(df[residual_col].loc[mask].abs())
179+
180+
ax.plot(bins, rolling_maes, **kwargs)
181+
182+
ax.fill_between(
183+
bins, rolling_maes + rolling_stds, rolling_maes - rolling_stds, alpha=0.3
184+
)
185+
186+
if not ax_is_fresh:
187+
# return earlier if all plot objects besides the line were already drawn by a
188+
# previous call
189+
return ax
190+
191+
scale_bar = AnchoredSizeBar(
192+
ax.transData,
193+
2 * half_window,
194+
"40 meV",
195+
"lower left",
196+
pad=0.5,
197+
frameon=False,
198+
size_vertical=0.002,
199+
)
200+
201+
ax.add_artist(scale_bar)
202+
203+
ax.plot((0.05, 0.5), (0.05, 0.5), color="grey", linestyle="--", alpha=0.3)
204+
ax.plot((-0.5, -0.05), (0.5, 0.05), color="grey", linestyle="--", alpha=0.3)
205+
ax.plot((-0.05, 0.05), (0.05, 0.05), color="grey", linestyle="--", alpha=0.3)
206+
ax.plot((-0.1, 0.1), (0.1, 0.1), color="grey", linestyle="--", alpha=0.3)
207+
208+
ax.fill_between(
209+
(-0.5, -0.05, 0.05, 0.5),
210+
(0.5, 0.5, 0.5, 0.5),
211+
(0.5, 0.05, 0.05, 0.5),
212+
color="tab:red",
213+
alpha=0.2,
214+
)
215+
216+
ax.plot((0, 0.05), (0, 0.05), color="grey", linestyle="--", alpha=0.3)
217+
ax.plot((-0.05, 0), (0.05, 0), color="grey", linestyle="--", alpha=0.3)
218+
219+
ax.fill_between(
220+
(-0.05, 0, 0.05),
221+
(0.05, 0.05, 0.05),
222+
(0.05, 0, 0.05),
223+
color="tab:orange",
224+
alpha=0.2,
225+
)
226+
227+
arrowprops = dict(facecolor="black", width=0.5, headwidth=5, headlength=5)
228+
ax.annotate(
229+
xy=(0.055, 0.05),
230+
xytext=(0.12, 0.05),
231+
arrowprops=arrowprops,
232+
text="Corrected\nGGA DFT\nAccuracy",
233+
verticalalignment="center",
234+
horizontalalignment="left",
235+
)
236+
ax.annotate(
237+
xy=(0.105, 0.1),
238+
xytext=(0.16, 0.1),
239+
arrowprops=arrowprops,
240+
text="GGA DFT\nAccuracy",
241+
verticalalignment="center",
242+
horizontalalignment="left",
243+
)
244+
245+
ax.text(0, 0.13, r"$|\Delta E_{Hull-MP}| > $MAE", horizontalalignment="center")
246+
247+
ax.set(xlabel=r"$\Delta E_{Hull-MP}$ / eV per atom", ylabel="MAE / eV per atom")
248+
249+
ax.set(xlim=x_lim, ylim=(0.0, 0.14))
250+
251+
return ax

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
plt.rc("savefig", bbox="tight", dpi=200)
1818
plt.rcParams["figure.constrained_layout.use"] = True
19-
plt.rc("figure", dpi=150)
19+
plt.rc("figure", dpi=200)
2020
plt.rc("font", size=16)
2121

2222

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# %%
2+
from datetime import datetime
3+
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
7+
from mb_discovery import ROOT
8+
from mb_discovery.plot_scripts.plot_funcs import rolling_mae_vs_hull_dist
9+
10+
11+
__author__ = "Rhys Goodall, Janosh Riebesell"
12+
__date__ = "2022-06-18"
13+
14+
today = f"{datetime.now():%Y-%m-%d}"
15+
16+
plt.rc("savefig", bbox="tight", dpi=200)
17+
plt.rcParams["figure.constrained_layout.use"] = True
18+
plt.rc("figure", dpi=200)
19+
plt.rc("font", size=16)
20+
21+
22+
# %%
23+
markers = ["o", "v", "^", "H", "D", ""]
24+
25+
df = pd.read_csv(
26+
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
27+
# f"{ROOT}/data/2022-08-16-wrenformer-ensemble-predictions.csv.bz2"
28+
).set_index("material_id")
29+
30+
31+
# %%
32+
rare = "all"
33+
# from pymatgen.core import Composition
34+
# rare = "no-lanthanides"
35+
# df["contains_rare_earths"] = df.composition.map(
36+
# lambda x: any(el.is_rare_earth_metal for el in Composition(x))
37+
# )
38+
# df = df.query("~contains_rare_earths")
39+
40+
41+
df_hull = pd.read_csv(
42+
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
43+
).set_index("material_id")
44+
45+
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
46+
47+
assert (n_nans := df.isna().sum().sum()) == 0, f"Found {n_nans} NaNs"
48+
49+
target_col = "e_form_target"
50+
# --- or ---
51+
# target_col = "e_form_per_atom_target"
52+
# df["e_form_per_atom_target"] = df.e_form / df.n_sites
53+
54+
# make sure we average the expected number of ensemble member predictions
55+
assert df.filter(regex=r"_pred_\d").shape[1] == 10
56+
57+
df["e_form_pres_ens"] = df.filter(regex=r"_pred_\d+").mean(axis=1)
58+
df["e_above_mp_hull_pred"] = df.e_form_pres_ens - df[target_col] + df.e_above_mp_hull
59+
60+
df["residual"] = df.e_above_mp_hull_pred - df.e_above_mp_hull
61+
62+
63+
# %%
64+
ax = rolling_mae_vs_hull_dist(
65+
df,
66+
e_above_hull_col="e_above_mp_hull",
67+
residual_col="residual",
68+
)
69+
70+
ax.figure.set_size_inches(10, 9)
71+
72+
img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist-{rare=}.pdf"
73+
# plt.savefig(img_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# %%
2+
from datetime import datetime
3+
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
7+
from mb_discovery import ROOT
8+
from mb_discovery.plot_scripts.plot_funcs import rolling_mae_vs_hull_dist
9+
10+
11+
__author__ = "Rhys Goodall, Janosh Riebesell"
12+
__date__ = "2022-06-18"
13+
14+
today = f"{datetime.now():%Y-%m-%d}"
15+
16+
17+
plt.rc("savefig", bbox="tight", dpi=200)
18+
plt.rcParams["figure.constrained_layout.use"] = True
19+
plt.rc("figure", dpi=200)
20+
plt.rc("font", size=16)
21+
22+
23+
# %%
24+
rare = "all"
25+
26+
df_wbm = pd.read_csv(
27+
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
28+
).set_index("material_id")
29+
30+
df_hull = pd.read_csv(
31+
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
32+
).set_index("material_id")
33+
34+
df_wbm["e_above_mp_hull"] = df_hull.e_above_mp_hull
35+
assert df_wbm.e_above_mp_hull.isna().sum() == 0
36+
37+
target_col = "e_form_target"
38+
39+
# make sure we average the expected number of ensemble member predictions
40+
assert df_wbm.filter(regex=r"_pred_\d").shape[1] == 10
41+
42+
df_wbm["e_above_mp_hull_pred"] = (
43+
df_wbm.filter(regex=r"_pred_\d").mean(axis=1)
44+
- df_wbm[target_col]
45+
+ df_wbm.e_above_mp_hull
46+
)
47+
df_wbm["error"] = abs(df_wbm.e_above_mp_hull_pred - df_wbm.e_above_mp_hull)
48+
49+
50+
# %%
51+
fig, ax = plt.subplots(1, figsize=(10, 9))
52+
markers = ("o", "v", "^", "H", "D")
53+
assert len(markers) == 5 # number of WBM rounds of element substitution
54+
55+
for idx, marker in enumerate(markers, 1):
56+
title = f"Batch {idx}"
57+
df = df_wbm[df_wbm.index.str.startswith(f"wbm-step-{idx}")]
58+
59+
rolling_mae_vs_hull_dist(
60+
df,
61+
residual_col="error",
62+
e_above_hull_col="e_above_mp_hull",
63+
ax=ax,
64+
label=title,
65+
marker=marker,
66+
markevery=20,
67+
markerfacecolor="white",
68+
markeredgewidth=2.5,
69+
)
70+
71+
72+
ax.legend(loc="lower right", frameon=False)
73+
74+
75+
img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist-wbm-batches-{rare=}.pdf"
76+
# plt.savefig(img_path)

0 commit comments

Comments
 (0)