Skip to content

Commit d1751a3

Browse files
committed
add 2023-01-18-e-form-scatter-models.png to paper
pnpm add -D svelte-preprocess-import-assets (used in svelte.config.js) decrease katex font-size to 10pt rename scripts/scatter_e_above_hull_models.py use subscripts in plotly template quantity_labels
1 parent 0a2284a commit d1751a3

File tree

8 files changed

+185
-15
lines changed

8 files changed

+185
-15
lines changed

matbench_discovery/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
ROOT = os.path.dirname(os.path.dirname(__file__)) # repository root
1010
FIGS = f"{ROOT}/site/static/figs" # directory to store figures
11+
PAPER = f"{ROOT}/site/src/routes/paper/figs" # directory to store figures
1112
# whether a currently running slurm job is in debug mode
1213
DEBUG = "DEBUG" in os.environ or (
1314
"slurm-submit" not in sys.argv and "SLURM_JOB_ID" not in os.environ

matbench_discovery/plots.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@
3232
n_wyckoff="Number of Wyckoff positions",
3333
n_sites="Lattice site count",
3434
energy_per_atom="Energy (eV/atom)",
35-
e_form="Formation energy (eV/atom)",
36-
e_above_hull="Energy above convex hull (eV/atom)",
37-
e_above_hull_pred="Predicted energy above convex hull (eV/atom)",
38-
e_above_hull_mp="Energy above MP convex hull (eV/atom)",
39-
e_above_hull_error="Error in energy above convex hull (eV/atom)",
35+
e_form="Actual E<sub>form</sub> (eV/atom)",
36+
e_above_hull="E<sub>above hull</sub> (eV/atom)",
37+
e_above_hull_mp2020_corrected_ppd_mp="Actual E<sub>above hull</sub> (eV/atom)",
38+
e_above_hull_pred="Predicted E<sub>above hull</sub> (eV/atom)",
39+
e_above_hull_mp="E<sub>above MP hull</sub> (eV/atom)",
40+
e_above_hull_error="Error in E<sub>above hull</sub> (eV/atom)",
4041
vol_diff="Volume difference (A^3)",
41-
e_form_per_atom_mp2020_corrected="Formation energy (eV/atom)",
42-
e_form_per_atom_pred="Predicted formation energy (eV/atom)",
42+
e_form_per_atom_mp2020_corrected="Actual E<sub>form</sub> (eV/atom)",
43+
e_form_per_atom_pred="Predicted E<sub>form</sub> (eV/atom)",
4344
material_id="Material ID",
4445
band_gap="Band gap (eV)",
4546
formula="Formula",
@@ -60,6 +61,7 @@
6061
margin=dict(l=30, r=20, t=60, b=20),
6162
paper_bgcolor="rgba(0,0,0,0)",
6263
# plot_bgcolor="rgba(0,0,0,0)",
64+
font_size=15,
6365
)
6466
pio.templates["global"] = dict(layout=global_layout)
6567
pio.templates.default = "plotly_dark+global"
+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# %%
2+
import numpy as np
3+
import plotly.graph_objects as go
4+
from pymatviz.utils import add_identity_line, save_fig
5+
from sklearn.metrics import r2_score
6+
7+
from matbench_discovery import FIGS, PAPER, today
8+
from matbench_discovery.data import PRED_FILENAMES, load_df_wbm_with_preds
9+
from matbench_discovery.energy import classify_stable
10+
from matbench_discovery.plots import px
11+
12+
__author__ = "Janosh Riebesell"
13+
__date__ = "2022-11-28"
14+
15+
16+
# %%
17+
print(f"loadable models: {list(PRED_FILENAMES)}")
18+
models = sorted(
19+
"CGCNN, Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet".split(", ")
20+
)
21+
df_wbm = load_df_wbm_with_preds(models=models).round(3)
22+
23+
e_form_col = "e_form_per_atom_mp2020_corrected"
24+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
25+
id_col = "material_id"
26+
legend = dict(x=1, y=0, xanchor="right", yanchor="bottom", title=None)
27+
28+
29+
# %%
30+
e_form_preds = "e_form_per_atom_pred"
31+
e_above_hull_preds = "e_above_hull_pred"
32+
var_name = "Model"
33+
hover_cols = (id_col, e_form_col, e_above_hull_col, "formula")
34+
35+
df_melt = df_wbm.melt(
36+
id_vars=hover_cols,
37+
value_vars=models,
38+
var_name=var_name,
39+
value_name=e_form_preds,
40+
)
41+
42+
df_melt[e_above_hull_preds] = (
43+
df_melt[e_above_hull_col] - df_melt[e_form_col] + df_melt[e_form_preds]
44+
)
45+
46+
47+
# %%
48+
def _metric_str(model_name: str) -> str:
49+
MAE = (df_wbm[e_form_col] - df_wbm[model_name]).abs().mean()
50+
R2 = r2_score(*df_wbm[[e_form_col, model_name]].dropna().to_numpy().T)
51+
return f"{model_name} · {MAE=:.2} · R<sup>2</sup>={R2:.2}"
52+
53+
54+
def _add_metrics_to_legend(fig: go.Figure) -> None:
55+
for trace in fig.data:
56+
# initially hide all traces, let users select which models to compare
57+
trace.visible = "legendonly"
58+
# add MAE and R2 to legend
59+
model = trace.name
60+
trace.name = _metric_str(model)
61+
62+
63+
# %% scatter plot of actual vs predicted e_form_per_atom
64+
fig = px.scatter(
65+
df_melt.iloc[::10],
66+
x=e_form_col,
67+
y=e_form_preds,
68+
color=var_name,
69+
hover_data=hover_cols,
70+
hover_name=id_col,
71+
)
72+
73+
_add_metrics_to_legend(fig)
74+
fig.update_layout(legend=legend)
75+
add_identity_line(fig)
76+
fig.show()
77+
78+
79+
# %%
80+
img_path = f"{FIGS}/{today}-e-form-scatter-models"
81+
# fig.write_image(f"{img_path}.pdf")
82+
save_fig(fig, f"{img_path}.svelte")
83+
84+
85+
# %% scatter plot of actual vs predicted e_above_hull
86+
fig = px.scatter(
87+
df_melt.iloc[::10],
88+
x=e_above_hull_col,
89+
y=e_above_hull_preds,
90+
color=var_name,
91+
hover_data=hover_cols,
92+
hover_name=id_col,
93+
)
94+
95+
_add_metrics_to_legend(fig)
96+
fig.update_layout(legend=legend)
97+
add_identity_line(fig)
98+
fig.show()
99+
100+
101+
# %%
102+
img_path = f"{FIGS}/{today}-e-above-hull-scatter-models"
103+
# fig.write_image(f"{img_path}.pdf")
104+
save_fig(fig, f"{img_path}.svelte")
105+
106+
107+
# %% plot all models in separate subplots
108+
true_pos, false_neg, false_pos, true_neg = classify_stable(
109+
df_melt[e_above_hull_col], df_melt[e_above_hull_preds]
110+
)
111+
112+
df_melt["clf"] = np.array(
113+
classes := ["true positive", "false negative", "false positive", "true negative"]
114+
)[true_pos * 0 + false_neg * 1 + false_pos * 2 + true_neg * 3]
115+
116+
fig = px.scatter(
117+
df_melt.iloc[::10],
118+
x=e_above_hull_col,
119+
y=e_above_hull_preds,
120+
facet_col=var_name,
121+
facet_col_wrap=3,
122+
hover_data=hover_cols,
123+
hover_name=id_col,
124+
color="clf",
125+
color_discrete_map=dict(zip(classes, ("green", "yellow", "red", "blue"))),
126+
opacity=0.4,
127+
)
128+
129+
# iterate over subplots and set new title
130+
for idx, model in enumerate(models, 1):
131+
132+
# add MAE and R2 to subplot title
133+
MAE = (df_wbm[e_form_col] - df_wbm[model]).abs().mean()
134+
R2 = r2_score(*df_wbm[[e_form_col, model]].dropna().to_numpy().T)
135+
# find index of annotation belonging to model
136+
anno_idx = [a.text for a in fig.layout.annotations].index(f"Model={model}")
137+
assert anno_idx >= 0, f"could not find annotation for {model}"
138+
# set new title
139+
fig.layout.annotations[anno_idx].text = _metric_str(model)
140+
# remove x and y axis titles if not on center row or center column
141+
if idx != 2:
142+
fig.layout[f"xaxis{idx}"].title.text = ""
143+
if idx > 1:
144+
fig.layout[f"yaxis{idx}"].title.text = ""
145+
# add vertical and horizontal lines at 0
146+
fig.add_vline(x=0, line=dict(width=1, dash="dash", color="gray"))
147+
fig.add_hline(y=0, line=dict(width=1, dash="dash", color="gray"))
148+
149+
id_line = add_identity_line(fig, ret_shape=True)
150+
fig.update_layout(showlegend=False)
151+
fig.update_xaxes(nticks=5)
152+
fig.update_yaxes(nticks=5)
153+
154+
fig.show()
155+
img_path = f"{PAPER}/{today}-e-form-scatter-models.png"
156+
save_fig(fig, img_path, scale=4, width=1000, height=500)

