Skip to content

Commit 5df80ef

Browse files
committed
use bin counts directly (no KDE) in hull dist density scatter plot
add rolling-mae-vs-hull-dist-wbm-batches-{alignn,bowsr,mace,voronoi-rf}.svelte fix hull-dist-scatter-wrenformer-failures.pdf (used wrong input dataframe df_preds vs df_each_pred) add ref wang_framework_2021 (MP2020 correction scheme)
1 parent d9bb043 commit 5df80ef

21 files changed

+171
-90
lines changed

data/wbm/eda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
save_fig(ptable, f"{PDF_FIGS}/{dataset}-element-{count_mode}-counts.pdf")
9494

9595

96-
# %% histogram of energy above MP convex hull for WBM
96+
# %% histogram of energy distance to MP convex hull for WBM
9797
col = each_true_col # or e_form_col
9898
mean, std = df_wbm[col].mean(), df_wbm[col].std()
9999

matbench_discovery/plots.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -521,12 +521,14 @@ def rolling_mae_vs_hull_dist(
521521
fig.set(xlim=x_lim, ylim=y_lim)
522522
line_styles = "- -- -. :".split()
523523
markers = "o s ^ v D * p X".split()
524-
combinations = [(ls, mark) for mark in markers for ls in line_styles]
525524

526525
for idx, line in enumerate(fig.lines):
527-
ls, marker = combinations[idx % len(combinations)]
526+
line_label = line.get_label()
527+
if line_label.startswith("_"):
528+
continue
529+
ls, marker = line_styles[idx], markers[idx]
528530
line.set(ls=ls, marker=marker, markeredgewidth=0.5, markeredgecolor="black")
529-
line.set_markevery(4)
531+
line.set_markevery(8)
530532

531533
elif backend == "plotly":
532534
for idx, model in enumerate(df_rolling_err if with_sem else []):
@@ -614,14 +616,16 @@ def rolling_mae_vs_hull_dist(
614616
)
615617
fig.add_shape(type="rect", x0=x0, y0=y0, x1=x0 - window, y1=y0 + window / 5)
616618

617-
line_styles = "solid dash dot dashdot".split()
618-
markers = "circle square triangle-up triangle-down diamond cross star x".split()
619619
from matbench_discovery.preds import model_styles
620620

621-
for trace in fig.data:
621+
for idx, trace in enumerate(fig.data):
622622
if style := model_styles.get(trace.name):
623623
ls, _marker, color = style
624624
trace.line = dict(color=color, dash=ls, width=2)
625+
else:
626+
trace.line = dict(
627+
color=plotly_colors[idx], dash=plotly_line_styles[idx], width=3
628+
)
625629
# marker_spacing = 2
626630
# trace = go.Scatter(
627631
# x=trace.x[::marker_spacing],

models/wrenformer/analyze_wrenformer.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33

44
# %%
5+
import numpy as np
56
import pandas as pd
67
from aviary.wren.utils import get_isopointal_proto_from_aflow
78
from pymatviz import spacegroup_hist, spacegroup_sunburst
89
from pymatviz.ptable import ptable_heatmap_plotly
9-
from pymatviz.utils import save_fig
10+
from pymatviz.utils import add_identity_line, bin_df_cols, save_fig
1011

1112
from matbench_discovery import PDF_FIGS, SITE_FIGS
1213
from matbench_discovery.data import DATA_FILES, df_wbm
@@ -80,10 +81,14 @@
8081

8182
# %%
8283
fig = spacegroup_sunburst(df_bad[spg_col], width=350, height=350)
83-
fig.layout.title.update(text=f"Spacegroup sunburst for {title}", x=0.5, font_size=14)
84+
# fig.layout.title.update(text=f"Spacegroup sunburst for {title}", x=0.5, font_size=14)
85+
fig.layout.margin.update(l=1, r=1, t=1, b=1)
8486
fig.show()
87+
88+
89+
# %%
8590
save_fig(fig, f"{PDF_FIGS}/spacegroup-sunburst-{model.lower()}-failures.pdf")
86-
# save_fig(fig, f"{FIGS}/spacegroup-sunburst-{model}-failures.svelte")
91+
save_fig(fig, f"{SITE_FIGS}/spacegroup-sunburst-{model}-failures.svelte")
8792

8893

8994
# %%
@@ -92,3 +97,46 @@
9297
fig.layout.margin = dict(l=0, r=0, t=50, b=0)
9398
fig.show()
9499
save_fig(fig, f"{PDF_FIGS}/elements-{model.lower()}-failures.pdf")
100+
101+
102+
# %%
103+
model = "Wrenformer"
104+
cols = [model, each_true_col]
105+
bin_cnt_col = "bin counts"
106+
df_bin = bin_df_cols(
107+
df_each_pred, [each_true_col, model], n_bins=200, bin_counts_col=bin_cnt_col
108+
)
109+
log_cnt_col = f"log {bin_cnt_col}"
110+
df_bin[log_cnt_col] = np.log1p(df_bin[bin_cnt_col]).round(2)
111+
112+
113+
# %%
114+
fig = df_bin.reset_index().plot.scatter(
115+
x=each_true_col,
116+
y=model,
117+
hover_data=cols,
118+
hover_name=df_preds.index.name,
119+
backend="plotly",
120+
color=log_cnt_col,
121+
color_continuous_scale="turbo",
122+
)
123+
124+
# title = "Analysis of Wrenformer failure cases in the highlighted rectangle"
125+
# fig.layout.title.update(text=title, x=0.5)
126+
fig.layout.margin.update(l=0, r=0, t=0, b=0)
127+
fig.layout.legend.update(title="", x=1, y=0, xanchor="right")
128+
add_identity_line(fig)
129+
fig.layout.coloraxis.colorbar.update(
130+
x=1, y=0.5, xanchor="right", thickness=12, title=""
131+
)
132+
# add shape shaded rectangle at x < 1, y > 1
133+
fig.add_shape(
134+
type="rect", **dict(x0=1, y0=1, x1=-1, y1=6), fillcolor="gray", opacity=0.2
135+
)
136+
fig.show()
137+
138+
139+
# %%
140+
img_name = "hull-dist-scatter-wrenformer-failures"
141+
save_fig(fig, f"{SITE_FIGS}/{img_name}.svelte")
142+
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf", width=600, height=300)

scripts/model_figs/scatter_hull_dist_models.py

+19-58
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,10 @@
99

1010
import numpy as np
1111
import plotly.express as px
12-
import scipy.stats
1312
from pymatviz.utils import add_identity_line, bin_df_cols, save_fig
14-
from tqdm import tqdm
1513

1614
from matbench_discovery import PDF_FIGS, SITE_FIGS
17-
from matbench_discovery.metrics import classify_stable
18-
from matbench_discovery.plots import clf_color_map, clf_colors, clf_labels
15+
from matbench_discovery.plots import clf_colors
1916
from matbench_discovery.preds import (
2017
df_metrics,
2118
df_preds,
@@ -47,22 +44,28 @@
4744
df_melt[each_pred_col] = (
4845
df_melt[each_true_col] + df_melt[e_form_pred_col] - df_melt[e_form_col]
4946
)
50-
df_bin = bin_df_cols(df_melt, [each_true_col, each_pred_col], [facet_col], n_bins=200)
47+
df_bin = bin_df_cols(
48+
df_melt,
49+
bin_by_cols=[each_true_col, each_pred_col],
50+
group_by_cols=[facet_col],
51+
n_bins=200,
52+
bin_counts_col=(bin_cnt_col := "bin counts"),
53+
)
5154
df_bin = df_bin.reset_index()
5255

5356
# sort legend and facet plots by MAE
5457
legend_order = list(df_metrics.T.MAE.sort_values().index)
5558

5659

5760
# determine each point's classification to color them by
58-
true_pos, false_neg, false_pos, true_neg = classify_stable(
59-
df_bin[each_true_col], df_bin[each_pred_col]
60-
)
61-
62-
clf_col = "classified"
63-
df_bin[clf_col] = np.array(clf_labels)[
64-
true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
65-
]
61+
# now unused, can be used to color points by TP/FP/TN/FN
62+
# true_pos, false_neg, false_pos, true_neg = classify_stable(
63+
# df_bin[each_true_col], df_bin[each_pred_col]
64+
# )
65+
# clf_col = "classified"
66+
# df_bin[clf_col] = np.array(clf_labels)[
67+
# true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3
68+
# ]
6669

6770

6871
# %% scatter plot of actual vs predicted e_form_per_atom
@@ -121,21 +124,8 @@
121124

122125

123126
# %%
124-
clr_col, cnt_col = "density", "counts"
125-
# compute KDE for each model's predictions separately
126-
for model in (pbar := tqdm(models)):
127-
pbar.set_description(f"KDE for {model=}")
128-
129-
xy = df_preds[[each_true_col, model]].dropna().T
130-
model_kde = scipy.stats.gaussian_kde(xy)
131-
132-
model_rows = df_bin[df_bin[facet_col] == model]
133-
xy_binned = model_rows[[each_true_col, each_pred_col]].T
134-
density = model_kde(xy_binned)
135-
n_preds = len(df_preds[model].dropna())
136-
df_bin.loc[model_rows.index, cnt_col] = density / density.sum() * n_preds
137-
138-
df_bin[clr_col] = np.log1p(df_bin[cnt_col]).round(2)
127+
log_bin_cnt_col = f"log {bin_cnt_col}"
128+
df_bin[log_bin_cnt_col] = np.log1p(df_bin[bin_cnt_col]).round(2)
139129

140130

141131
# %% scatter plot of DFT vs predicted hull distance with each model in separate subplot
@@ -148,7 +138,7 @@
148138
y=each_pred_col,
149139
facet_col=facet_col,
150140
facet_col_wrap=n_cols,
151-
color=clr_col,
141+
color=log_bin_cnt_col,
152142
facet_col_spacing=0.02,
153143
facet_row_spacing=0.04,
154144
hover_data=hover_cols,
@@ -259,32 +249,3 @@
259249
fig_name = f"each-scatter-models-{n_rows}x{n_cols}"
260250
save_fig(fig, f"{SITE_FIGS}/{fig_name}.svelte")
261251
save_fig(fig, f"{PDF_FIGS}/{fig_name}.pdf")
262-
263-
264-
# %%
265-
model = "Wrenformer"
266-
fig = px.scatter(
267-
df_bin.query(f"{facet_col} == {model!r}"),
268-
x=each_true_col,
269-
y=each_pred_col,
270-
hover_data=hover_cols,
271-
color=clf_col,
272-
color_discrete_map=clf_color_map,
273-
hover_name=df_preds.index.name,
274-
opacity=0.7,
275-
)
276-
277-
title = "Analysis of Wrenformer failure cases in the highlighted rectangle"
278-
fig.layout.title.update(text=title, x=0.5)
279-
fig.layout.legend.update(title="", x=1, y=0, xanchor="right")
280-
add_identity_line(fig)
281-
282-
# add shape shaded rectangle at x < 1, y > 1
283-
fig.add_shape(
284-
type="rect", **dict(x0=1, y0=1, x1=-1, y1=6), fillcolor="gray", opacity=0.2
285-
)
286-
fig.show()
287-
288-
img_name = "hull-dist-scatter-wrenformer-failures"
289-
# save_fig(fig, f"{FIGS}/{img_name}.svelte")
290-
save_fig(fig, f"{PDF_FIGS}/{img_name}.pdf")

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99
from matbench_discovery import PDF_FIGS, SITE_FIGS, today
1010
from matbench_discovery.plots import plt, rolling_mae_vs_hull_dist
11-
from matbench_discovery.preds import df_each_pred, df_preds, e_form_col, each_true_col
11+
from matbench_discovery.preds import (
12+
df_each_pred,
13+
df_preds,
14+
e_form_col,
15+
each_true_col,
16+
models,
17+
)
1218

1319
__author__ = "Rhys Goodall, Janosh Riebesell"
1420
__date__ = "2022-06-18"
@@ -39,7 +45,7 @@
3945
markevery=20,
4046
markerfacecolor="white",
4147
markeredgewidth=2.5,
42-
backend="matplotlib",
48+
backend="matplotlib", # don't change, code here not plotly compatible
4349
ax=ax,
4450
just_plot_lines=idx > 1,
4551
pbar=False,
@@ -54,7 +60,7 @@
5460

5561

5662
# %% plotly
57-
for model in list(df_each_pred)[:-2]:
63+
for model in models:
5864
df_pivot = df_each_pred.pivot(columns=batch_col, values=model)
5965

6066
fig, df_err, df_std = rolling_mae_vs_hull_dist(
@@ -66,9 +72,11 @@
6672
show_dummy_mae=False,
6773
with_sem=False,
6874
)
69-
fig.layout.legend.update(title=f"<b>{model}</b>", x=0.02, y=0.02)
75+
fig.layout.legend.update(
76+
title=f"<b>{model}</b>", x=0.02, y=0.02, bgcolor="rgba(0,0,0,0)"
77+
)
7078
fig.layout.margin.update(l=10, r=10, b=10, t=10)
71-
fig.update_layout(hovermode="x unified", hoverlabel_bgcolor="black")
79+
fig.layout.update(hovermode="x unified", hoverlabel_bgcolor="black")
7280
fig.update_traces(
7381
hovertemplate="y=%{y:.3f} eV",
7482
selector=lambda trace: trace.name.startswith("Batch"),
@@ -78,4 +86,4 @@
7886
model_snake_case = model.lower().replace(" + ", "-").replace(" ", "-")
7987
img_path = f"rolling-mae-vs-hull-dist-wbm-batches-{model_snake_case}"
8088
save_fig(fig, f"{SITE_FIGS}/{img_path}.svelte")
81-
save_fig(fig, f"{PDF_FIGS}/{img_path}.pdf")
89+
save_fig(fig, f"{PDF_FIGS}/{img_path}.pdf", width=500, height=330)

site/src/figs/each-scatter-models-5x2.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/hull-dist-scatter-wrenformer-failures.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-alignn.svelte

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-bowsr.svelte

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-cgcnn+p.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-cgcnn.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-chgnet.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-m3gnet.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-mace.svelte

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-mean-prediction-all-models.svelte

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-megnet.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-voronoi-rf.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/rolling-mae-vs-hull-dist-wbm-batches-wrenformer.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/figs/spacegroup-sunburst-wrenformer-failures.svelte

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

site/src/routes/preprint/references.yaml

+56
Original file line numberDiff line numberDiff line change
@@ -2648,6 +2648,62 @@ references:
26482648
URL: https://www.nature.com/articles/s41467-020-18556-9
26492649
volume: '11'
26502650

2651+
- id: wang_framework_2021
2652+
abstract: >-
2653+
In this work, we demonstrate a method to quantify uncertainty in corrections
2654+
to density functional theory (DFT) energies based on empirical results. Such
2655+
corrections are commonly used to improve the accuracy of computational
2656+
enthalpies of formation, phase stability predictions, and other
2657+
energy-derived properties, for example. We incorporate this method into a
2658+
new DFT energy correction scheme comprising a mixture of oxidation-state and
2659+
composition-dependent corrections and show that many chemical systems
2660+
contain unstable polymorphs that may actually be predicted stable when
2661+
uncertainty is taken into account. We then illustrate how these
2662+
uncertainties can be used to estimate the probability that a compound is
2663+
stable on a compositional phase diagram, thus enabling better-informed
2664+
assessments of compound stability.
2665+
accessed:
2666+
- year: 2023
2667+
month: 8
2668+
day: 28
2669+
author:
2670+
- family: Wang
2671+
given: Amanda
2672+
- family: Kingsbury
2673+
given: Ryan
2674+
- family: McDermott
2675+
given: Matthew
2676+
- family: Horton
2677+
given: Matthew
2678+
- family: Jain
2679+
given: Anubhav
2680+
- family: Ong
2681+
given: Shyue Ping
2682+
- family: Dwaraknath
2683+
given: Shyam
2684+
- family: Persson
2685+
given: Kristin A.
2686+
citation-key: wang_framework_2021
2687+
container-title: Scientific Reports
2688+
container-title-short: Sci Rep
2689+
DOI: 10.1038/s41598-021-94550-5
2690+
ISSN: 2045-2322
2691+
issue: '1'
2692+
issued:
2693+
- year: 2021
2694+
month: 7
2695+
day: 29
2696+
language: en
2697+
license: 2021 The Author(s)
2698+
number: '1'
2699+
page: '15496'
2700+
publisher: Nature Publishing Group
2701+
source: www.nature.com
2702+
title: A framework for quantifying uncertainty in DFT energy corrections
2703+
type: article-journal
2704+
URL: https://www.nature.com/articles/s41598-021-94550-5
2705+
volume: '11'
2706+
26512707
- id: wang_predicting_2021
26522708
abstract: >-
26532709
We propose an efficient high-throughput scheme for the discovery of stable

site/src/routes/si/+page.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ The figures below show the rolling MAE as a function of distance to the convex h
8181
{#if mounted}
8282

8383
<div style="display: grid; grid-template-columns: 1fr 1fr; margin: 0 -1em 0 -4em;">
84-
<M3gnetRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
85-
<CHGNetRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
86-
<WrenformerRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
87-
<MegnetRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
88-
<VoronoiRfRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
89-
<CgcnnRollingMaeBatches style="margin: -2em 0 0; height: 400px;" />
84+
<M3gnetRollingMaeBatches style="aspect-ratio: 1.2;" />
85+
<CHGNetRollingMaeBatches style="aspect-ratio: 1.2;" />
86+
<WrenformerRollingMaeBatches style="aspect-ratio: 1.2;" />
87+
<MegnetRollingMaeBatches style="aspect-ratio: 1.2;" />
88+
<VoronoiRfRollingMaeBatches style="aspect-ratio: 1.2;" />
89+
<CgcnnRollingMaeBatches style="aspect-ratio: 1.2;" />
9090
</div>
9191
{/if}
9292

0 commit comments

Comments
 (0)