Skip to content

Commit a549532

Browse files
authored
Add MACE (#48)
* add scripts/model_figs/update_all_model_figs.py goal is to be able to run a single script that will update all the model figures (both the interactive version for the leaderboard and the PDF version for the paper) * prepare MACE submission (readme.md, metadata.yml, test_mace.py) * fix model name still set to chgnet in test_mace.py * wandb collect all dep versions in single dict * delete matbench_discovery.DEBUG global * roc_prc_curves_models.py fix n_rows x n_cols in out filename * test_mace.py add relax trajectory recording * add MACE + ALIGNN checkpoints figshare urls to class DataFiles * rename BOWSR + MEGnet -> BOWSR * ensure out_path matches glob_pattern in join scripts * load_df_wbm_with_preds() use 1st matching df column * refactor df_to_pdf() from wkhtmltopdf to weasyprint * improve Figshare description * update most figures with MACE results * extract scripts/model_figs/per_element_errors.py out of scripts/analyze_element_errors.py to run former as part of update_model_figs.py * add MACE to site/src/figs/model-run-times-bar.svelte * revert 'extract scripts/model_figs/per_element_errors.py out of scripts/analyze_element_errors.py' * add weasyprint to df-pdf-export optional deps * fix test_df_metrics failing from MACE R^2 of -1.291 being below -0.9 * fix df_to_pdf if weasyprint not installed
1 parent 6696d22 commit a549532

File tree

78 files changed

+1264
-1004
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+1264
-1004
lines changed

data/figshare/1.0.0.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
{
2+
"alignn_checkpoint": [
3+
"https://figshare.com/ndownloader/files/41233560",
4+
"2023-06-02-pbenner-best-alignn-model.pth.zip"
5+
],
6+
"mace_checkpoint": [
7+
"https://figshare.com/ndownloader/files/41565618",
8+
"2023-07-14-mace-universal-2-big-128-6.model"
9+
],
210
"mp_computed_structure_entries": [
311
"https://figshare.com/ndownloader/files/40344436",
412
"2023-02-07-mp-computed-structure-entries.json.gz"

data/mp/build_phase_diagram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
df = pd.read_json(data_path).set_index("material_id")
4444

4545
# drop the structure, just load ComputedEntry, makes the PPD faster to build and load
46-
mp_computed_entries = [ComputedEntry.from_dict(x) for x in tqdm(df.entry)]
46+
mp_computed_entries = [ComputedEntry.from_dict(dct) for dct in tqdm(df.entry)]
4747

4848
print(f"{len(mp_computed_entries) = :,} on {today}")
4949
# len(mp_computed_entries) = 146,323 on 2022-09-16

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828
)
2929

3030
cses = [
31-
ComputedStructureEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)
31+
ComputedStructureEntry.from_dict(dct)
32+
for dct in tqdm(df_cse.computed_structure_entry)
3233
]
3334

34-
ces = [ComputedEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)]
35+
ces = [ComputedEntry.from_dict(dct) for dct in tqdm(df_cse.computed_structure_entry)]
3536

3637

3738
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")

data/wbm/fetch_process_wbm_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,8 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
502502
assert mat_id == cse["entry_id"], f"{mat_id} != {cse['entry_id']}"
503503

504504
df_wbm["cse"] = [
505-
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
505+
ComputedStructureEntry.from_dict(dct)
506+
for dct in tqdm(df_wbm.computed_structure_entry)
506507
]
507508
# raw WBM ComputedStructureEntries have no energy corrections applied:
508509
assert all(cse.uncorrected_energy == cse.energy for cse in df_wbm.cse)
@@ -640,6 +641,6 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
640641
).set_index("material_id")
641642

642643
df_wbm["cse"] = [
643-
ComputedStructureEntry.from_dict(x)
644-
for x in tqdm(df_wbm.computed_structure_entry)
644+
ComputedStructureEntry.from_dict(dct)
645+
for dct in tqdm(df_wbm.computed_structure_entry)
645646
]

