Skip to content

Commit 2f6fae3

Browse files
committed
add mdsvex custom img component
used in paper via custom mdsvex layout site/src/routes/paper/Layout.svelte fix 2023-01-08-wbm-elements.svg (was identical to mp-elements.svg)
1 parent d1751a3 commit 2f6fae3

25 files changed

+229
-135
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ repos:
7272
- prettier
7373
- prettier-plugin-svelte
7474
- svelte
75-
exclude: ^(site/static/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yml|yaml|json))$
75+
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yml|yaml|json))$
7676

7777
- repo: https://github.com/pre-commit/mirrors-eslint
7878
rev: v8.31.0

data/wbm/analysis.py

+29-31
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from pymatviz import count_elements, ptable_heatmap_plotly
66
from pymatviz.utils import save_fig
77

8-
from matbench_discovery import ROOT, today
8+
from matbench_discovery import FIGS, today
9+
from matbench_discovery.data import df_wbm
910

1011
module_dir = os.path.dirname(__file__)
1112

@@ -15,71 +16,68 @@
1516

1617

1718
# %%
18-
df_summary = pd.read_csv(f"{module_dir}/2022-10-19-wbm-summary.csv").set_index(
19-
"material_id"
20-
)
21-
elem_counts = count_elements(df_summary.formula).astype(int)
19+
wbm_elem_counts = count_elements(df_wbm.formula).astype(int)
2220

23-
elem_counts.to_json(
24-
f"{ROOT}/site/src/routes/about-the-test-set/{today}-wbm-element-counts.json"
25-
)
21+
# wbm_elem_counts.to_json(
22+
# f"{ROOT}/site/src/routes/about-the-test-set/{today}-wbm-element-counts.json"
23+
# )
2624

2725

