Skip to content

Commit d7f300b

Browse files
committed
fix bad column name in join_mace_results.py
update readme
1 parent 5df80ef commit d7f300b

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

models/m3gnet/join_m3gnet_results.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(
5858
"material_id"
5959
)
60-
61-
df_cse["cse"] = [
60+
entry_col = "computed_structure_entry"
61+
df_cse[entry_col] = [
6262
ComputedStructureEntry.from_dict(dct)
6363
for dct in tqdm(df_cse.computed_structure_entry)
6464
]
@@ -71,22 +71,22 @@
7171
mat_id, struct_dict, m3gnet_energy, *_ = row
7272
mlip_struct = Structure.from_dict(struct_dict)
7373
df_m3gnet.at[mat_id, struct_col] = mlip_struct # noqa: PD008
74-
cse = df_cse.loc[mat_id, "cse"]
74+
cse = df_cse.loc[mat_id, entry_col]
7575
cse._energy = m3gnet_energy # cse._energy is the uncorrected energy
7676
cse._structure = mlip_struct
77-
df_m3gnet.loc[mat_id, "cse"] = cse
77+
df_m3gnet.loc[mat_id, entry_col] = cse
7878

7979

8080
# %% apply energy corrections
8181
out = MaterialsProject2020Compatibility().process_entries(
82-
df_m3gnet.cse, verbose=True, clean=True
82+
df_m3gnet[entry_col], verbose=True, clean=True
8383
)
8484
assert len(out) == len(df_m3gnet)
8585

8686

8787
# %% compute corrected formation energies
8888
df_m3gnet["e_form_per_atom_m3gnet"] = [
89-
get_e_form_per_atom(cse) for cse in tqdm(df_m3gnet.cse)
89+
get_e_form_per_atom(cse) for cse in tqdm(df_m3gnet[entry_col])
9090
]
9191

9292

models/mace/join_mace_results.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# %%
3131
module_dir = os.path.dirname(__file__)
3232
task_type = "IS2RE"
33-
date = "2023-08-14"
33+
date = "2023-09-02"
3434
glob_pattern = f"{date}-mace-wbm-{task_type}*/*.json.gz"
3535
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
3636
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
@@ -77,7 +77,7 @@
7777

7878
# %% apply energy corrections
7979
out = MaterialsProject2020Compatibility().process_entries(
80-
df_mace.cse, verbose=True, clean=True
80+
df_mace[entry_col], verbose=True, clean=True
8181
)
8282
assert len(out) == len(df_mace)
8383

@@ -96,7 +96,8 @@
9696

9797
# %%
9898
bad_mask = (df_wbm[e_form_col] - df_wbm[e_form_mace_col]).abs() > 10
99-
ax = density_scatter(df=df_wbm[bad_mask], x=e_form_col, y=e_form_mace_col)
99+
print(f"{sum(bad_mask)=}")
100+
ax = density_scatter(df=df_wbm[~bad_mask], x=e_form_col, y=e_form_mace_col)
100101

101102

102103
# %%

models/mace/test_mace.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import pandas as pd
10+
import torch
1011
import wandb
1112
from ase.constraints import ExpCellFilter
1213
from ase.optimize import FIRE, LBFGS
@@ -36,6 +37,7 @@
3637
# model_name = "2023-07-14-mace-ilyes-trained-MPF-2021-2-8-big-128-6"
3738
# MACE trained on CHGNet training set by Yuan Chiang
3839
model_name = "2023-08-14-mace-yuan-trained-mptrj-04"
40+
device = "cuda" if torch.cuda.is_available() else "cpu"
3941

4042
slurm_vars = slurm_submit(
4143
job_name=job_name,
@@ -64,12 +66,13 @@
6466
print(f"\nJob started running {timestamp}")
6567
print(f"{data_path=}")
6668
e_pred_col = "mace_energy"
69+
id_col = "material_id"
6770
max_steps = 500
6871
force_max = 0.05 # Run until the forces are smaller than this in eV/A
6972
checkpoint = f"{ROOT}/models/mace/{model_name}.model"
7073

7174
df_in: pd.DataFrame = np.array_split(
72-
pd.read_json(data_path).set_index("material_id"), slurm_array_task_count
75+
pd.read_json(data_path).set_index(id_col), slurm_array_task_count
7376
)[slurm_array_task_id - 1]
7477

7578
run_params = dict(
@@ -83,14 +86,15 @@
8386
relax_cell=relax_cell,
8487
force_max=force_max,
8588
ase_optimizer=ase_optimizer,
89+
device=device,
8690
)
8791

8892
run_name = f"{job_name}-{slurm_array_task_id}"
8993
wandb.init(project="matbench-discovery", name=run_name, config=run_params)
9094

9195

9296
# %%
93-
mace_calc = MACECalculator(checkpoint, device="cuda", default_dtype="float32")
97+
mace_calc = MACECalculator(checkpoint, device=device, default_dtype="float32")
9498
relax_results: dict[str, dict[str, Any]] = {}
9599
input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[task_type]
96100

@@ -130,18 +134,18 @@
130134
)
131135

132136
relax_results[material_id] = {
133-
"mace_structure": mace_struct,
134-
"mace_energy": mace_energy,
135-
"mace_trajectory": mace_traj, # Add the trajectory to the results
137+
"structure": mace_struct,
138+
"energy": mace_energy,
139+
"trajectory": mace_traj,
136140
}
137141
except Exception as exc:
138142
print(f"Failed to relax {material_id}: {exc!r}")
139143
continue
140144

141145

142146
# %%
143-
df_out = pd.DataFrame(relax_results).T
144-
df_out.index.name = "material_id"
147+
df_out = pd.DataFrame(relax_results).T.add_prefix("mace_")
148+
df_out.index.name = id_col
145149

146150
df_out.reset_index().to_json(out_path, default_handler=as_dict_handler)
147151

readme.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
<h4 align="center" class="toc-exclude">
77

8+
[![arXiv](https://img.shields.io/badge/arXiv-2308.14920-blue)](https://arxiv.org/abs/2308.14920)
89
[![Tests](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/test.yml)
910
[![GitHub Pages](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml/badge.svg)](https://github.com/janosh/matbench-discovery/actions/workflows/gh-pages.yml)
10-
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/janosh/matbench-discovery/main.svg?badge_token=Qza33izjRxSbegTqeSyDvA)](https://results.pre-commit.ci/latest/github/janosh/matbench-discovery/main?badge_token=Qza33izjRxSbegTqeSyDvA)
1111
[![Requires Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
1212
[![PyPI](https://img.shields.io/pypi/v/matbench-discovery?logo=pypi&logoColor=white)](https://pypi.org/project/matbench-discovery?logo=pypi&logoColor=white)
1313

1414
</h4>
1515

16-
> TL;DR: We benchmark ML models on crystal stability prediction from unrelaxed structures finding universal interatomic potentials (UIP) like [M3GNet](https://github.com/materialsvirtuallab/m3gnet) and [CHGNet](https://github.com/CederGroupHub/chgnet) to be highly accurate, robust across chemistries and ready for production use in high-throughput discovery pipelines.
16+
> TL;DR: We benchmark ML models on crystal stability prediction from unrelaxed structures finding universal interatomic potentials (UIP) like [CHGNet](https://github.com/CederGroupHub/chgnet), [M3GNet](https://github.com/materialsvirtuallab/m3gnet) and [MACE](https://github.com/ACEsuit/mace) to be highly accurate, robust across chemistries and ready for production use in high-throughput discovery pipelines.
1717
1818
Matbench Discovery is an [interactive leaderboard](https://janosh.github.io/matbench-discovery/models) and associated [PyPI package](https://pypi.org/project/matbench-discovery) which together make it easy to rank ML energy models on a task designed to closely simulate a high-throughput discovery campaign for new stable inorganic crystals.
1919

0 commit comments

Comments
 (0)