matbench_discovery/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Global variables used all across the matbench_discovery package."""
22

33
import os
4-
import sys
54
from datetime import datetime
65

76
ROOT = os.path.dirname(os.path.dirname(__file__)) # repo root directory
@@ -13,10 +12,6 @@
1312
for directory in [FIGS, MODELS, FIGSHARE, PDF_FIGS]:
1413
os.makedirs(directory, exist_ok=True)
1514

16-
# whether a currently running slurm job is in debug mode
17-
DEBUG = "DEBUG" in os.environ or (
18-
"slurm-submit" not in sys.argv and "SLURM_JOB_ID" not in os.environ
19-
)
2015
# directory to store model checkpoints downloaded from wandb cloud storage
2116
CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
2217
# wandb <entity>/<project name> to record new runs to

matbench_discovery/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
232232
if answer == "y":
233233
load(key) # download and cache data file
234234

235+
# TODO maybe set attrs to None and load file names from Figshare json
235236
mp_computed_structure_entries = (
236237
"mp/2023-02-07-mp-computed-structure-entries.json.gz"
237238
)
@@ -246,6 +247,8 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
246247
"wbm/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
247248
)
248249
wbm_summary = "wbm/2022-10-19-wbm-summary.csv.gz"
250+
alignn_checkpoint = "2023-06-02-pbenner-best-alignn-model.pth.zip"
251+
mace_checkpoint = "2023-07-14-mace-universal-2-big-128-6.model"
249252

250253

251254
# data files can be downloaded and cached with matbench_discovery.data.load()

matbench_discovery/plots.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import math
66
import os
7+
import subprocess
78
from collections import defaultdict
89
from collections.abc import Sequence
910
from pathlib import Path
@@ -65,7 +66,7 @@ def unit(text: str) -> str:
6566
model_labels = dict(
6667
alignn="ALIGNN",
6768
alignn_pretrained="ALIGNN Pretrained",
68-
bowsr_megnet="BOWSR + MEGNet",
69+
bowsr_megnet="BOWSR",
6970
chgnet="CHGNet",
7071
chgnet_megnet="CHGNet + MEGNet",
7172
cgcnn_p="CGCNN+P",
@@ -74,6 +75,7 @@ def unit(text: str) -> str:
7475
m3gnet="M3GNet",
7576
m3gnet_direct="M3GNet DIRECT",
7677
m3gnet_ms="M3GNet MS",
78+
mace="MACE",
7779
megnet="MEGNet",
7880
voronoi_rf="Voronoi RF",
7981
wrenformer="Wrenformer",
@@ -874,38 +876,81 @@ def df_to_svelte_table(
874876
def df_to_pdf(
875877
styler: Styler, file_path: str | Path, crop: bool = True, **kwargs: Any
876878
) -> None:
877-
"""Export a pandas Styler to PDF.
879+
"""Export a pandas Styler to PDF with WeasyPrint.
878880
879881
Args:
880882
styler (Styler): Styler object to export.
881-
file_path (str): Path to save the PDF to. Requires pdfkit.
882-
crop (bool): Whether to crop the PDF margins. Requires pdfCropMargins. Defaults
883-
to True.
883+
file_path (str): Path to save the PDF to. Requires WeasyPrint.
884+
crop (bool): Whether to crop the PDF margins. Requires pdfCropMargins.
885+
Defaults to True.
884886
**kwargs: Keyword arguments passed to Styler.to_html().
885887
"""
886888
try:
887-
# pdfkit used to export pandas Styler to PDF, requires:
888-
# pip install pdfkit && brew install homebrew/cask/wkhtmltopdf
889-
import pdfkit
889+
from weasyprint import HTML
890890
except ImportError as exc:
891-
raise ImportError(
892-
"pdfkit not installed\nrun pip install pdfkit && brew install "
893-
"homebrew/cask/wkhtmltopdf\n(brew is macOS only, use apt on linux)"
894-
) from exc
895-
896-
pdfkit.from_string(styler.to_html(**kwargs), file_path)
897-
if not crop:
898-
return
891+
msg = "weasyprint not installed\nrun pip install weasyprint"
892+
raise ImportError(msg) from exc
893+
894+
html_str = styler.to_html(**kwargs)
895+
896+
# CSS to adjust layout and margins
897+
html_str = f"""
898+
<style>
899+
@page {{ size: landscape; margin: 1cm; }}
900+
body {{ margin: 0; padding: 1em; }}
901+
</style>
902+
{html_str}
903+
"""
904+
905+
html = HTML(string=html_str)
906+
907+
html.write_pdf(file_path)
908+
909+
if crop:
910+
normalize_and_crop_pdf(file_path)
911+
912+
913+
def normalize_and_crop_pdf(file_path: str | Path) -> None:
914+
"""Normalize a PDF using Ghostscript and then crop it.
915+
Without gs normalization, pdfCropMargins sometimes corrupts the PDF.
916+
917+
Args:
918+
file_path (str | Path): Path to the PDF file.
919+
"""
899920
try:
900-
# needed to auto-crop large white margins from PDF
901-
# pip install pdfCropMargins
902-
from pdfCropMargins import crop as crop_pdf
921+
normalized_file_path = f"{file_path}_normalized.pdf"
922+
from pdfCropMargins import crop
923+
924+
# Normalize the PDF with Ghostscript
925+
subprocess.run(
926+
[
927+
"gs",
928+
"-sDEVICE=pdfwrite",
929+
"-dCompatibilityLevel=1.4",
930+
"-dPDFSETTINGS=/default",
931+
"-dNOPAUSE",
932+
"-dQUIET",
933+
"-dBATCH",
934+
f"-sOutputFile={normalized_file_path}",
935+
str(file_path),
936+
]
937+
)
903938

