Skip to content

Commit 1b7f056

Browse files
committed
refactor rolling_mae_vs_hull_dist() to support plotting multiple lines at once
corresponding to models or WBM batches update tests and plotting scripts
1 parent f84171b commit 1b7f056

6 files changed

+135
-132
lines changed

matbench_discovery/plots.py

+78-79
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,14 @@ def hist_classified_stable_vs_hull_dist(
254254

255255
def rolling_mae_vs_hull_dist(
256256
e_above_hull_true: pd.Series,
257-
e_above_hull_error: pd.Series,
257+
e_above_hull_errors: pd.DataFrame | dict[str, pd.Series],
258258
window: float = 0.02,
259259
bin_width: float = 0.001,
260260
x_lim: tuple[float, float] = (-0.2, 0.2),
261-
y_lim: tuple[float, float] = (0, 0.15),
262-
ax: plt.Axes = None,
261+
y_lim: tuple[float, float] = (0, 0.2),
263262
backend: Backend = "plotly",
264263
y_label: str = "rolling MAE (eV/atom)",
264+
just_plot_lines: bool = False,
265265
**kwargs: Any,
266266
) -> plt.Axes | go.Figure:
267267
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
@@ -274,61 +274,75 @@ def rolling_mae_vs_hull_dist(
274274
Args:
275275
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
276276
ground truth (in eV / atom).
277-
e_above_hull_error (pd.Series): Error in model-predicted distance to convex
278-
hull, i.e. actual hull distance minus predicted hull distance (in eV / atom).
279-
window (float, optional): Rolling MAE averaging window. Defaults to 0.02 (20 meV/atom)
280-
bin_width (float, optional): Density of line points (more points the smaller).
277+
e_above_hull_errors (pd.DataFrame | dict[str, pd.Series]): Error in
278+
model-predicted distance to convex hull, i.e. actual hull distance minus
279+
predicted hull distance (in eV / atom).
280+
window (float, optional): Rolling MAE averaging window. Defaults to 0.02 (20
281+
meV/atom) bin_width (float, optional): Density of line points (more points the
282+
smaller).
281283
Defaults to 0.002.
282284
x_lim (tuple[float, float], optional): x-axis range. Defaults to (-0.2, 0.3).
283285
y_lim (tuple[float, float], optional): y-axis range. Defaults to (0.0, 0.14).
284-
ax (plt.Axes, optional): matplotlib Axes object. Defaults to None.
285286
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
286287
Changes the return type. Defaults to 'plotly'.
287288
y_label (str, optional): y-axis label. Defaults to "rolling MAE (eV/atom)".
289+
just_plot_line (bool, optional): If True, plot only the rolling MAE, no shapes
290+
and annotations. Also won't plot the standard error in the mean. Defaults
291+
to False.
288292
289293
Returns:
290-
plt.Axes | go.Figure: matplotlib Axes or plotly Figure depending on backend.
294+
tuple[plt.Axes | go.Figure, pd.DataFrame, pd.DataFrame]: matplotlib Axes or
295+
plotly
296+
Figure depending on backend, followed by two dataframes containing the
297+
rolling error for each column in e_above_hull_errors and the rolling
298+
standard error in the mean.
291299
"""
292300
bins = np.arange(*x_lim, bin_width)
301+
models = list(e_above_hull_errors)
302+
303+
df_rolling_err = pd.DataFrame(columns=models, index=bins)
304+
df_err_std = df_rolling_err.copy()
305+
306+
for model in models:
307+
for idx, bin_center in enumerate(bins):
308+
low = bin_center - window
309+
high = bin_center + window
293310

294-
rolling_maes = np.zeros_like(bins)
295-
rolling_stds = np.zeros_like(bins)
311+
mask = (e_above_hull_true <= high) & (e_above_hull_true > low)
296312

297-
for idx, bin_center in enumerate(bins):
298-
low = bin_center - window
299-
high = bin_center + window
313+
each_mae = e_above_hull_errors[model].loc[mask].abs().mean()
314+
df_rolling_err[model].iloc[idx] = each_mae
300315

301-
mask = (e_above_hull_true <= high) & (e_above_hull_true > low)
302-
rolling_maes[idx] = e_above_hull_error.loc[mask].abs().mean()
303-
rolling_stds[idx] = scipy.stats.sem(e_above_hull_error.loc[mask].abs())
316+
# drop NaNs to avoid error, scipy doesn't ignore NaNs
317+
each_std = scipy.stats.sem(
318+
e_above_hull_errors[model].loc[mask].dropna().abs()
319+
)
320+
df_err_std[model].iloc[idx] = each_std
321+
322+
# increase line width
323+
ax = df_rolling_err.plot(backend=backend, **kwargs)
324+
325+
if just_plot_lines:
326+
# return earlier if all plot objects besides the line were already drawn by a
327+
# previous call
328+
return ax, df_rolling_err, df_err_std
304329

305330
# DFT accuracy at 25 meV/atom for e_above_hull calculations of chemically similar
306331
# systems which is lower than formation energy error due to systematic error
307332
# cancellation among similar chemistries, supporting ref:
308-
# https://journals.aps.org/prb/abstract/10.1103/PhysRevB.85.155208
333+
href = "https://doi.org/10.1103/PhysRevB.85.155208"
309334
dft_acc = 0.025
310-
# used by plotly branch of this function, unrecognized by matplotlib
311-
fig = kwargs.pop("fig", None)
312335

313336
if backend == "matplotlib":
314-
ax = ax or plt.gca()
315-
is_fresh_ax = len(ax.lines) == 0
316-
kwargs = dict(linewidth=3) | kwargs
317-
ax.plot(bins, rolling_maes, **kwargs)
318-
319-
ax.fill_between(
320-
bins, rolling_maes + rolling_stds, rolling_maes - rolling_stds, alpha=0.3
321-
)
322-
# alternative implementation using pandas.rolling(). drawback: window size can only
323-
# be set as number of observations, not fixed-size energy above hull interval.
324-
# e_above_hull_error.index = e_above_hull_true # warning: in-place change
325-
# e_above_hull_error.sort_index().abs().rolling(window=8000).mean().plot(
326-
# ax=ax, **kwargs
327-
# )
328-
if not is_fresh_ax:
329-
# return earlier if all plot objects besides the line were already drawn by a
330-
# previous call
331-
return ax
337+
# assert df_rolling_err.isna().sum().sum() == 0, "NaNs in df_rolling_err"
338+
# assert df_err_std.isna().sum().sum() == 0, "NaNs in df_err_std"
339+
# for model in df_rolling_err:
340+
# ax.fill_between(
341+
# bins,
342+
# df_rolling_err[model] + df_err_std[model],
343+
# df_rolling_err[model] - df_err_std[model],
344+
# alpha=0.3,
345+
# )
332346

333347
scale_bar = AnchoredSizeBar(
334348
ax.transData,
@@ -376,34 +390,22 @@ def rolling_mae_vs_hull_dist(
376390
ax.set(xlabel=r"$E_\mathrm{above\ hull}$ (eV/atom)", ylabel=y_label)
377391
ax.set(xlim=x_lim, ylim=y_lim)
378392
elif backend == "plotly":
379-
title = kwargs.pop("label", None)
380-
ax = px.line(
381-
x=bins,
382-
y=rolling_maes,
383-
# error_y=rolling_stds,
384-
markers=False,
385-
title=title,
386-
**kwargs,
387-
)
388-
line_color = ax.data[0].line.color
389-
ax_std = go.Scatter(
390-
x=list(bins) + list(bins)[::-1], # bins, then bins reversed
391-
y=list(rolling_maes + 2 * rolling_stds)
392-
+ list(rolling_maes - 2 * rolling_stds)[::-1], # upper, then lower reversed
393-
fill="toself",
394-
line_color="white",
395-
fillcolor=line_color,
396-
opacity=0.3,
397-
hoverinfo="skip",
398-
showlegend=False,
399-
)
400-
ax.add_trace(ax_std)
401-
402-
if isinstance(fig, go.Figure):
403-
# if passed existing plotly figure, add traces to it
404-
# return without changing layout and adding annotations
405-
fig.add_traces(ax.data)
406-
return fig
393+
for idx, model in enumerate(df_rolling_err):
394+
ax.data[idx].legendgroup = model
395+
ax.add_scatter(
396+
x=list(bins) + list(bins)[::-1], # bins, then bins reversed
397+
y=list(df_rolling_err[model] + 3 * df_err_std[model])
398+
+ list(df_rolling_err[model] - 3 * df_err_std[model])[
399+
::-1
400+
], # upper, then lower reversed
401+
mode="lines",
402+
line=dict(color="white", width=0),
403+
fill="toself",
404+
legendgroup=model,
405+
fillcolor=ax.data[0].line.color,
406+
opacity=0.3,
407+
showlegend=False,
408+
)
407409

408410
legend = dict(title=None, xanchor="right", x=1, yanchor="bottom", y=0)
409411
ax.update_layout(
@@ -415,32 +417,30 @@ def rolling_mae_vs_hull_dist(
415417
)
416418
ax.update_xaxes(range=x_lim)
417419
ax.update_yaxes(range=y_lim)
418-
scatter_kwds = dict(fill="toself", opacity=0.5)
419-
err_gt_each_region = go.Scatter(
420+
scatter_kwds = dict(fill="toself", opacity=0.3)
421+
ax.add_scatter(
420422
x=(-1, -dft_acc, dft_acc, 1),
421423
y=(1, dft_acc, dft_acc, 1),
422424
name="MAE > |E<sub>above hull</sub>|",
423425
# fillcolor="yellow",
424426
**scatter_kwds,
425427
)
426-
ml_err_lt_dft_err_region = go.Scatter(
428+
ax.add_scatter(
427429
x=(-dft_acc, dft_acc, 0, -dft_acc),
428430
y=(dft_acc, dft_acc, 0, dft_acc),
429431
name="MAE < |DFT error|",
430432
# fillcolor="red",
431433
**scatter_kwds,
432434
)
433-
ax.add_traces([err_gt_each_region, ml_err_lt_dft_err_region])
434435
ax.add_annotation(
435-
x=dft_acc,
436+
x=-dft_acc,
436437
y=dft_acc,
437-
text="<a href='https://doi.org/10.1103/PhysRevB.85.155208'>Corrected GGA DFT "
438-
"Accuracy</a>",
438+
text=f"<a {href=}>Corrected GGA Accuracy</a>",
439439
showarrow=True,
440-
xshift=10,
441-
arrowhead=1,
442-
ax=4 * dft_acc,
443-
ay=dft_acc,
440+
xshift=-10,
441+
arrowhead=2,
442+
ax=-4 * dft_acc,
443+
ay=2 * dft_acc,
444444
axref="x",
445445
ayref="y",
446446
)
@@ -464,10 +464,9 @@ def rolling_mae_vs_hull_dist(
464464
y0=y0,
465465
x1=x0 + window,
466466
y1=y0 + window / 5,
467-
fillcolor=line_color,
468467
)
469468

470-
return ax
469+
return ax, df_rolling_err, df_err_std
471470

472471

473472
def cumulative_precision_recall(

scripts/rolling_mae_vs_hull_dist.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,30 @@
88

99

1010
# %%
11-
df_wbm = load_df_wbm_with_preds(models=["Wrenformer"]).round(3)
11+
model = "Wrenformer"
12+
df_wbm = load_df_wbm_with_preds([model]).round(3)
1213

1314
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
1415
e_form_col = "e_form_per_atom_mp2020_corrected"
1516

1617

1718
# %%
18-
model_name = "Wrenformer"
19-
ax = rolling_mae_vs_hull_dist(
19+
ax, df_err, df_std = rolling_mae_vs_hull_dist(
2020
e_above_hull_true=df_wbm[e_above_hull_col],
21-
e_above_hull_error=df_wbm[e_form_col] - df_wbm[model_name],
22-
label=model_name,
21+
e_above_hull_errors={model: df_wbm[e_form_col] - df_wbm[model]},
22+
# label=model,
2323
backend=(backend := "plotly"),
24-
# template="plotly_white+global",
24+
# template="plotly_white",
2525
)
2626

27-
title = f"{today} {model_name}"
27+
title = f"{today} {model}"
2828
if backend == "matplotlib":
2929
fig = ax.figure
30-
fig.set_size_inches(10, 9)
30+
fig.set_size_inches(6, 5)
3131
ax.legend(loc="lower right", frameon=False)
3232
ax.set(title=title)
33+
for line in ax.lines:
34+
line._linewidth *= 2
3335
elif backend == "plotly":
3436
ax.update_layout(title=dict(text=title, x=0.5))
3537
ax.show()

scripts/rolling_mae_vs_hull_dist_all_models.py

+20-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# %%
2-
from plotly.subplots import make_subplots
32
from pymatviz.utils import save_fig
43

54
from matbench_discovery import FIGS, today
@@ -14,45 +13,38 @@
1413
models = sorted(
1514
"Wrenformer, CGCNN, Voronoi RF, MEGNet, M3GNet, BOWSR MEGNet".split(", ")
1615
)
16+
e_form_col = "e_form_per_atom_mp2020_corrected"
17+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
1718

1819
df_wbm = load_df_wbm_with_preds(models=models).round(3)
1920

2021

2122
# %%
22-
target_col = "e_form_per_atom_mp2020_corrected"
23-
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
2423
backend: Backend = "plotly"
2524

26-
rows, cols = len(models) // 3, 3
27-
if backend == "plotly":
28-
fig = make_subplots(rows=rows, cols=cols)
29-
30-
31-
for idx, model_name in enumerate(models):
32-
row, col = idx % rows + 1, idx // rows + 1
25+
for model in models:
26+
model_error = df_wbm[e_form_col] - df_wbm[model]
27+
MAE = (df_wbm[e_above_hull_col] - model_error).abs().mean()
28+
df_wbm[f"{model} {MAE=:.2f}"] = df_wbm[e_form_col] - df_wbm[model]
3329

34-
# assert df_wbm[model_name].isna().sum() < 100
35-
preds = df_wbm[target_col] - df_wbm[model_name]
36-
MAE = (df_wbm[e_above_hull_col] - preds).abs().mean()
30+
fig, df_err, df_std = rolling_mae_vs_hull_dist(
31+
e_above_hull_true=df_wbm[e_above_hull_col],
32+
e_above_hull_errors=df_wbm.filter(like=" MAE="),
33+
backend=backend,
34+
# template="plotly_white",
35+
)
3736

38-
ax = rolling_mae_vs_hull_dist(
39-
e_above_hull_true=df_wbm[e_above_hull_col],
40-
e_above_hull_error=preds,
41-
label=f"{model_name} · {MAE=:.2f}",
42-
backend=backend,
43-
)
44-
if backend == "plotly":
45-
fig.add_traces(ax.data, row=row, col=col)
4637

47-
if hasattr(ax, "legend"):
38+
if backend == "matplotlib":
4839
# increase line width in legend
49-
legend = ax.legend(frameon=False, loc="lower right")
50-
ax.figure.set_size_inches(10, 9)
51-
for line in legend.get_lines():
40+
legend = fig.legend(frameon=False, loc="lower right")
41+
fig.figure.set_size_inches(10, 9)
42+
for handle in legend.get_lines():
43+
handle._linewidth *= 6
44+
for line in fig.lines:
5245
line._linewidth *= 3
53-
54-
55-
fig.show()
46+
else:
47+
fig.show()
5648

5749

5850
# %%

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
# %%
11-
df_wbm = load_df_wbm_with_preds(models=["Wren", "Wrenformer"]).round(3)
11+
df_wbm = load_df_wbm_with_preds(models=["Wrenformer"]).round(3)
1212

1313
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
1414
e_form_col = "e_form_per_atom_mp2020_corrected"
@@ -18,29 +18,34 @@
1818
model_name = "Wrenformer"
1919
fig, ax = plt.subplots(1, figsize=(10, 9))
2020
markers = ("o", "v", "^", "H", "D")
21-
assert len(markers) == 5 # number of WBM rounds of element substitution
21+
assert len(markers) == 5 # number of iterations of element substitution in WBM data set
2222

2323
for idx, marker in enumerate(markers, 1):
2424
# select all rows from WBM step=idx
25-
df_step = df_wbm[df_wbm.index.str.startswith(f"wbm-{idx}")]
25+
df_step = df_wbm[df_wbm.index.str.startswith(f"wbm-{idx}-")]
2626

2727
title = f"Batch {idx} ({len(df_step.filter(like='e_').dropna()):,})"
2828
assert 1e4 < len(df_step) < 1e5, print(f"{len(df_step) = :,}")
2929

30-
rolling_mae_vs_hull_dist(
30+
ax, df_err, df_std = rolling_mae_vs_hull_dist(
3131
e_above_hull_true=df_step[e_above_hull_col],
32-
e_above_hull_error=df_step[e_form_col] - df_step[model_name],
33-
ax=ax,
32+
e_above_hull_errors={title: df_step[e_form_col] - df_step[model_name]},
3433
label=title,
3534
marker=marker,
3635
markevery=20,
3736
markerfacecolor="white",
3837
markeredgewidth=2.5,
38+
backend="matplotlib",
39+
ax=ax,
40+
just_plot_lines=idx > 1,
3941
)
4042

4143

4244
ax.legend(loc="lower right", frameon=False)
43-
ax.set(title=f"{today} model={model_name}")
45+
ax.set(title=f"{today} {model_name}")
46+
for line in ax.lines:
47+
line._linewidth *= 3
48+
line.set_markersize(10)
4449

4550

4651
img_path = f"{FIGS}/{today}-{model_name}-rolling-mae-vs-hull-dist-wbm-batches"

0 commit comments

Comments
 (0)