Skip to content

Commit 9aebb08

Browse files
committed
add kwarg show_dft_acc=False to rolling_mae_vs_hull_dist()
fix about-the-test-set page showing WBM element counts in MP heatmap's hover data
1 parent 75fc095 commit 9aebb08

File tree

5 files changed

+94
-59
lines changed

5 files changed

+94
-59
lines changed

data/wbm/analysis.py

-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@
113113
fig.update_layout(showlegend=False, paper_bgcolor="rgba(0,0,0,0)")
114114
fig.update_xaxes(title_text="WBM energy above MP convex hull (eV/atom)")
115115

116-
117116
for x_pos, label in zip(
118117
[mean, mean + std, mean - std],
119118
[f"{mean = :.2f}", f"{mean + std = :.2f}", f"{mean - std = :.2f}"],

data/wbm/readme.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ The WBM test set has an energy above the MP convex hull distribution with mean *
9494

9595
The dummy MAE of always predicting the test set mean is **0.17 eV/atom**.
9696

97-
The number of stable materials is **97k** out of 257k, resulting in a dummy stability hit rate of **37%**.
97+
The number of stable materials (according to the MP convex hull which is spanned by the training data the models have access to) is **97k** out of **257k**, resulting in a dummy stability hit rate of **37%**.
98+
99+
> Incidentally, [according to the authors](https://www.nature.com/articles/s41524-020-00481-6#Sec2), a more accurate stability rate according to the combined MP+WBM convex hull of the first 3 rounds of elemental substitution is 18,479 out of 189,981 crystals ($\approx$ 9.7%).
98100
99101
<slot name="wbm-each-hist">
100102
<img src="./figs/2023-01-26-wbm-each-hist.svg" alt="WBM energy above MP convex hull distribution">

matbench_discovery/plots.py

+65-54
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def rolling_mae_vs_hull_dist(
296296
y_label: str = "rolling MAE (eV/atom)",
297297
just_plot_lines: bool = False,
298298
with_sem: bool = True,
299+
show_dft_acc: bool = False,
299300
**kwargs: Any,
300301
) -> plt.Axes | go.Figure:
301302
"""Rolling mean absolute error as the energy to the convex hull is varied. A scale
@@ -325,6 +326,9 @@ def rolling_mae_vs_hull_dist(
325326
to False.
326327
with_sem (bool, optional): If True, plot the standard error of the mean as
327328
shaded area around the rolling MAE. Defaults to True.
329+
show_dft_acc (bool, optional): If True, change color of the cone of peril's tip
330+
and annotate it with 'Corrected GGA Accuracy' at rolling MAE of 25 meV/atom.
331+
Defaults to False.
328332
329333
Returns:
330334
tuple[plt.Axes | go.Figure, pd.DataFrame, pd.DataFrame]: matplotlib Axes or
@@ -363,8 +367,8 @@ def rolling_mae_vs_hull_dist(
363367
# previous call
364368
return ax, df_rolling_err, df_err_std
365369

366-
# DFT accuracy at 25 meV/atom for e_above_hull calculations of chemically similar
367-
# systems which is lower than formation energy error due to systematic error
370+
# DFT accuracy at 25 meV/atom for relative difference of e_above_hull for chemically
371+
# similar systems which is lower than formation energy error due to systematic error
368372
# cancellation among similar chemistries, supporting ref:
369373
href = "https://doi.org/10.1103/PhysRevB.85.155208"
370374
dft_acc = 0.025
@@ -397,32 +401,33 @@ def rolling_mae_vs_hull_dist(
397401
ax.add_artist(scale_bar)
398402

399403
ax.fill_between(
400-
(-1, -dft_acc, dft_acc, 1),
401-
(1, 1, 1, 1),
402-
(1, dft_acc, dft_acc, 1),
404+
(-1, -dft_acc, dft_acc, 1) if show_dft_acc else (-1, 0, 1),
405+
(1, 1, 1, 1) if show_dft_acc else (1, 1, 1),
406+
(1, dft_acc, dft_acc, 1) if show_dft_acc else (1, 0, 1),
403407
color="tab:red",
404408
alpha=0.2,
405409
)
406410

407-
ax.fill_between(
408-
(-dft_acc, 0, dft_acc),
409-
(dft_acc, dft_acc, dft_acc),
410-
(dft_acc, 0, dft_acc),
411-
color="tab:orange",
412-
alpha=0.2,
413-
)
414-
# shrink=0.1 means cut off 10% length from both sides of arrow line
415-
arrowprops = dict(
416-
facecolor="black", width=0.5, headwidth=5, headlength=5, shrink=0.1
417-
)
418-
ax.annotate(
419-
xy=(-dft_acc, dft_acc),
420-
xytext=(-2 * dft_acc, dft_acc),
421-
text="Corrected\nGGA DFT\nAccuracy",
422-
arrowprops=arrowprops,
423-
verticalalignment="center",
424-
horizontalalignment="right",
425-
)
411+
if show_dft_acc:
412+
ax.fill_between(
413+
(-dft_acc, 0, dft_acc),
414+
(dft_acc, dft_acc, dft_acc),
415+
(dft_acc, 0, dft_acc),
416+
color="tab:orange",
417+
alpha=0.2,
418+
)
419+
# shrink=0.1 means cut off 10% length from both sides of arrow line
420+
arrowprops = dict(
421+
facecolor="black", width=0.5, headwidth=5, headlength=5, shrink=0.1
422+
)
423+
ax.annotate(
424+
xy=(-dft_acc, dft_acc),
425+
xytext=(-2 * dft_acc, dft_acc),
426+
text="Corrected GGA\nAccuracy",
427+
arrowprops=arrowprops,
428+
verticalalignment="center",
429+
horizontalalignment="right",
430+
)
426431

427432
ax.text(
428433
0, 0.13, r"MAE > $|E_\mathrm{above\ hull}|$", horizontalalignment="center"
@@ -457,43 +462,49 @@ def rolling_mae_vs_hull_dist(
457462
yanchor="bottom",
458463
title_font=dict(size=13),
459464
)
460-
ax.update_layout(
461-
dict(
462-
xaxis_title="E<sub>above MP hull</sub> (eV/atom)",
463-
yaxis_title="rolling MAE (eV/atom)",
464-
),
465-
legend=legend,
466-
)
465+
ax.layout.legend.update(legend)
466+
ax.layout.xaxis.title.text = "E<sub>above MP hull</sub> (eV/atom)"
467+
ax.layout.yaxis.title.text = "rolling MAE (eV/atom)"
467468
ax.update_xaxes(range=x_lim)
468469
ax.update_yaxes(range=y_lim)
469-
scatter_kwds = dict(fill="toself", opacity=0.4)
470-
ax.add_scatter(
471-
x=(-1, -dft_acc, dft_acc, 1),
472-
y=(1, dft_acc, dft_acc, 1),
473-
name="MAE > |E<sub>above hull</sub>|",
474-
# fillcolor="yellow",
475-
**scatter_kwds,
476-
)
470+
scatter_kwds = dict(fill="toself", opacity=0.2)
471+
peril_cone_anno = "MAE > |E<sub>above hull</sub>|"
477472
ax.add_scatter(
478-
x=(-dft_acc, dft_acc, 0, -dft_acc),
479-
y=(dft_acc, dft_acc, 0, dft_acc),
480-
name="MAE < |DFT error|",
481-
# fillcolor="red",
473+
x=(-1, -dft_acc, dft_acc, 1) if show_dft_acc else (-1, 0, 1),
474+
y=(1, dft_acc, dft_acc, 1) if show_dft_acc else (1, 0, 1),
475+
name=peril_cone_anno,
476+
fillcolor="red",
477+
showlegend=False,
482478
**scatter_kwds,
483479
)
484480
ax.add_annotation(
485-
x=-dft_acc,
486-
y=dft_acc,
487-
text=f"<a {href=}>Corrected GGA Accuracy<br>for rel. Energy</a> "
488-
"[<a href='#hautier_accuracy_2012' target='_self'>ref</a>]",
489-
showarrow=True,
490-
xshift=-10,
491-
arrowhead=2,
492-
ax=-4 * dft_acc,
493-
ay=2 * dft_acc,
494-
axref="x",
495-
ayref="y",
481+
x=0,
482+
y=0.8,
483+
text=peril_cone_anno,
484+
showarrow=False,
485+
yref="paper",
496486
)
487+
if show_dft_acc:
488+
ax.add_scatter(
489+
x=(-dft_acc, dft_acc, 0, -dft_acc),
490+
y=(dft_acc, dft_acc, 0, dft_acc),
491+
name="MAE < |Corrected GGA error|",
492+
fillcolor="red",
493+
**scatter_kwds,
494+
)
495+
ax.add_annotation(
496+
x=-dft_acc,
497+
y=dft_acc,
498+
text=f"<a {href=}>Corrected GGA Accuracy<br>for rel. Energy</a> "
499+
"[<a href='#hautier_accuracy_2012' target='_self'>ref</a>]",
500+
showarrow=True,
501+
xshift=-10,
502+
arrowhead=2,
503+
ax=-4 * dft_acc,
504+
ay=2 * dft_acc,
505+
axref="x",
506+
ayref="y",
507+
)
497508

498509
ax.data = ax.data[::-1] # bring px.line() to front
499510
# plot rectangle to indicate MAE window size

matbench_discovery/structure.py

+23
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,26 @@ def perturb_structure(struct: Structure, gamma: float = 1.5) -> Structure:
2525
site.to_unit_cell(in_place=True)
2626

2727
return perturbed
28+
29+
30+
if __name__ == "__main__":
31+
import matplotlib.pyplot as plt
32+
33+
gamma = 1.5
34+
samples = np.array([np.random.weibull(gamma) for _ in range(10000)])
35+
mean = samples.mean()
36+
37+
# reproduces the dist in https://www.nature.com/articles/s41524-022-00891-8#Fig5
38+
ax = plt.hist(samples, bins=100)
39+
# add vertical line at the mean
40+
plt.axvline(mean, color="gray", linestyle="dashed", linewidth=1)
41+
# annotate the mean line
42+
plt.annotate(
43+
f"{mean = :.2f}",
44+
xy=(mean, 1),
45+
# use ax coords for y
46+
xycoords=("data", "axes fraction"),
47+
# add text offset
48+
xytext=(10, -20),
49+
textcoords="offset points",
50+
)

site/src/routes/about-the-test-set/+page.svelte

+3-3
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@
6969
<TableInset slot="inset" grid_row="3">
7070
{#if active_mp_elem?.name}
7171
<strong>
72-
{active_mp_elem?.name}: {pretty_num(wbm_elem_counts[active_mp_elem?.symbol])}
72+
{active_mp_elem?.name}: {pretty_num(mp_elem_counts[active_mp_elem?.symbol])}
7373
<!-- compute percent of total -->
74-
{#if wbm_elem_counts[active_mp_elem?.symbol] > 0}
74+
{#if mp_elem_counts[active_mp_elem?.symbol] > 0}
7575
{@const total = wbm_heat_vals.reduce((a, b) => a + b, 0)}
76-
({pretty_num((wbm_elem_counts[active_mp_elem?.symbol] / total) * 100)}%)
76+
({pretty_num((mp_elem_counts[active_mp_elem?.symbol] / total) * 100)}%)
7777
{/if}
7878
</strong>
7979
{/if}

0 commit comments

Comments
 (0)