Skip to content

Commit fbf6a02

Browse files
committed
support plotly backend in rolling_mae_vs_hull_dist()
test all plot funcs with all backends
1 parent 811f581 commit fbf6a02

File tree

4 files changed

+209
-107
lines changed

4 files changed

+209
-107
lines changed

matbench_discovery/plots.py

+180-86
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def hist_classified_stable_vs_hull_dist(
8484
x_lim: tuple[float | None, float | None] = (-0.4, 0.4),
8585
rolling_accuracy: float | None = 0.02,
8686
backend: Backend = "plotly",
87-
ylabel: str = "Number of materials",
87+
y_label: str = "Number of materials",
8888
**kwargs: Any,
8989
) -> tuple[plt.Axes | go.Figure, dict[str, float]]:
9090
"""
@@ -112,8 +112,9 @@ def hist_classified_stable_vs_hull_dist(
112112
x_lim (tuple[float | None, float | None]): x-axis limits.
113113
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to None
114114
or 0 to disable. Defaults to 0.02, meaning 20 meV / atom.
115-
backend ('matplotlib' | 'plotly'], optional): Which plotting backend to use.
116-
Changes the return type.
115+
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
116+
Changes the return type. Defaults to 'plotly'.
117+
y_label (str, optional): y-axis label. Defaults to "Number of materials".
117118
kwargs: Additional keyword arguments passed to the ax.hist() or px.histogram()
118119
depending on backend.
119120
@@ -162,7 +163,7 @@ def hist_classified_stable_vs_hull_dist(
162163
stacked=True,
163164
**kwargs,
164165
)
165-
ax.set(xlabel=xlabel, ylabel=ylabel, xlim=x_lim)
166+
ax.set(xlabel=xlabel, ylabel=y_label, xlim=x_lim)
166167

167168
if stability_threshold is not None:
168169
ax.axvline(
@@ -221,7 +222,7 @@ def hist_classified_stable_vs_hull_dist(
221222
**kwargs,
222223
)
223224
ax.update_layout(
224-
dict(xaxis_title=xlabel, yaxis_title=ylabel),
225+
dict(xaxis_title=xlabel, yaxis_title=y_label),
225226
legend=dict(title=None, yanchor="top", y=1, xanchor="right", x=1),
226227
)
227228

@@ -251,27 +252,46 @@ def hist_classified_stable_vs_hull_dist(
251252
def rolling_mae_vs_hull_dist(
252253
e_above_hull_true: pd.Series,
253254
e_above_hull_error: pd.Series,
254-
window: float = 0.04,
255-
bin_width: float = 0.002,
255+
window: float = 0.02,
256+
bin_width: float = 0.001,
256257
x_lim: tuple[float, float] = (-0.2, 0.3),
258+
y_lim: tuple[float, float] = (0.0, 0.14),
257259
ax: plt.Axes = None,
260+
backend: Backend = "plotly",
261+
y_label: str = "rolling MAE (eV/atom)",
258262
**kwargs: Any,
259263
) -> plt.Axes:
260264
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
261-
bar is shown for the windowing period of 40 meV per atom used when calculating
262-
the rolling MAE. The standard error in the mean is shaded
263-
around each curve. The highlighted V-shaped region shows the area in which the
264-
average absolute error is greater than the energy to the known convex hull. This is
265-
where models are most at risk of misclassifying structures.
266-
"""
267-
ax = ax or plt.gca()
265+
bar is shown for the windowing period of 40 meV per atom used when calculating the
266+
rolling MAE. The standard error in the mean is shaded around each curve. The
267+
highlighted V-shaped region shows the area in which the average absolute error is
268+
greater than the energy to the known convex hull. This is where models are most at
269+
risk of misclassifying structures.
268270
269-
is_fresh_ax = len(ax.lines) == 0
271+
Args:
272+
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
273+
ground truth (in eV / atom).
274+
e_above_hull_error (pd.Series): Error in model-predicted distance to convex
275+
hull, i.e. actual hull distance minus predicted hull distance (in eV / atom).
276+
window (float, optional): Rolling MAE averaging window. Defaults to 0.02 (20 meV/atom)
277+
bin_width (float, optional): Density of line points (more points the smaller).
278+
Defaults to 0.002.
279+
x_lim (tuple[float, float], optional): x-axis range. Defaults to (-0.2, 0.3).
280+
y_lim (tuple[float, float], optional): y-axis range. Defaults to (0.0, 0.14).
281+
ax (plt.Axes, optional): matplotlib Axes object. Defaults to None.
282+
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
283+
Changes the return type. Defaults to 'plotly'.
284+
y_label (str, optional): y-axis label. Defaults to "rolling MAE (eV/atom)".
285+
286+
Returns:
287+
plt.Axes: _description_
288+
"""
270289

271290
bins = np.arange(*x_lim, bin_width)
272291

273292
rolling_maes = np.zeros_like(bins)
274293
rolling_stds = np.zeros_like(bins)
294+
275295
for idx, bin_center in enumerate(bins):
276296
low = bin_center - window
277297
high = bin_center + window
@@ -280,79 +300,152 @@ def rolling_mae_vs_hull_dist(
280300
rolling_maes[idx] = e_above_hull_error.loc[mask].abs().mean()
281301
rolling_stds[idx] = scipy.stats.sem(e_above_hull_error.loc[mask].abs())
282302

283-
kwargs = dict(linewidth=3) | kwargs
284-
ax.plot(bins, rolling_maes, **kwargs)
285-
286-
ax.fill_between(
287-
bins, rolling_maes + rolling_stds, rolling_maes - rolling_stds, alpha=0.3
288-
)
289-
# alternative implementation using pandas.rolling(). drawback: window size can only
290-
# be set as number of observations, not fixed-size energy above hull interval.
291-
# e_above_hull_error.index = e_above_hull_true # warning: in-place change
292-
# e_above_hull_error.sort_index().abs().rolling(window=8000).mean().plot(
293-
# ax=ax, **kwargs
294-
# )
295-
296-
if not is_fresh_ax:
297-
# return earlier if all plot objects besides the line were already drawn by a
298-
# previous call
299-
return ax
300-
301-
scale_bar = AnchoredSizeBar(
302-
ax.transData,
303-
window,
304-
"40 meV",
305-
"lower left",
306-
pad=0.5,
307-
frameon=False,
308-
size_vertical=0.002,
309-
)
310-
# indicate size of MAE averaging window
311-
ax.add_artist(scale_bar)
312-
313-
# DFT accuracy at 25 meV/atom for relative e_above_hull which is lower than
314-
# formation energy error due to systematic error cancellation among
315-
# similar chemistries, supporting ref:
303+
# DFT accuracy at 25 meV/atom for e_above_hull calculations of chemically similar
304+
# systems which is lower than formation energy error due to systematic error
305+
# cancellation among similar chemistries, supporting ref:
316306
# https://journals.aps.org/prb/abstract/10.1103/PhysRevB.85.155208
317307
dft_acc = 0.025
318-
ax.plot((dft_acc, 1), (dft_acc, 1), color="grey", linestyle="--", alpha=0.3)
319-
ax.plot((-1, -dft_acc), (1, dft_acc), color="grey", linestyle="--", alpha=0.3)
320-
ax.plot(
321-
(-dft_acc, dft_acc), (dft_acc, dft_acc), color="grey", linestyle="--", alpha=0.3
322-
)
323-
ax.fill_between(
324-
(-1, -dft_acc, dft_acc, 1),
325-
(1, 1, 1, 1),
326-
(1, dft_acc, dft_acc, 1),
327-
color="tab:red",
328-
alpha=0.2,
329-
)
330308

331-
ax.plot((0, dft_acc), (0, dft_acc), color="grey", linestyle="--", alpha=0.3)
332-
ax.plot((-dft_acc, 0), (dft_acc, 0), color="grey", linestyle="--", alpha=0.3)
333-
ax.fill_between(
334-
(-dft_acc, 0, dft_acc),
335-
(dft_acc, dft_acc, dft_acc),
336-
(dft_acc, 0, dft_acc),
337-
color="tab:orange",
338-
alpha=0.2,
339-
)
340-
# shrink=0.1 means cut off 10% length from both sides of arrow line
341-
arrowprops = dict(
342-
facecolor="black", width=0.5, headwidth=5, headlength=5, shrink=0.1
343-
)
344-
ax.annotate(
345-
xy=(-dft_acc, dft_acc),
346-
xytext=(-2 * dft_acc, dft_acc),
347-
text="Corrected\nGGA DFT\nAccuracy",
348-
arrowprops=arrowprops,
349-
verticalalignment="center",
350-
horizontalalignment="right",
351-
)
309+
if backend == "matplotlib":
310+
ax = ax or plt.gca()
311+
is_fresh_ax = len(ax.lines) == 0
312+
kwargs = dict(linewidth=3) | kwargs
313+
ax.plot(bins, rolling_maes, **kwargs)
314+
315+
ax.fill_between(
316+
bins, rolling_maes + rolling_stds, rolling_maes - rolling_stds, alpha=0.3
317+
)
318+
# alternative implementation using pandas.rolling(). drawback: window size can only
319+
# be set as number of observations, not fixed-size energy above hull interval.
320+
# e_above_hull_error.index = e_above_hull_true # warning: in-place change
321+
# e_above_hull_error.sort_index().abs().rolling(window=8000).mean().plot(
322+
# ax=ax, **kwargs
323+
# )
324+
if not is_fresh_ax:
325+
# return earlier if all plot objects besides the line were already drawn by a
326+
# previous call
327+
return ax
328+
329+
scale_bar = AnchoredSizeBar(
330+
ax.transData,
331+
window,
332+
"40 meV",
333+
"lower left",
334+
pad=0.5,
335+
frameon=False,
336+
size_vertical=0.002,
337+
)
338+
# indicate size of MAE averaging window
339+
ax.add_artist(scale_bar)
340+
341+
ax.fill_between(
342+
(-1, -dft_acc, dft_acc, 1),
343+
(1, 1, 1, 1),
344+
(1, dft_acc, dft_acc, 1),
345+
color="tab:red",
346+
alpha=0.2,
347+
)
352348

353-
ax.text(0, 0.13, r"$|E_\mathrm{above\ hull}| > $MAE", horizontalalignment="center")
354-
ax.set(xlabel=r"$E_\mathrm{above\ hull}$ (eV / atom)", ylabel="MAE (eV / atom)")
355-
ax.set(xlim=x_lim, ylim=(0.0, 0.14))
349+
ax.fill_between(
350+
(-dft_acc, 0, dft_acc),
351+
(dft_acc, dft_acc, dft_acc),
352+
(dft_acc, 0, dft_acc),
353+
color="tab:orange",
354+
alpha=0.2,
355+
)
356+
# shrink=0.1 means cut off 10% length from both sides of arrow line
357+
arrowprops = dict(
358+
facecolor="black", width=0.5, headwidth=5, headlength=5, shrink=0.1
359+
)
360+
ax.annotate(
361+
xy=(-dft_acc, dft_acc),
362+
xytext=(-2 * dft_acc, dft_acc),
363+
text="Corrected\nGGA DFT\nAccuracy",
364+
arrowprops=arrowprops,
365+
verticalalignment="center",
366+
horizontalalignment="right",
367+
)
368+
369+
ax.text(
370+
0, 0.13, r"MAE > $|E_\mathrm{above\ hull}|$", horizontalalignment="center"
371+
)
372+
ax.set(xlabel=r"$E_\mathrm{above\ hull}$ (eV/atom)", ylabel=y_label)
373+
ax.set(xlim=x_lim, ylim=y_lim)
374+
elif backend == "plotly":
375+
title = kwargs.pop("label", None)
376+
ax = px.line(
377+
x=bins,
378+
y=rolling_maes,
379+
# error_y=rolling_stds,
380+
markers=False,
381+
title=title,
382+
**kwargs,
383+
)
384+
ax_std = go.Scatter(
385+
x=list(bins) + list(bins)[::-1], # bins, then bins reversed
386+
y=list(rolling_maes + 2 * rolling_stds)
387+
+ list(rolling_maes - 2 * rolling_stds)[::-1], # upper, then lower reversed
388+
fill="toself",
389+
line_color="white",
390+
fillcolor=ax.data[0].line.color,
391+
opacity=0.3,
392+
hoverinfo="skip",
393+
showlegend=False,
394+
)
395+
ax.add_trace(ax_std)
396+
397+
ax.update_layout(
398+
dict(
399+
xaxis_title="E<sub>above hull</sub> (eV/atom)",
400+
yaxis_title="rolling MAE (eV/atom)",
401+
),
402+
legend=dict(title=None, xanchor="right", x=1, yanchor="bottom", y=0),
403+
)
404+
ax.update_xaxes(range=x_lim)
405+
ax.update_yaxes(range=y_lim)
406+
scatter_kwds = dict(fill="toself", opacity=0.5)
407+
err_gt_each_region = go.Scatter(
408+
x=(-1, -dft_acc, dft_acc, 1),
409+
y=(1, dft_acc, dft_acc, 1),
410+
name="MAE > |E<sub>above hull</sub>|",
411+
# fillcolor="yellow",
412+
**scatter_kwds,
413+
)
414+
ml_err_lt_dft_err_region = go.Scatter(
415+
x=(-dft_acc, dft_acc, 0, -dft_acc),
416+
y=(dft_acc, dft_acc, 0, dft_acc),
417+
name="MAE < |DFT error|",
418+
# fillcolor="red",
419+
**scatter_kwds,
420+
)
421+
ax.add_traces([err_gt_each_region, ml_err_lt_dft_err_region])
422+
ax.add_annotation(
423+
x=4 * dft_acc,
424+
y=dft_acc,
425+
text="Corrected GGA DFT Accuracy",
426+
showarrow=True,
427+
# arrowhead=1,
428+
ax=-dft_acc,
429+
ay=dft_acc,
430+
)
431+
432+
ax.data = ax.data[::-1] # bring px.line() to front
433+
# show MAE window size
434+
x0, y0 = x_lim[0] + 0.01, y_lim[0] + 0.01
435+
ax.add_annotation(
436+
x=x0 + 0.05,
437+
y=y0 + 0.01,
438+
text=f"rolling MAE window<br>{window} eV/atom",
439+
showarrow=False,
440+
)
441+
ax.add_shape(
442+
type="rect",
443+
x0=x0,
444+
y0=y0,
445+
x1=x0 + window,
446+
y1=y0 + window / 5,
447+
fillcolor="black",
448+
)
356449

357450
return ax
358451

@@ -388,8 +481,9 @@ def cumulative_precision_recall(
388481
axis projection lines.
389482
show_optimal (bool, optional): Whether to plot the optimal recall line. Defaults
390483
to False.
391-
backend ('plotly' | 'matplotlib', optional): Defaults to 'plotly'. **kwargs:
392-
Keyword arguments passed to df.plot().
484+
backend ('matplotlib' | 'plotly'], optional): Which plotting engine to use.
485+
Changes the return type. Defaults to 'plotly'.
486+
**kwargs: Keyword arguments passed to df.plot().
393487
394488
Returns:
395489
tuple[plt.Figure | go.Figure, pd.DataFrame]: The matplotlib/plotly figure and

scripts/rolling_mae_vs_hull_dist.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,18 @@
2020
e_above_hull_true=df_wbm[e_above_hull_col],
2121
e_above_hull_error=df_wbm[target_col] - df_wbm[model_name],
2222
label=model_name,
23+
backend=(backend := "plotly"),
2324
)
2425

25-
fig = ax.figure
26-
fig.set_size_inches(10, 9)
27-
ax.legend(loc="lower right", frameon=False)
26+
title = f"{today} {model_name}"
27+
if backend == "matplotlib":
28+
fig = ax.figure
29+
fig.set_size_inches(10, 9)
30+
ax.legend(loc="lower right", frameon=False)
31+
ax.set(title=title)
32+
elif backend == "plotly":
33+
ax.update_layout(title=dict(text=title, x=0.5))
34+
ax.show()
2835

2936
img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist.pdf"
3037
# fig.savefig(img_path)

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
assert 1e4 < len(df_step) < 1e5, print(f"{len(df_step) = :,}")
2929

3030
rolling_mae_vs_hull_dist(
31-
e_above_hull_error=df_step[target_col] - df_step[model_name],
3231
e_above_hull_true=df_step[e_above_hull_col],
32+
e_above_hull_error=df_step[target_col] - df_step[model_name],
3333
ax=ax,
3434
label=title,
3535
marker=marker,

0 commit comments

Comments
 (0)