Skip to content

Add MACE #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c8b6092
add scripts/model_figs/update_all_model_figs.py
janosh Jul 14, 2023
4816fcc
prepare MACE submission (readme.md, metadata.yml, test_mace.py)
janosh Jul 14, 2023
9ab0df2
fix model name still set to chgnet in test_mace.py
janosh Jul 18, 2023
8c5cb34
wandb collect all dep versions in single dict
janosh Jul 23, 2023
aa7d281
delete matbench_discovery.DEBUG global
janosh Jul 23, 2023
8891906
roc_prc_curves_models.py fix n_rows x n_cols in out filename
janosh Jul 23, 2023
081a09e
test_mace.py add relax trajectory recording
janosh Jul 23, 2023
df50ce7
add MACE + ALIGNN checkpoints figshare urls to class DataFiles
janosh Jul 23, 2023
6911cf2
rename BOWSR + MEGnet -> BOWSR
janosh Jul 24, 2023
08c0934
ensure out_path matches glob_pattern in join scripts
janosh Jul 24, 2023
c90febb
load_df_wbm_with_preds() use 1st matching df column
janosh Jul 25, 2023
52f9940
refactor df_to_pdf() from wkhtmltopdf to weasyprint
janosh Jul 26, 2023
26cff30
improve Figshare description
janosh Jul 26, 2023
e3ddee2
update most figures with MACE results
janosh Jul 26, 2023
ff975ec
extract scripts/model_figs/per_element_errors.py out of scripts/analy…
janosh Jul 26, 2023
da4b617
Merge branch 'main' into mace
janosh Jul 26, 2023
c3d0fd3
add MACE to site/src/figs/model-run-times-bar.svelte
janosh Jul 26, 2023
80dcc7e
revert 'extract scripts/model_figs/per_element_errors.py out of scrip…
janosh Jul 26, 2023
da6e2a2
add weasyprint to df-pdf-export optional deps
janosh Jul 26, 2023
fde6ee6
fix test_df_metrics failing from MACE R^2 of -1.291 being below -0.9
janosh Jul 26, 2023
7f9cb74
fix df_to_pdf if weasyprint not installed
janosh Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions data/figshare/1.0.0.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
{
"alignn_checkpoint": [
"https://figshare.com/ndownloader/files/41233560",
"2023-06-02-pbenner-best-alignn-model.pth.zip"
],
"mace_checkpoint": [
"https://figshare.com/ndownloader/files/41565618",
"2023-07-14-mace-universal-2-big-128-6.model"
],
"mp_computed_structure_entries": [
"https://figshare.com/ndownloader/files/40344436",
"2023-02-07-mp-computed-structure-entries.json.gz"
Expand Down
2 changes: 1 addition & 1 deletion data/mp/build_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
df = pd.read_json(data_path).set_index("material_id")

# drop the structure, just load ComputedEntry, makes the PPD faster to build and load
mp_computed_entries = [ComputedEntry.from_dict(x) for x in tqdm(df.entry)]
mp_computed_entries = [ComputedEntry.from_dict(dct) for dct in tqdm(df.entry)]

print(f"{len(mp_computed_entries) = :,} on {today}")
# len(mp_computed_entries) = 146,323 on 2022-09-16
Expand Down
5 changes: 3 additions & 2 deletions data/wbm/compare_cse_vs_ce_mp_2020_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
)

cses = [
ComputedStructureEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_cse.computed_structure_entry)
]

ces = [ComputedEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)]
ces = [ComputedEntry.from_dict(dct) for dct in tqdm(df_cse.computed_structure_entry)]


warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
Expand Down
7 changes: 4 additions & 3 deletions data/wbm/fetch_process_wbm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
assert mat_id == cse["entry_id"], f"{mat_id} != {cse['entry_id']}"

df_wbm["cse"] = [
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_wbm.computed_structure_entry)
]
# raw WBM ComputedStructureEntries have no energy corrections applied:
assert all(cse.uncorrected_energy == cse.energy for cse in df_wbm.cse)
Expand Down Expand Up @@ -640,6 +641,6 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
).set_index("material_id")

