Skip to content

Commit fc20ccf

Browse files
committed
make wide tables horizontally scrollable on mobile screens
add flex-wrap: wrap; to side-by-side figures to get line breaks on mobile add explanation for CHGNet/M3GNet difference in caption of fig:cumulative-mae-rmse
1 parent 8798786 commit fc20ccf

File tree

14 files changed

+58
-38
lines changed

14 files changed

+58
-38
lines changed

matbench_discovery/plots.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def unit(text: str) -> str:
7070
cgcnn="CGCNN",
7171
m3gnet_megnet="M3GNet + MEGNet",
7272
m3gnet="M3GNet",
73+
m3gnet_directs="M3GNet DIRECTS",
7374
megnet="MEGNet",
7475
voronoi_rf="Voronoi RF",
7576
wrenformer="Wrenformer",
@@ -841,7 +842,7 @@ def df_to_svelte_table(
841842
styler: Styler,
842843
file_path: str | Path,
843844
inline_props: str = "",
844-
styles: str = "",
845+
styles: str = "table { overflow: scroll; max-width: 100%; display: block; }",
845846
**kwargs: Any,
846847
) -> None:
847848
"""Convert a pandas Styler to a svelte table.

matbench_discovery/preds.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class PredFiles(Files):
4444

4545
# original M3GNet straight from publication, not re-trained
4646
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
47+
# m3gnet_directs = "m3gnet/2023-05-30-m3gnet-directs-wbm-IS2RE.csv"
4748

4849
# original MEGNet straight from publication, not re-trained
4950
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"

models/bowsr/join_bowsr_results.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,14 @@
6666

6767

6868
# %%
69-
out_path = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}.json.gz"
70-
df_bowsr.reset_index().to_json(out_path, default_handler=lambda x: x.as_dict())
69+
out_path = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
70+
df_bowsr = df_bowsr.round(4)
71+
# save energy and formation energy as fast-loading CSV
72+
df_bowsr.select_dtypes("number").to_csv(f"{out_path}.csv")
73+
df_bowsr.reset_index().to_json(
74+
f"{out_path}.json.gz", default_handler=lambda x: x.as_dict()
75+
)
7176

72-
# save energy and formation energy as CSV for fast loading
73-
df_bowsr.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
7477

7578
# in_path = f"{module_dir}/2023-01-23-bowsr-megnet-wbm-IS2RE.json.gz"
7679
# df_bowsr = pd.read_json(in_path).set_index("material_id")

models/chgnet/join_chgnet_results.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@
6464

6565

6666
# %%
67-
out_path = f"{module_dir}/{today}-chgnet-wbm-{task_type}.json.gz"
67+
out_path = f"{module_dir}/{today}-chgnet-wbm-{task_type}"
6868
df_chgnet = df_chgnet.round(4)
69-
df_chgnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
70-
df_chgnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
69+
df_chgnet.select_dtypes("number").to_csv(f"{out_path}.csv")
70+
df_chgnet.reset_index().to_json(f"{out_path}.json.gz", default_handler=as_dict_handler)
7171

7272
# in_path = f"{module_dir}/2023-03-04-chgnet-wbm-IS2RE.json.gz"
7373
# df_chgnet = pd.read_csv(in_path.replace(".json.gz", ".csv")).set_index("material_id")

models/m3gnet/join_m3gnet_results.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@
3030
# %%
3131
module_dir = os.path.dirname(__file__)
3232
task_type = "IS2RE"
33-
date = "2022-10-31"
34-
glob_pattern = f"{date}-m3gnet-wbm-{task_type}/*.json.gz"
33+
date = "2023-05-30"
34+
model_type = "directs"
35+
glob_pattern = f"{date}-m3gnet-{model_type}-wbm-{task_type}/*.json.gz"
3536
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
3637
struct_col = "m3gnet_structure"
3738
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
3839

39-
dfs: dict[str, pd.DataFrame] = {}
40+
# prevent accidental overwrites
41+
if "dfs" not in locals():
42+
dfs: dict[str, pd.DataFrame] = {}
4043

4144

4245
# %%
@@ -66,7 +69,7 @@
6669
for row in tqdm(df_m3gnet.itertuples(), total=len(df_m3gnet)):
6770
mat_id, struct_dict, m3gnet_energy, *_ = row
6871
m3gnet_struct = Structure.from_dict(struct_dict)
69-
df_m3gnet.loc[mat_id, struct_col] = m3gnet_struct
72+
df_m3gnet.at[mat_id, struct_col] = m3gnet_struct # noqa: PD008
7073
cse = df_cse.loc[mat_id, "cse"]
7174
cse._energy = m3gnet_energy # cse._energy is the uncorrected energy
7275
cse._structure = m3gnet_struct
@@ -81,7 +84,7 @@
8184

8285

8386
# %% compute corrected formation energies
84-
df_m3gnet["e_form_per_atom_m3gnet"] = [
87+
df_m3gnet[f"e_form_per_atom_m3gnet_{model_type}"] = [
8588
get_e_form_per_atom(cse) for cse in tqdm(df_m3gnet.cse)
8689
]
8790

@@ -93,11 +96,11 @@
9396

9497

9598
# %%
96-
out_path = f"{module_dir}/{today}-m3gnet-wbm-{task_type}.json.gz"
99+
out_path = f"{module_dir}/{today}-m3gnet-{model_type}-wbm-{task_type}"
97100
df_m3gnet = df_m3gnet.round(4)
98-
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
101+
df_m3gnet.select_dtypes("number").to_csv(f"{out_path}.csv")
102+
df_m3gnet.reset_index().to_json(f"{out_path}.json.gz", default_handler=as_dict_handler)
99103

100-
df_m3gnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
101104

102105
# in_path = f"{module_dir}/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
103106
# df_m3gnet = pd.read_csv(in_path.replace(".json.gz", ".csv")).set_index("material_id")

models/m3gnet/test_m3gnet.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import os
1212
import warnings
1313
from importlib.metadata import version
14-
from typing import Any
14+
from typing import Any, Literal
1515

1616
import numpy as np
1717
import pandas as pd
@@ -20,7 +20,7 @@
2020
from pymatgen.core import Structure
2121
from tqdm import tqdm
2222

23-
from matbench_discovery import DEBUG, timestamp, today
23+
from matbench_discovery import DEBUG, ROOT, timestamp, today
2424
from matbench_discovery.data import DATA_FILES, as_dict_handler
2525
from matbench_discovery.slurm import slurm_submit
2626

@@ -29,9 +29,10 @@
2929

3030
task_type = "IS2RE" # "RS2RE"
3131
module_dir = os.path.dirname(__file__)
32+
model_type: Literal["orig", "direct", "manual-sampling"] = "manual-sampling"
3233
# set large job array size for smaller data splits and faster testing/debugging
3334
slurm_array_task_count = 100
34-
job_name = f"m3gnet-wbm-{task_type}{'-debug' if DEBUG else ''}"
35+
job_name = f"m3gnet-{model_type}-wbm-{task_type}{'-debug' if DEBUG else ''}"
3536
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3637

3738
slurm_vars = slurm_submit(
@@ -85,7 +86,12 @@
8586

8687

8788
# %%
88-
megnet = Relaxer() # load default pre-trained M3GNet model
89+
checkpoint = None
90+
if model_type == "direct":
91+
checkpoint = f"{ROOT}/models/m3gnet/2023-05-26-DI-DFTstrictF10-TTRS-128U-442E"
92+
if model_type == "manual-sampling":
93+
checkpoint = f"{ROOT}/models/m3gnet/2023-05-26-MS-DFTstrictF10-128U-154E"
94+
megnet = Relaxer(potential=checkpoint) # load pre-trained M3GNet model
8995
relax_results: dict[str, dict[str, Any]] = {}
9096
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
9197

scripts/cumulative_metrics.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727

2828

2929
# %%
30-
# metrics = ("Precision", "Recall")
31-
metrics = ("MAE", "RMSE")
30+
metrics = ("Precision", "Recall")
31+
# metrics = ("MAE", "RMSE")
3232
fig, df_metric = cumulative_metrics(
3333
e_above_hull_true=df_preds[each_true_col],
3434
df_preds=df_each_pred[models],
3535
project_end_point="xy",
3636
backend=(backend := "plotly"),
37-
range_y=(0, 0.4),
37+
range_y=(0, 1),
3838
metrics=metrics,
3939
# facet_col_wrap=2,
4040
# increase facet col gap

site/src/app.css

+3-1
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ img {
122122
}
123123

124124
table {
125+
display: block;
126+
max-width: 100%;
127+
overflow: scroll;
125128
border-collapse: collapse;
126-
width: 100%;
127129
}
128130
table :is(td, th) {
129131
border: 1px solid gray;

site/src/routes/+layout.svelte

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@
4646
document.documentElement.style.setProperty(`--main-max-width`, `50em`)
4747
}
4848
49-
for (const node of document.querySelectorAll('pre > code')) {
49+
for (const node of document.querySelectorAll(`pre > code`)) {
5050
// skip if <pre> already contains a button (presumably for copy)
5151
const pre = node.parentElement
5252
if (!pre || pre.querySelector(`button`)) continue
5353
5454
new CopyButton({
5555
target: pre,
5656
props: {
57-
content: node.textContent ?? '',
58-
style: 'position: absolute; top: 1ex; right: 1ex;',
57+
content: node.textContent ?? ``,
58+
style: `position: absolute; top: 1ex; right: 1ex;`,
5959
},
6060
})
6161
}

site/src/routes/about-the-data/+page.svelte

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
{/if}
108108
</svelte:fragment>
109109
<div
110-
style="display: flex; gap: 1em; justify-content: space-around;"
110+
style="display: flex; gap: 1em; justify-content: space-around; flex-wrap: wrap;"
111111
slot="spacegroup-sunbursts"
112112
>
113113
{#if browser}

site/src/routes/preprint/+page.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,13 @@ We welcome further model submissions at
241241

242242
## Acknowledgments
243243

244-
Janosh Riebesell acknowledges support from the German Academic Scholarship Foundation ([Studienstiftung](https://wikipedia.org/wiki/Studienstiftung)) and gracious hosting as a visiting affiliate in the groups of Kristin Persson and Anubhav Jain.
244+
Janosh Riebesell acknowledges support from the German Academic Scholarship Foundation ([Studienstiftung](https://wikipedia.org/wiki/Studienstiftung)).
245245

246-
We would like to thank Jason Blake Gibson, Shyue Ping Ong, Chi Chen, Tian Xie, Bowen Deng, Peichen Zhong, Ekin Dogus Cubuk for helpful discussions. We also thank Hai-Chen Wang and co-authors for creating and freely providing the WBM data set @wang_predicting_2021.
246+
A big thank you to
247247

248-
Thanks also to [@pbenner](https://github.com/pbenner) for [finding and reporting many bugs]({repo}/issues?q=is%3Aissue+author%3Apbenner+) in the data loading and caching routines prior to the v1 release.
248+
- Hai-Chen Wang and co-authors for creating and freely providing the WBM data set
249+
- Jason Blake Gibson, Shyue Ping Ong, Chi Chen, Tian Xie, Bowen Deng, Peichen Zhong, Ekin Dogus Cubuk for helpful discussions
250+
- Philipp Benner ([@pbenner](https://github.com/pbenner)) for [finding and reporting many bugs]({repo}/issues?q=is%3Aissue+author%3Apbenner+) in the data loading routines before the v1 release.
249251

250252
## Author Contributions
251253

site/src/routes/preprint/iclr-ml4mat/+page.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,13 @@ We welcome further model submissions as well as data contributions for version 2
140140

141141
## Acknowledgments
142142

143-
Janosh Riebesell acknowledges support from the German Academic Scholarship Foundation ([Studienstiftung](https://wikipedia.org/wiki/Studienstiftung)) and gracious hosting as a visiting affiliate in the groups of Kristin Persson and Anubhav Jain.
143+
Janosh Riebesell acknowledges support from the German Academic Scholarship Foundation ([Studienstiftung](https://wikipedia.org/wiki/Studienstiftung)).
144144

145-
We would like to thank Jason Blake Gibson, Shyue Ping Ong, Chi Chen, Tian Xie, Bowen Deng, Peichen Zhong, Ekin Dogus Cubuk for helpful discussions. We also thank Hai-Chen Wang and co-authors for creating and freely providing the WBM data set.
145+
A big thank you to
146146

147-
Thanks also to [@pbenner](https://github.com/pbenner) for [finding and reporting many bugs]({repo}/issues?q=is%3Aissue+author%3Apbenner+) in the data loading and caching routines prior to the v1 release.
147+
- Hai-Chen Wang and co-authors for creating and freely providing the WBM data set
148+
- Jason Blake Gibson, Shyue Ping Ong, Chi Chen, Tian Xie, Bowen Deng, Peichen Zhong, Ekin Dogus Cubuk for helpful discussions
149+
- Philipp Benner ([@pbenner](https://github.com/pbenner)) for [finding and reporting many bugs]({repo}/issues?q=is%3Aissue+author%3Apbenner+) in the data loading routines before the v1 release.
148150

149151
## Author Contributions
150152

site/src/routes/si/+page.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
{/if}
5151

5252
> @label:fig:cumulative-mae-rmse Cumulative mean absolute error (MAE) and root mean square error (RMSE) during a simulated discovery campaign. This figure expands on the [precision-recall figure](/preprint#fig:cumulative-precision-recall). The $x$-axis again shows number of materials sorted by model-predicted stability or 'campaign length'. This allows the reader to choose a cutoff point given their discovery campaign's resource constraints for validating model predictions and then read off the optimal model given those constraints.
53-
> CHGNet achieves the lowest regression error profile, with a larger gap to the runner-up model M3GNet than in the precision-recall plots.
53+
> CHGNet achieves the lowest regression error profile, with a larger gap to the runner-up model M3GNet than in the precision-recall plots. This is likely due to the difference in TPR/TNR trade off between CHGNet and M3GNet. M3GNet has TNR = 0.80 vs CHGNet's TNR = 0.87. Higher TNR means lower FPR. Lower false positive rate means lower cumulative MAE and RMSE. Lines end when models stop predicting materials as stable, so these cumulative plots only contain model-predicted positive (stable) materials. Besides the high opportunity cost of false positives, this highlights another reason to prioritize low FPR in discovery models: lower error on the predictions of highest relevance.
5454
5555
## Model Run Times
5656

@@ -134,11 +134,11 @@ We highlight this here to refute the suggestion that training on raw DFT energie
134134

135135
{#if mounted}
136136

137-
<div style="display: flex; gap: 1em; justify-content: space-around;">
137+
<div style="display: flex; gap: 1em; justify-content: space-around; flex-wrap: wrap;">
138138
<SpacegroupSunburstWrenformerFailures />
139139
<SpacegroupSunburstWbm />
140140
</div>
141-
<EAboveHullScatterWrenformerFailures style="height: 300; width: 300;" />
141+
<EAboveHullScatterWrenformerFailures />
142142
{/if}
143143

144144
> @label:fig:spacegroup-prevalence-wrenformer-failures The left spacegroup sunburst shows spacegroup 71 is by far the dominant lattice symmetry among the 941 Wrenformer failure cases where $E_\text{above hull,DFT} < 1$ and $E_\text{above hull,Wrenformer} > 1$ (points inside the shaded rectangle). On the right side for comparison is the spacegroup sunburst for the entire WBM test set.

site/src/routes/si/largest-error-scatter-select.svelte

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
const figs = import.meta.glob(
77
`$figs/scatter-largest-errors-models-mean-vs-true-hull-dist-*.svelte`,
8-
{ eager: true, import: 'default' }
8+
{ eager: true, import: `default` }
99
)
1010
1111
let selected: string[] = [Object.keys(figs)[0]]

0 commit comments

Comments
 (0)