site/package.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"preview": "vite preview",
1414
"serve": "vite build && vite preview",
1515
"check": "svelte-check",
16-
"make-api-docs": "cd .. && python ../scripts/make_api_docs.py"
16+
"make-api-docs": "cd .. && python scripts/make_api_docs.py"
1717
},
1818
"devDependencies": {
1919
"@iconify/svelte": "^3.0.1",
@@ -37,6 +37,7 @@
3737
"svelte": "^3.55.1",
3838
"svelte-check": "^3.0.2",
3939
"svelte-preprocess": "^5.0.0",
40+
"svelte-preprocess-import-assets": "^0.2.5",
4041
"svelte-toc": "^0.5.2",
4142
"svelte-zoo": "^0.2.1",
4243
"svelte2tsx": "^0.6.0",

site/src/app.css

+4
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,7 @@ caption {
150150
display: block;
151151
margin: 1em auto 2em;
152152
}
153+
154+
.math {
155+
font-size: 10pt;
156+
}

site/src/routes/+error.svelte

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<script lang="ts">
22
import { page } from '$app/stores'
3+
import { homepage, name } from '$site/package.json'
34
import Icon from '@iconify/svelte'
4-
import { homepage, name } from '../../package.json'
55
66
let online: boolean
77
</script>
@@ -29,7 +29,7 @@
2929

3030
<p>
3131
Back to <a href=".">
32-
<img src="/favicon.svg" alt={name} height="30" />
32+
<img src="$static/favicon.svg" alt={name} height="30" />
3333
landing page
3434
</a>.
3535
</p>

site/src/routes/paper/+page.svx

+4-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ geometry: margin=3cm # https://stackoverflow.com/a/13516042
5252
import MetricsTable from '$figs/2022-11-28-metrics-table.svelte'
5353
import { references } from './references.yaml'
5454
import { References } from '$lib'
55-
import './heading-number.css' // uncomment to remove heading numbers
55+
import './heading-number.css' // CSS to auto-number headings
5656
</script>
5757

5858
# {title}<br><small>{subtitle}</small>
@@ -180,13 +180,15 @@ Our benchmark is designed to make [adding future models easy](/how-to-contribute
180180
<MetricsTable />
181181
</div>
182182

183+
![Scatter plots for each model's energy above hull predictions vs DFT ground truth](./figs/2023-01-18-e-form-scatter-models.png)
184+
183185
## Analysis
184186

185187
## Conclusion
186188

187189
## Acknowledgements
188190

189-
JR acknowledges support from the German Academic Scholarship Foundation (Studienstiftung) and gracious hosting as a visiting affiliate in the group of KP.
191+
JR acknowledges support from the German Academic Scholarship Foundation (Studienstiftung) and gracious hosting as a visiting affiliate in the groups of [KP](https://perssongroup.lbl.gov/people) and [AJ](https://hackingmaterials.lbl.gov).
190192

191193
## References
192194

site/svelte.config.js

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import katex from 'rehype-katex-svelte'
66
import heading_slugs from 'rehype-slug'
77
import math from 'remark-math'
88
import preprocess from 'svelte-preprocess'
9+
import assets from 'svelte-preprocess-import-assets'
910

1011
const rehypePlugins = [
1112
katex,
@@ -35,13 +36,14 @@ export default {
3536

3637
preprocess: [
3738
{
38-
// preprocess markdown citations @auth_1st-word-title_yyyy into superscript
39-
// links to bibliography items, href must match References.svelte
4039
markup: (file) => {
4140
if (file.filename.endsWith(`paper/+page.svx`)) {
41+
// preprocess markdown citations @auth_1st-word-title_yyyy into superscript
42+
// links to bibliography items, href must match id format in References.svelte
4243
const code = file.content.replace(
4344
/@((.+?)_.+?_(\d{4}))/g,
44-
`<sup><a href="#$1">$2 $3</a></sup>`
45+
(_full_str, bib_id, author, year) =>
46+
`<sup><a href="#${bib_id}">${author} ${year}</a></sup>`
4547
)
4648
return { code }
4749
}
@@ -57,6 +59,7 @@ export default {
5759
remarkPlugins: [math],
5860
extensions: [`.svx`, `.md`],
5961
}),
62+
assets(),
6063
],
6164

6265
kit: {
@@ -65,6 +68,7 @@ export default {
6568
alias: {
6669
$site: `.`,
6770
$root: `..`,
71+
$static: `./static`,
6872
$figs: `./static/figs`,
6973
},
7074
},

0 commit comments

Comments
 (0)