df_wbm["cse"] = [
ComputedStructureEntry.from_dict(x)
for x in tqdm(df_wbm.computed_structure_entry)
ComputedStructureEntry.from_dict(dct)
for dct in tqdm(df_wbm.computed_structure_entry)
]
5 changes: 0 additions & 5 deletions matbench_discovery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Global variables used all across the matbench_discovery package."""

import os
import sys
from datetime import datetime

ROOT = os.path.dirname(os.path.dirname(__file__)) # repo root directory
Expand All @@ -13,10 +12,6 @@
for directory in [FIGS, MODELS, FIGSHARE, PDF_FIGS]:
os.makedirs(directory, exist_ok=True)

# whether a currently running slurm job is in debug mode
DEBUG = "DEBUG" in os.environ or (
"slurm-submit" not in sys.argv and "SLURM_JOB_ID" not in os.environ
)
# directory to store model checkpoints downloaded from wandb cloud storage
CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
# wandb <entity>/<project name> to record new runs to
Expand Down
3 changes: 3 additions & 0 deletions matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
if answer == "y":
load(key) # download and cache data file

# TODO maybe set attrs to None and load file names from Figshare json
mp_computed_structure_entries = (
"mp/2023-02-07-mp-computed-structure-entries.json.gz"
)
Expand All @@ -246,6 +247,8 @@ def _on_not_found(self, key: str, msg: str) -> None: # type: ignore[override]
"wbm/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
)
wbm_summary = "wbm/2022-10-19-wbm-summary.csv.gz"
alignn_checkpoint = "2023-06-02-pbenner-best-alignn-model.pth.zip"
mace_checkpoint = "2023-07-14-mace-universal-2-big-128-6.model"


# data files can be downloaded and cached with matbench_discovery.data.load()
Expand Down
94 changes: 69 additions & 25 deletions matbench_discovery/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
import os
import subprocess
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
Expand Down Expand Up @@ -65,7 +66,7 @@ def unit(text: str) -> str:
model_labels = dict(
alignn="ALIGNN",
alignn_pretrained="ALIGNN Pretrained",
bowsr_megnet="BOWSR + MEGNet",
bowsr_megnet="BOWSR",
chgnet="CHGNet",
chgnet_megnet="CHGNet + MEGNet",
cgcnn_p="CGCNN+P",
Expand All @@ -74,6 +75,7 @@ def unit(text: str) -> str:
m3gnet="M3GNet",
m3gnet_direct="M3GNet DIRECT",
m3gnet_ms="M3GNet MS",
mace="MACE",
megnet="MEGNet",
voronoi_rf="Voronoi RF",
wrenformer="Wrenformer",
Expand Down Expand Up @@ -874,38 +876,80 @@ def df_to_svelte_table(
def df_to_pdf(
styler: Styler, file_path: str | Path, crop: bool = True, **kwargs: Any
) -> None:
"""Export a pandas Styler to PDF.
"""Export a pandas Styler to PDF with WeasyPrint.

Args:
styler (Styler): Styler object to export.
file_path (str): Path to save the PDF to. Requires pdfkit.
crop (bool): Whether to crop the PDF margins. Requires pdfCropMargins. Defaults
to True.
file_path (str): Path to save the PDF to. Requires WeasyPrint.
crop (bool): Whether to crop the PDF margins. Requires pdfCropMargins.
Defaults to True.
**kwargs: Keyword arguments passed to Styler.to_html().
"""
# Convert Styler to HTML
from weasyprint import HTML

html_str = styler.to_html(**kwargs)

# Add CSS to adjust layout and margins
html_str = f"""
<style>
@page {{ size: landscape; margin: 1cm; }}
body {{ margin: 0; padding: 1em; }}
</style>
{html_str}
"""

# Create an HTML object from the HTML string
html = HTML(string=html_str)

# Write the HTML object to a PDF
html.write_pdf(file_path)

if crop:
normalize_and_crop_pdf(file_path)


def normalize_and_crop_pdf(file_path: str | Path) -> None:
"""Normalize a PDF using Ghostscript and then crop it.
Without gs normalization, pdfCropMargins sometimes corrupts the PDF.

Args:
file_path (str | Path): Path to the PDF file.
"""
try:
# pdfkit used to export pandas Styler to PDF, requires:
# pip install pdfkit && brew install homebrew/cask/wkhtmltopdf
import pdfkit
except ImportError as exc:
raise ImportError(
"pdfkit not installed\nrun pip install pdfkit && brew install "
"homebrew/cask/wkhtmltopdf\n(brew is macOS only, use apt on linux)"
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason: wkhtmltopdf requires Rosetta, i.e. is not natively M1 Mac compatible.

) from exc

pdfkit.from_string(styler.to_html(**kwargs), file_path)
if not crop:
return
try:
# needed to auto-crop large white margins from PDF
# pip install pdfCropMargins
from pdfCropMargins import crop as crop_pdf
normalized_file_path = f"{file_path}_normalized.pdf"
from pdfCropMargins import crop

# Normalize the PDF with Ghostscript
subprocess.run(
[
"gs",
"-sDEVICE=pdfwrite",
"-dCompatibilityLevel=1.4",
"-dPDFSETTINGS=/default",
"-dNOPAUSE",
"-dQUIET",
"-dBATCH",
f"-sOutputFile={normalized_file_path}",
str(file_path),
]
)