904-
# Remove PDF margins
905-
cropped_file_path, _exit_code, _stdout, _stderr = crop_pdf(
906-
["--percentRetain", "0", file_path]
939+
# Crop the normalized PDF
940+
cropped_file_path, exit_code, stdout, stderr = crop(
941+
["--percentRetain", "0", normalized_file_path]
907942
)
908-
os.replace(cropped_file_path, file_path)
943+
944+
if stderr:
945+
print(f"pdfCropMargins {stderr=}")
946+
# something went wrong, remove the cropped PDF
947+
os.remove(cropped_file_path)
948+
else:
949+
# replace the original PDF with the cropped one
950+
os.replace(cropped_file_path, str(file_path))
951+
952+
os.remove(normalized_file_path)
953+
909954
except ImportError as exc:
910955
msg = "pdfCropMargins not installed\nrun pip install pdfCropMargins"
911956
raise ImportError(msg) from exc

matbench_discovery/preds.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class PredFiles(Files):
4747
# m3gnet_direct = "m3gnet/2023-05-30-m3gnet-direct-wbm-IS2RE.csv.gz"
4848
# m3gnet_ms = "m3gnet/2023-06-01-m3gnet-manual-sampling-wbm-IS2RE.csv.gz"
4949

50+
# MACE trained on original M3GNet training set
51+
mace = "mace/2023-07-23-mace-wbm-IS2RE-FIRE.csv.gz"
52+
5053
# original MEGNet straight from publication, not re-trained
5154
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE.csv.gz"
5255
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
@@ -106,8 +109,15 @@ def load_df_wbm_with_preds(
106109
df_out = df_wbm.copy()
107110
for model_name, df in dfs.items():
108111
model_key = model_name.lower().replace(" + ", "_").replace(" ", "_")
109-
if (col := f"e_form_per_atom_{model_key}") in df:
110-
df_out[model_name] = df[col]
112+
113+
cols = [col for col in df if col.startswith(f"e_form_per_atom_{model_key}")]
114+
if cols:
115+
if len(cols) > 1:
116+
print(
117+
f"Warning: multiple pred cols for {model_name=}, using {cols[0]!r} "
118+
f"out of {cols=}"
119+
)
120+
df_out[model_name] = df[cols[0]]
111121

112122
elif pred_cols := list(df.filter(like="_pred_ens")):
113123
assert len(pred_cols) == 1

models/alignn/test_alignn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sklearn.metrics import r2_score
1818
from tqdm import tqdm
1919

20-
from matbench_discovery import DEBUG, today
20+
from matbench_discovery import today
2121
from matbench_discovery.data import DATA_FILES, df_wbm
2222
from matbench_discovery.plots import wandb_scatter
2323
from matbench_discovery.slurm import slurm_submit
@@ -36,7 +36,7 @@
3636
input_col = "initial_structure"
3737
id_col = "material_id"
3838
device = "cuda" if torch.cuda.is_available() else "cpu"
39-
job_name = f"{model_name}-wbm-{task_type}{'-debug' if DEBUG else ''}"
39+
job_name = f"{model_name}-wbm-{task_type}"
4040
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
4141

4242

@@ -85,15 +85,15 @@
8585
assert input_col in df_in, f"{input_col=} not in {list(df_in)}"
8686

8787
df_in[input_col] = [
88-
JarvisAtomsAdaptor.get_atoms(Structure.from_dict(x))
89-
for x in tqdm(df_in[input_col], leave=False, desc="Converting to JARVIS atoms")
88+
JarvisAtomsAdaptor.get_atoms(Structure.from_dict(dct))
89+
for dct in tqdm(df_in[input_col], leave=False, desc="Converting to JARVIS atoms")
9090
]
9191

9292

9393
# %%
9494
run_params = dict(
9595
data_path=data_path,
96-
**{f"{dep}_version": version(dep) for dep in ("megnet", "numpy")},
96+
versions={dep: version(dep) for dep in ("megnet", "numpy")},
9797
model_name=model_name,
9898
task_type=task_type,
9999
target_col=target_col,

models/alignn/train_alignn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.utils.data import DataLoader
1919
from tqdm import tqdm
2020

21-
from matbench_discovery import DEBUG, today
21+
from matbench_discovery import today
2222
from matbench_discovery.data import DATA_FILES
2323
from matbench_discovery.slurm import slurm_submit
2424

@@ -35,7 +35,7 @@
3535
input_col = "atoms"
3636
id_col = "material_id"
3737
device = "cuda" if torch.cuda.is_available() else "cpu"
38-
job_name = f"train-{model_name}{'-debug' if DEBUG else ''}"
38+
job_name = f"train-{model_name}"
3939

4040

4141
pred_col = "e_form_per_atom_alignn"
@@ -49,7 +49,7 @@
4949
slurm_vars = slurm_submit(
5050
job_name=job_name,
5151
# partition="perlmuttter",
52-
account="matgen_g",
52+
account="matgen",
5353
time="4:0:0",
5454
out_dir=out_dir,
5555
slurm_flags="--qos regular --constraint gpu --gpus 1",
@@ -79,7 +79,7 @@
7979
# %%
8080
run_params = dict(
8181
data_path=DATA_FILES.mp_energies,
82-
**{f"{dep}_version": version(dep) for dep in ("alignn", "numpy", "torch", "dgl")},
82+
versions={dep: version(dep) for dep in ("alignn", "numpy", "torch", "dgl")},
8383
model_name=model_name,
8484
target_col=target_col,
8585
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),

models/bowsr/join_bowsr_results.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pymatviz
99
from tqdm import tqdm
1010

11-
from matbench_discovery import today
1211
from matbench_discovery.data import DATA_FILES
1312

1413
__author__ = "Janosh Riebesell"
@@ -66,7 +65,7 @@
6665

6766

6867
# %%
69-
out_path = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
68+
out_path = f"{module_dir}/{glob_pattern.split('/*')[0]}"
7069
df_bowsr = df_bowsr.round(4)
7170
# save energy and formation energy as fast-loading CSV
7271
df_bowsr.select_dtypes("number").to_csv(f"{out_path}.csv")

models/bowsr/metadata.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model_name: BOWSR + MEGNet
1+
model_name: BOWSR
22
model_version: 2022.9.20
33
matbench_discovery_version: 1.0
44
date_added: "2022-11-17"

0 commit comments

Comments
 (0)