2826
# %%
29-
fig = ptable_heatmap_plotly(
30-
elem_counts,
27+
wbm_fig = ptable_heatmap_plotly(
28+
wbm_elem_counts.drop("Xe"),
3129
log=True,
32-
colorscale="YlGnBu",
30+
colorscale="RdBu",
3331
hover_props=dict(atomic_number="atomic number"),
34-
hover_data=elem_counts,
35-
font_size="1vw",
32+
hover_data=wbm_elem_counts,
3633
)
3734

3835
title = "WBM Elements"
39-
fig.update_layout(
36+
wbm_fig.update_layout(
4037
title=dict(text=title, x=0.35, y=0.9, font_size=20),
4138
xaxis=dict(fixedrange=True),
4239
yaxis=dict(fixedrange=True),
4340
paper_bgcolor="rgba(0,0,0,0)",
4441
)
45-
fig.show()
42+
wbm_fig.show()
4643

4744

4845
# %%
49-
fig.write_image(f"{module_dir}/{today}-wbm-elements.svg", width=1000, height=500)
50-
save_fig(fig, f"{module_dir}/{today}-wbm-elements.svelte")
46+
wbm_fig.write_image(
47+
f"{module_dir}/figs/{today}-wbm-elements.svg", width=1000, height=500
48+
)
49+
save_fig(wbm_fig, f"{FIGS}/{today}-wbm-elements.svelte")
5150

5251

5352
# %% load MP training set
5453
df = pd.read_json(f"{module_dir}/../mp/2022-08-13-mp-energies.json.gz")
55-
elem_counts = count_elements(df.formula_pretty).astype(int)
54+
mp_elem_counts = count_elements(df.formula_pretty).astype(int)
5655

57-
elem_counts.to_json(
58-
f"{ROOT}/site/src/routes/about-the-test-set/{today}-mp-element-counts.json"
59-
)
60-
elem_counts.describe()
56+
# mp_elem_counts.to_json(
57+
# f"{ROOT}/site/src/routes/about-the-test-set/{today}-mp-element-counts.json"
58+
# )
59+
mp_elem_counts.describe()
6160

6261

6362
# %%
64-
fig = ptable_heatmap_plotly(
65-
elem_counts[elem_counts > 1],
63+
mp_fig = ptable_heatmap_plotly(
64+
mp_elem_counts[mp_elem_counts > 1],
6665
log=True,
67-
colorscale="YlGnBu",
66+
colorscale="RdBu",
6867
hover_props=dict(atomic_number="atomic number"),
69-
hover_data=elem_counts,
70-
font_size="1vw",
68+
hover_data=mp_elem_counts,
7169
)
7270

7371
title = "MP Elements"
74-
fig.update_layout(
72+
mp_fig.update_layout(
7573
title=dict(text=title, x=0.35, y=0.9, font_size=20),
7674
xaxis=dict(fixedrange=True),
7775
yaxis=dict(fixedrange=True),
7876
paper_bgcolor="rgba(0,0,0,0)",
7977
)
80-
fig.show()
78+
mp_fig.show()
8179

8280

8381
# %%
84-
fig.write_image(f"{module_dir}/{today}-mp-elements.svg", width=1000, height=500)
85-
save_fig(fig, f"{module_dir}/{today}-mp-elements.svelte")
82+
mp_fig.write_image(f"{module_dir}/figs/{today}-mp-elements.svg", width=1000, height=500)
83+
# save_fig(mp_fig, f"{FIGS}/{today}-mp-elements.svelte")

data/wbm/fetch_process_wbm_dataset.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,9 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
433433
n_too_unstable = sum(df_summary.e_form_per_atom_wbm > e_form_cutoff)
434434
print(f"{n_too_unstable = }") # n_too_unstable = 22
435435

436-
fig = df_summary.hist(x="e_form_per_atom_wbm", backend="plotly", log_y=True)
436+
fig = df_summary.hist(
437+
x="e_form_per_atom_wbm", backend="plotly", log_y=True, range_x=[-5.5, 5.5]
438+
)
437439
fig.add_vline(x=e_form_cutoff, line=dict(width=2, dash="dash", color="green"))
438440
fig.add_vline(x=-e_form_cutoff, line=dict(width=2, dash="dash", color="green"))
439441
fig.add_annotation(
@@ -443,15 +445,27 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
443445
)
444446
x_axis_title = "WBM uncorrected formation energy (eV/atom)"
445447
fig.update_layout(xaxis_title=x_axis_title, margin=dict(l=10, r=10, t=40, b=10))
448+
# disabling zooming y-axis
449+
fig.update_yaxes(fixedrange=True)
450+
fig.show(
451+
config=dict(
452+
modeBarButtonsToRemove=["lasso2d", "select2d", "autoScale2d", "toImage"],
453+
displaylogo=False,
454+
)
455+
)
446456

447457

448458
# %%
449459
# no need to store all 250k x values in plot, leads to 1.7 MB file, subsample every 10th
450460
# point is enough to see the distribution
451-
fig.data[0].x = fig.data[0].x[::10]
461+
if not fig.data[0].compressed:
462+
fig.data[0].compressed = True
463+
# keep only every 10th data point, round to 3 decimal places to reduce file size
464+
fig.data[0].x = [round(x, 3) for x in fig.data[0].x[::10]]
465+
452466
# recommended to upload SVG to vecta.io/nano afterwards for compression
453-
img_path = f"{module_dir}/{today}-hist-e-form-per-atom"
454-
save_fig(fig, f"{img_path}.svg", width=800, height=300)
467+
img_path = f"{module_dir}/2022-12-07-hist-e-form-per-atom"
468+
# save_fig(fig, f"{img_path}.svg", width=800, height=300)
455469
save_fig(fig, f"{img_path}.svelte")
456470

457471

data/wbm/figs/2023-01-08-mp-elements.svg

+1
Loading

data/wbm/figs/2023-01-08-wbm-elements.svg

+1
Loading

data/wbm/readme.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ The full set of processing steps used to curate the WBM test set from the raw da
2323
- remove 6 pathological structures (with 0 volume)
2424
- remove formation energy outliers below -5 and above 5 eV/atom (502 and 22 crystals respectively out of 257,487 total, including an anomaly of 500 structures at exactly -10 eV/atom)
2525

26-
<caption>WBM Formation energy distribution. 524 materials outside green dashed lines were discarded.</caption>
26+
<caption>WBM Formation energy distribution. 524 materials outside green dashed lines were discarded.<br />(zoom out on this plot to see discarded samples)</caption>
2727
<slot name="hist-e-form-per-atom">
28-
<img src="./2022-12-07-hist-e-form-per-atom.svg" alt="WBM formation energy histogram indicating outlier cutoffs">
28+
<img src="./figs/2022-12-07-hist-e-form-per-atom.svg" alt="WBM formation energy histogram indicating outlier cutoffs">
2929
</slot>
3030

3131
- apply the [`MaterialsProject2020Compatibility`](https://pymatgen.org/pymatgen.entries.compatibility.html#pymatgen.entries.compatibility.MaterialsProject2020Compatibility) energy correction scheme to the formation energies
@@ -75,13 +75,13 @@ materialscloud:2021.68 includes a readme file with a description of the dataset,
7575
## 📊 &thinsp; Plots
7676

7777
<slot name="wbm-elements-heatmap">
78-
<img src="./2023-01-08-wbm-elements.svg" alt="Periodic table log heatmap of WBM elements">
78+
<img src="./figs/2023-01-08-wbm-elements.svg" alt="Periodic table log heatmap of WBM elements">
7979
</slot>
80-
<caption>Test set element counts consisting of 256,963 WBM <code>ComputedStructureEntries</code></caption>
80+
<caption>Element counts for test set consisting of 256,963 WBM <code>ComputedStructureEntries</code></caption>
8181

8282
By comparison, the training set of MP ComputedStructureEntries has this element distribution.
8383

8484
<slot name="mp-elements-heatmap">
85-
<img src="./2023-01-08-mp-elements.svg" alt="Periodic table log heatmap of MP elements">
85+
<img src="./figs/2023-01-08-mp-elements.svg" alt="Periodic table log heatmap of MP elements">
8686
</slot>
87-
<caption>Training set element counts consisting of 146,323 MP <code>ComputedStructureEntries</code></caption>
87+
<caption>Element counts for training set consisting of 146,323 MP <code>ComputedStructureEntries</code></caption>

matbench_discovery/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
"""Global variables used all across the matbench_discovery package."""
22

3-
from __future__ import annotations
4-
53
import os
64
import sys
75
from datetime import datetime
86

97
ROOT = os.path.dirname(os.path.dirname(__file__)) # repository root
10-
FIGS = f"{ROOT}/site/static/figs" # directory to store figures
11-
PAPER = f"{ROOT}/site/src/routes/paper/figs" # directory to store figures
8+
FIGS = f"{ROOT}/site/src/figs" # directory to store interactive figures
9+
STATIC = f"{ROOT}/site/static/figs" # directory to store static figures
1210
# whether a currently running slurm job is in debug mode
1311
DEBUG = "DEBUG" in os.environ or (
1412
"slurm-submit" not in sys.argv and "SLURM_JOB_ID" not in os.environ

matbench_discovery/data.py

+5
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,24 @@ def load_df_wbm_with_preds(
224224
model_key = model_name.lower().replace(" ", "_")
225225
if f"e_form_per_atom_{model_key}" in df:
226226
df_out[model_name] = df[f"e_form_per_atom_{model_key}"]
227+
227228
elif len(pred_cols := df.filter(like="_pred_ens").columns) > 0:
228229
assert len(pred_cols) == 1
229230
df_out[model_name] = df[pred_cols[0]]
230231
if len(std_cols := df.filter(like="_std_ens").columns) > 0:
231232
df_out[f"{model_name}_std"] = df[std_cols[0]]
233+
232234
elif len(pred_cols := df.filter(like=r"_pred_").columns) > 1:
233235
# make sure we average the expected number of ensemble member predictions
234236
assert len(pred_cols) == 10, f"{len(pred_cols) = }, expected 10"
235237
df_out[model_name] = df[pred_cols].mean(axis=1)
238+
236239
elif "e_form_per_atom_voronoi_rf" in df: # new voronoi
237240
df_out[model_name] = df.e_form_per_atom_voronoi_rf
241+
238242
elif "e_form_pred" in df: # old voronoi
239243
df_out[model_name] = df.e_form_pred
244+
240245
else:
241246
raise ValueError(
242247
f"No pred col for {model_name=}, available cols={list(df)}"

matbench_discovery/plots.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
margin=dict(l=30, r=20, t=60, b=20),
6262
paper_bgcolor="rgba(0,0,0,0)",
6363
# plot_bgcolor="rgba(0,0,0,0)",
64-
font_size=15,
64+
font_size=13,
6565
)
6666
pio.templates["global"] = dict(layout=global_layout)
6767
pio.templates.default = "plotly_dark+global"
@@ -181,7 +181,7 @@ def hist_classified_stable_vs_hull_dist(
181181
# add moving average of the accuracy computed within given window
182182
# as a function of e_above_hull shown as blue line (right axis)
183183
ax_acc = ax.twinx()
184-
ax_acc.set_ylabel("Accuracy", color="darkblue")
184+
ax_acc.set_ylabel("Rolling Accuracy", color="darkblue")
185185
ax_acc.tick_params(labelcolor="darkblue")
186186
ax_acc.set(ylim=(0, 1))
187187

models/bowsr/join_bowsr_results.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
# %%
1818
module_dir = os.path.dirname(__file__)
1919
task_type = "IS2RE"
20-
date = "2022-11-22"
21-
glob_pattern = f"{date}-bowsr-megnet-wbm-{task_type}/*.json.gz"
20+
date = "2023-01-20"
21+
energy_model = "megnet"
22+
glob_pattern = f"{date}-bowsr-{energy_model}-wbm-{task_type}/*.json.gz"
2223
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
2324
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
2425

@@ -43,7 +44,17 @@
4344
df_wbm = pd.read_csv(data_path).set_index("material_id")
4445

4546

46-
print(f"{len(df_bowsr):,} - {len(df_wbm):,} = {len(df_bowsr) - len(df_wbm) = :,}")
47+
print(
48+
f"{len(df_bowsr) - len(df_wbm) = :,} missing ({len(df_bowsr):,} - {len(df_wbm):,})"
49+
)
50+
51+
52+
# %% sanity check: since Bowsr uses MEGNet as energy model final BOWSR energy and Megnet
53+
# formation energy should be the same
54+
pymatviz.density_scatter(
55+
x=df_bowsr.e_form_per_atom_bowsr_megnet,
56+
y=df_bowsr[f"energy_bowsr_{energy_model}"],
57+
)
4758

4859

4960
# %%

models/bowsr/test_bowsr.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828

2929
task_type = "IS2RE" # "RS2RE"
3030
module_dir = os.path.dirname(__file__)
31-
# --mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s)
32-
# Some of your processes may have been killed by the cgroup out-of-memory handler.
33-
slurm_mem_per_node = 12000
3431
# set large job array size for fast testing/debugging
3532
slurm_array_task_count = 500
3633
# see https://stackoverflow.com/a/55431306 for how to change array throttling
@@ -45,12 +42,14 @@
4542
slurm_vars = slurm_submit(
4643
job_name=job_name,
4744
out_dir=out_dir,
48-
partition="icelake-himem",
45+
partition="skylake",
4946
account="LEE-SL3-CPU",
5047
time="12:0:0",
5148
# --time 2h is probably enough but best be safe.
5249
array=f"1-{slurm_array_task_count}%{slurm_max_parallel}",
53-
slurm_flags=("--mem", str(slurm_mem_per_node)),
50+
# --mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s)
51+
# Some of your processes may have been killed by the cgroup out-of-memory handler.
52+
slurm_flags=("--mem", str(12_000)),
5453
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
5554
# https://stackoverflow.com/a/40982782
5655
pre_cmd="TF_CPP_MIN_LOG_LEVEL=2",
@@ -141,7 +140,7 @@
141140
structure_bowsr
142141
),
143142
"structure_bowsr": structure_bowsr,
144-
"energy_bowsr": energy_bowsr,
143+
f"energy_bowsr_{energy_model}": energy_bowsr,
145144
}
146145

147146
relax_results[material_id] = results

models/m3gnet/test_m3gnet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
module_dir = os.path.dirname(__file__)
3030
# set large job array size for fast testing/debugging
3131
slurm_array_task_count = 100
32-
slurm_mem_per_node = 12000
3332
job_name = f"m3gnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
3433
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3534

@@ -40,7 +39,7 @@
4039
account="LEE-SL3-CPU",
4140
time="3:0:0",
4241
array=f"1-{slurm_array_task_count}",
43-
slurm_flags=("--mem", str(slurm_mem_per_node)),
42+
slurm_flags=("--mem", str(12_000)),
4443
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
4544
# https://stackoverflow.com/a/40982782
4645
pre_cmd="TF_CPP_MIN_LOG_LEVEL=2",

readme.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ Matbench Discovery
1313

1414
</h4>
1515

16-
Matbench Discovery is an [interactive leaderboard](https://matbench-discovery.janosh.dev/figures) and associated [PyPI package](https://pypi.org/project/matbench-discovery) for benchmarking ML energy models on a task designed to closely emulate a real-world computational materials discovery workflow. In it, these models take on the role of a triaging step prior to DFT to determine how to allocate limited compute budget for structure relaxations.
16+
Matbench Discovery is an [interactive leaderboard](https://matbench-discovery.janosh.dev/figures) and associated [PyPI package](https://pypi.org/project/matbench-discovery) for benchmarking ML energy models on a task designed to closely emulate a real-world computational materials discovery workflow. In it, these models take on the role of a triaging step prior to DFT to decide how to allocate limited compute budget for structure relaxations.
1717

18-
We welcome contributions that add new models to the leaderboard through [GitHub PRs](https://github.com/janosh/matbench-discovery/pulls). See the [usage and contributing guide](https://janosh.github.io/matbench-discovery/how-to-contribute).
18+
We welcome contributions that add new models to the leaderboard through [GitHub PRs](https://github.com/janosh/matbench-discovery/pulls). See the [usage and contributing guide](https://janosh.github.io/matbench-discovery/how-to-contribute) for details.
1919

2020
Several new energy models specifically designed to handle unrelaxed structures were published in 2021/22
2121

scripts/cumulative_clf_metrics.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pandas as pd
33
from pymatviz.utils import save_fig
44

5-
from matbench_discovery import FIGS, today
5+
from matbench_discovery import STATIC, today
66
from matbench_discovery.data import load_df_wbm_with_preds
77
from matbench_discovery.plots import cumulative_precision_recall
88

@@ -12,8 +12,8 @@
1212

1313
# %%
1414
models = (
15-
# Wren, CGCNN IS2RE, CGCNN RS2RE
16-
"Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet, CGCNN, CGCNN debug"
15+
# Wren, CGCNN IS2RE, CGCNN RS2RE, CGCNN
16+
"Voronoi RF, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
1717
).split(", ")
1818

1919
df_wbm = load_df_wbm_with_preds(models=models).round(3)
@@ -37,17 +37,21 @@
3737
show_optimal=True,
3838
)
3939

40-
title = f"{today} - Cumulative Precision, Recall and F1 Score for Stable Materials"
40+
title = f"{today} - Cumulative Precision, Recall, F1 scores for classifying stable materials"
4141
# xlabel_cumulative = "Materials predicted stable sorted by hull distance"
4242
if backend == "matplotlib":
4343
fig.suptitle(title)
4444
# fig.text(0.5, -0.08, xlabel_cumulative, ha="center", fontdict={"size": 16})
4545
elif backend == "plotly":
46-
fig.update_layout(title=title)
46+
# place legend in lower right corner
47+
fig.update_layout(
48+
title=title,
49+
legend=dict(yanchor="bottom", y=0.02, xanchor="right", x=1),
50+
)
4751
fig.update_xaxes(matches=None, showticklabels=True)
4852
fig.update_yaxes(matches=None, showticklabels=True)
4953

50-
fig.show(config=dict(responsive=True))
54+
fig.show()
5155

5256

5357
# %%
@@ -57,6 +61,7 @@
5761
assert isinstance(trace.y[0], float)
5862
trace.y = [round(y, 3) for y in trace.y]
5963

60-
img_path = f"{FIGS}/{today}-cumulative-clf-metrics"
64+
img_path = f"{STATIC}/{today}-cumulative-clf-metrics"
6165
# save_fig(fig, f"{img_path}.pdf")
6266
save_fig(fig, f"{img_path}.svelte")
67+
# save_fig(fig, f"{img_path}.png", scale=3)

0 commit comments

Comments
 (0)