Skip to content

Commit f5057ac

Browse files
committed
fix misnamed variable in rolling_mae_vs_hull_dist() e_above_hull_pred->e_above_hull_error
clean up scripts/rolling_mae_vs_hull_dist{,batches}.py
1 parent 0773112 commit f5057ac

File tree

4 files changed

+35
-72
lines changed

4 files changed

+35
-72
lines changed

matbench_discovery/plots.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def hist_classified_stable_vs_hull_dist(
250250

251251
def rolling_mae_vs_hull_dist(
252252
e_above_hull_true: pd.Series,
253-
e_above_hull_pred: pd.Series,
254-
half_window: float = 0.02,
253+
e_above_hull_error: pd.Series,
254+
window: float = 0.04,
255255
bin_width: float = 0.002,
256256
x_lim: tuple[float, float] = (-0.2, 0.3),
257257
ax: plt.Axes = None,
@@ -273,19 +273,25 @@ def rolling_mae_vs_hull_dist(
273273
rolling_maes = np.zeros_like(bins)
274274
rolling_stds = np.zeros_like(bins)
275275
for idx, bin_center in enumerate(bins):
276-
low = bin_center - half_window
277-
high = bin_center + half_window
276+
low = bin_center - window
277+
high = bin_center + window
278278

279279
mask = (e_above_hull_true <= high) & (e_above_hull_true > low)
280-
rolling_maes[idx] = e_above_hull_pred.loc[mask].abs().mean()
281-
rolling_stds[idx] = scipy.stats.sem(e_above_hull_pred.loc[mask].abs())
280+
rolling_maes[idx] = e_above_hull_error.loc[mask].abs().mean()
281+
rolling_stds[idx] = scipy.stats.sem(e_above_hull_error.loc[mask].abs())
282282

283283
kwargs = dict(linewidth=3) | kwargs
284284
ax.plot(bins, rolling_maes, **kwargs)
285285

286286
ax.fill_between(
287287
bins, rolling_maes + rolling_stds, rolling_maes - rolling_stds, alpha=0.3
288288
)
289+
# alternative implementation using pandas.rolling(). drawback: window size can only
290+
# be set as number of observations, not fixed-size energy above hull interval.
291+
# e_above_hull_error.index = e_above_hull_true # warning: in-place change
292+
# e_above_hull_error.sort_index().abs().rolling(window=8000).mean().plot(
293+
# ax=ax, **kwargs
294+
# )
289295

290296
if not is_fresh_ax:
291297
# return earlier if all plot objects besides the line were already drawn by a
@@ -294,7 +300,7 @@ def rolling_mae_vs_hull_dist(
294300

295301
scale_bar = AnchoredSizeBar(
296302
ax.transData,
297-
2 * half_window,
303+
window,
298304
"40 meV",
299305
"lower left",
300306
pad=0.5,

scripts/rolling_mae_vs_hull_dist.py

+8-29
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,25 @@
11
# %%
2-
import pandas as pd
3-
42
from matbench_discovery import ROOT, today
5-
from matbench_discovery.load_preds import df_wbm
3+
from matbench_discovery.load_preds import load_df_wbm_with_preds
64
from matbench_discovery.plots import rolling_mae_vs_hull_dist
75

86
__author__ = "Rhys Goodall, Janosh Riebesell"
97
__date__ = "2022-06-18"
108

119

1210
# %%
13-
data_path = (
14-
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
15-
# f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
16-
)
17-
df = pd.read_csv(data_path).set_index("material_id")
18-
legend_label = "Wren"
19-
20-
21-
# %%
22-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
23-
24-
assert all(n_nans := df.isna().sum() == 0), f"Found {n_nans} NaNs"
25-
26-
target_col = "e_form_target"
27-
# target_col = "e_form_per_atom"
28-
# --- or ---
29-
# target_col = "e_form_per_atom_target"
30-
# df["e_form_per_atom_target"] = df.e_form / df.n_sites
31-
32-
# make sure we average the expected number of ensemble member predictions
33-
assert df.filter(regex=r"_pred_\d").shape[1] == 10
11+
df_wbm = load_df_wbm_with_preds(models=["Wren", "Wrenformer"]).round(3)
3412

35-
df["e_form_pres_ens"] = df.filter(regex=r"_pred_\d+").mean(axis=1)
36-
df["e_above_hull_pred"] = df.e_form_pres_ens - df[target_col]
13+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
14+
target_col = "e_form_per_atom_mp2020_corrected"
3715

3816

3917
# %%
18+
model_name = "Wrenformer"
4019
ax = rolling_mae_vs_hull_dist(
41-
e_above_hull_pred=df.e_above_hull_pred,
42-
e_above_hull_true=df.e_above_hull_mp,
43-
label=legend_label,
20+
e_above_hull_true=df_wbm[e_above_hull_col],
21+
e_above_hull_error=df_wbm[target_col] - df_wbm[model_name],
22+
label=model_name,
4423
)
4524

4625
fig = ax.figure
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,35 @@
11
# %%
2-
import pandas as pd
3-
42
from matbench_discovery import ROOT, today
5-
from matbench_discovery.load_preds import df_wbm
3+
from matbench_discovery.load_preds import load_df_wbm_with_preds
64
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
75

86
__author__ = "Rhys Goodall, Janosh Riebesell"
97
__date__ = "2022-06-18"
108

119

1210
# %%
13-
df_wren = pd.read_csv(
14-
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
15-
).set_index("material_id")
16-
17-
df_wrenformer = pd.read_csv(
18-
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
19-
).set_index("material_id")
20-
21-
22-
# %%
23-
model_name = "wren"
24-
df = {"wren": df_wren, "wrenformer": df_wrenformer}[model_name]
25-
26-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
27-
assert df.e_above_hull_mp.isna().sum() == 0
28-
29-
possible_targets = (
30-
"e_form_per_atom_mp2020_corrected e_form_per_atom e_form_target".split()
31-
)
32-
target_col = next(filter(lambda x: x in df, possible_targets))
33-
34-
# make sure we average the expected number of ensemble member predictions
35-
assert df.filter(regex=r"_pred_\d").shape[1] == 10
11+
df_wbm = load_df_wbm_with_preds(models=["Wren", "Wrenformer"]).round(3)
3612

37-
df["e_above_hull_pred"] = df.filter(regex=r"_pred_\d").mean(axis=1) - df[target_col]
13+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
14+
target_col = "e_form_per_atom_mp2020_corrected"
3815

3916

4017
# %%
18+
model_name = "Wrenformer"
4119
fig, ax = plt.subplots(1, figsize=(10, 9))
4220
markers = ("o", "v", "^", "H", "D")
4321
assert len(markers) == 5 # number of WBM rounds of element substitution
4422

4523
for idx, marker in enumerate(markers, 1):
4624
# select all rows from WBM step=idx
47-
df_step = df[df.index.str.startswith(f"wbm-step-{idx}")]
25+
df_step = df_wbm[df_wbm.index.str.startswith(f"wbm-step-{idx}")]
4826

4927
title = f"Batch {idx} ({len(df_step.filter(like='e_').dropna()):,})"
5028
assert 1e4 < len(df_step) < 1e5, print(f"{len(df_step) = :,}")
5129

5230
rolling_mae_vs_hull_dist(
53-
e_above_hull_pred=df_step.e_above_hull_pred,
54-
e_above_hull_true=df_step.e_above_hull_mp,
31+
e_above_hull_error=df_step[target_col] - df_step[model_name],
32+
e_above_hull_true=df_step[e_above_hull_col],
5533
ax=ax,
5634
label=title,
5735
marker=marker,
@@ -65,5 +43,5 @@
6543
ax.set(title=f"{today} model={model_name}")
6644

6745

68-
img_name = f"{today}-{model_name}-rolling-mae-vs-hull-dist-wbm-batches"
69-
# fig.savefig(f"{ROOT}/figures/{img_name}.pdf")
46+
img_path = f"{ROOT}/figures/{today}-{model_name}-rolling-mae-vs-hull-dist-wbm-batches"
47+
# fig.savefig(f"{img_path}.pdf")

tests/test_plots.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,25 @@ def test_cumulative_precision_recall(
7171
assert fig.layout.yaxis2.title.text == "Recall"
7272

7373

74-
@pytest.mark.parametrize("half_window", (0.02, 0.002))
74+
@pytest.mark.parametrize("window", (0.02, 0.002))
7575
@pytest.mark.parametrize("bin_width", (0.1, 0.001))
7676
@pytest.mark.parametrize("x_lim", ((0, 0.6), (-0.2, 0.8)))
7777
def test_rolling_mae_vs_hull_dist(
78-
half_window: float, bin_width: float, x_lim: tuple[float, float]
78+
window: float, bin_width: float, x_lim: tuple[float, float]
7979
) -> None:
8080
ax = plt.figure().gca() # new figure ensures test functions use different axes
8181

8282
for (model_name, df), color in zip(
8383
test_dfs.items(), ("tab:blue", "tab:orange", "tab:pink")
8484
):
8585
ax = rolling_mae_vs_hull_dist(
86-
e_above_hull_pred=df.e_above_hull_pred,
86+
e_above_hull_error=df.e_above_hull_pred,
8787
e_above_hull_true=df.e_above_hull_mp,
8888
color=color,
8989
label=model_name,
9090
ax=ax,
9191
x_lim=x_lim,
92-
half_window=half_window,
92+
window=window,
9393
bin_width=bin_width,
9494
)
9595

0 commit comments

Comments
 (0)