# Remove PDF margins
cropped_file_path, _exit_code, _stdout, _stderr = crop_pdf(
["--percentRetain", "0", file_path]
# Crop the normalized PDF
cropped_file_path, exit_code, stdout, stderr = crop(
["--percentRetain", "0", normalized_file_path]
)
os.replace(cropped_file_path, file_path)

if stderr:
print(f"pdfCropMargins {stderr=}")
# something went wrong, remove the cropped PDF
os.remove(cropped_file_path)
else:
# replace the original PDF with the cropped one
os.replace(cropped_file_path, str(file_path))

os.remove(normalized_file_path)

except ImportError as exc:
msg = "pdfCropMargins not installed\nrun pip install pdfCropMargins"
raise ImportError(msg) from exc
Expand Down
14 changes: 12 additions & 2 deletions matbench_discovery/preds.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class PredFiles(Files):
# m3gnet_direct = "m3gnet/2023-05-30-m3gnet-direct-wbm-IS2RE.csv.gz"
# m3gnet_ms = "m3gnet/2023-06-01-m3gnet-manual-sampling-wbm-IS2RE.csv.gz"

# MACE trained on original M3GNet training set
mace = "mace/2023-07-23-mace-wbm-IS2RE-FIRE.csv.gz"

# original MEGNet straight from publication, not re-trained
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE.csv.gz"
# CHGNet-relaxed structures fed into MEGNet for formation energy prediction
Expand Down Expand Up @@ -106,8 +109,15 @@ def load_df_wbm_with_preds(
df_out = df_wbm.copy()
for model_name, df in dfs.items():
model_key = model_name.lower().replace(" + ", "_").replace(" ", "_")
if (col := f"e_form_per_atom_{model_key}") in df:
df_out[model_name] = df[col]

cols = [col for col in df if col.startswith(f"e_form_per_atom_{model_key}")]
if cols:
if len(cols) > 1:
print(
f"Warning: multiple pred cols for {model_name=}, using {cols[0]!r} "
f"out of {cols=}"
)
df_out[model_name] = df[cols[0]]

elif pred_cols := list(df.filter(like="_pred_ens")):
assert len(pred_cols) == 1
Expand Down
10 changes: 5 additions & 5 deletions models/alignn/test_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sklearn.metrics import r2_score
from tqdm import tqdm

from matbench_discovery import DEBUG, today
from matbench_discovery import today
from matbench_discovery.data import DATA_FILES, df_wbm
from matbench_discovery.plots import wandb_scatter
from matbench_discovery.slurm import slurm_submit
Expand All @@ -36,7 +36,7 @@
input_col = "initial_structure"
id_col = "material_id"
device = "cuda" if torch.cuda.is_available() else "cpu"
job_name = f"{model_name}-wbm-{task_type}{'-debug' if DEBUG else ''}"
job_name = f"{model_name}-wbm-{task_type}"
out_dir = os.getenv("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")


Expand Down Expand Up @@ -85,15 +85,15 @@
assert input_col in df_in, f"{input_col=} not in {list(df_in)}"

df_in[input_col] = [
JarvisAtomsAdaptor.get_atoms(Structure.from_dict(x))
for x in tqdm(df_in[input_col], leave=False, desc="Converting to JARVIS atoms")
JarvisAtomsAdaptor.get_atoms(Structure.from_dict(dct))
for dct in tqdm(df_in[input_col], leave=False, desc="Converting to JARVIS atoms")
]


# %%
run_params = dict(
data_path=data_path,
**{f"{dep}_version": version(dep) for dep in ("megnet", "numpy")},
versions={dep: version(dep) for dep in ("megnet", "numpy")},
model_name=model_name,
task_type=task_type,
target_col=target_col,
Expand Down
8 changes: 4 additions & 4 deletions models/alignn/train_alignn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm

from matbench_discovery import DEBUG, today
from matbench_discovery import today
from matbench_discovery.data import DATA_FILES
from matbench_discovery.slurm import slurm_submit

Expand All @@ -35,7 +35,7 @@
input_col = "atoms"
id_col = "material_id"
device = "cuda" if torch.cuda.is_available() else "cpu"
job_name = f"train-{model_name}{'-debug' if DEBUG else ''}"
job_name = f"train-{model_name}"


pred_col = "e_form_per_atom_alignn"
Expand All @@ -49,7 +49,7 @@
slurm_vars = slurm_submit(
job_name=job_name,
# partition="perlmuttter",
account="matgen_g",
account="matgen",
time="4:0:0",
out_dir=out_dir,
slurm_flags="--qos regular --constraint gpu --gpus 1",
Expand Down Expand Up @@ -79,7 +79,7 @@
# %%
run_params = dict(
data_path=DATA_FILES.mp_energies,
**{f"{dep}_version": version(dep) for dep in ("alignn", "numpy", "torch", "dgl")},
versions={dep: version(dep) for dep in ("alignn", "numpy", "torch", "dgl")},
model_name=model_name,
target_col=target_col,
df=dict(shape=str(df_in.shape), columns=", ".join(df_in)),
Expand Down
3 changes: 1 addition & 2 deletions models/bowsr/join_bowsr_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pymatviz
from tqdm import tqdm

from matbench_discovery import today
from matbench_discovery.data import DATA_FILES

__author__ = "Janosh Riebesell"
Expand Down Expand Up @@ -66,7 +65,7 @@


# %%
out_path = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
out_path = f"{module_dir}/{glob_pattern.split('/*')[0]}"
df_bowsr = df_bowsr.round(4)
# save energy and formation energy as fast-loading CSV
df_bowsr.select_dtypes("number").to_csv(f"{out_path}.csv")
Expand Down
2 changes: 1 addition & 1 deletion models/bowsr/metadata.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model_name: BOWSR + MEGNet
model_name: BOWSR
model_version: 2022.9.20
matbench_discovery_version: 1.0
date_added: "2022-11-17"
Expand Down
Loading