Skip to content

Commit 29eecf2

Browse files
committed
add module doc string with installation instructions to train_mace.py
fix pyright possibly unbound variable errors
1 parent 71a5edc commit 29eecf2

File tree

11 files changed

+58
-50
lines changed

11 files changed

+58
-50
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: v0.3.0
10+
rev: v0.3.3
1111
hooks:
1212
- id: ruff
1313
args: [--fix]
@@ -30,7 +30,7 @@ repos:
3030
- id: trailing-whitespace
3131

3232
- repo: https://github.com/pre-commit/mirrors-mypy
33-
rev: v1.8.0
33+
rev: v1.9.0
3434
hooks:
3535
- id: mypy
3636
additional_dependencies: [types-pyyaml, types-requests]
@@ -56,7 +56,7 @@ repos:
5656
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/(routes|figs).+\.(yaml|json)|changelog.md)$
5757

5858
- repo: https://github.com/pre-commit/mirrors-eslint
59-
rev: v9.0.0-beta.1
59+
rev: v9.0.0-beta.2
6060
hooks:
6161
- id: eslint
6262
types: [file]
@@ -80,7 +80,7 @@ repos:
8080
- id: check-github-actions
8181

8282
- repo: https://github.com/RobertCraigie/pyright-python
83-
rev: v1.1.352
83+
rev: v1.1.355
8484
hooks:
8585
- id: pyright
8686
args: [--level, error]

data/mp/build_phase_diagram.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,14 @@
105105

106106

107107
# %%
108-
df_mp["our_mp_e_form"] = [
108+
e_form_us = "e_form_us"
109+
df_mp[e_form_us] = [
109110
get_e_form_per_atom(mp_computed_entries[mp_id]) for mp_id in df_mp.index
110111
]
111112

112113

113114
# make sure get_form_energy_per_atom() reproduces MP formation energies
114-
ax = pymatviz.density_scatter(df_mp[Key.form_energy], df_mp["our_mp_e_form"])
115+
ax = pymatviz.density_scatter(df_mp[Key.form_energy], df_mp[e_form_us])
115116
ax.set(
116117
title="MP Formation Energy Comparison",
117118
xlabel="MP Formation Energy (eV/atom)",

data/wbm/compare_cse_vs_ce_mp_2020_corrections.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
"""
2020
NOTE MaterialsProject2020Compatibility takes structural information into account when
21-
correcting energies (only applies to certain oxides and sulfides). Always use
21+
correcting energies (for certain oxides and sulfides). Always use
2222
ComputedStructureEntry, not ComputedEntry when applying corrections.
2323
"""
2424

data/wbm/compile_wbm_test_set.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def increment_wbm_material_id(wbm_id: str) -> str:
161161
assert df_wbm.index[-1] == "wbm-5-23308"
162162

163163
df_wbm[Key.init_struct] = df_wbm.pop("org")
164-
df_wbm["final_structure"] = df_wbm.pop("opt")
165-
assert list(df_wbm.columns) == [Key.init_struct, "final_structure"]
164+
df_wbm[Key.final_struct] = df_wbm.pop("opt")
165+
assert list(df_wbm.columns) == [Key.init_struct, Key.final_struct]
166166

167167

168168
# %% download WBM ComputedStructureEntries from
@@ -247,7 +247,7 @@ def increment_wbm_material_id(wbm_id: str) -> str:
247247
]
248248

249249
df_wbm["composition_from_final_struct"] = [
250-
Structure.from_dict(struct).composition for struct in tqdm(df_wbm.final_structure)
250+
Structure.from_dict(struct).composition for struct in tqdm(df_wbm[Key.final_struct])
251251
]
252252

253253
# all but 1 composition matches between CSE and final structure
@@ -499,7 +499,9 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
499499
for mat_id, cse in df_wbm[Key.cse].items():
500500
assert mat_id == cse["entry_id"], f"{mat_id} != {cse['entry_id']}"
501501

502-
df_wbm["cse"] = [ComputedStructureEntry.from_dict(dct) for dct in tqdm(df_wbm[Key.cse])]
502+
df_wbm[Key.cse] = [
503+
ComputedStructureEntry.from_dict(dct) for dct in tqdm(df_wbm[Key.cse])
504+
]
503505
# raw WBM ComputedStructureEntries have no energy corrections applied:
504506
assert all(cse.uncorrected_energy == cse.energy for cse in df_wbm.cse)
505507
# summary and CSE n_sites match
@@ -548,7 +550,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
548550
# takes ~20 min at 200 it/s for 250k entries in WBM
549551
assert Key.each_true not in df_summary
550552

551-
for mat_id, cse in tqdm(df_wbm.cse.items(), total=len(df_wbm)):
553+
for mat_id, cse in tqdm(df_wbm[Key.cse].items(), total=len(df_wbm)):
552554
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
553555
assert cse.entry_id in df_summary.index, f"{cse.entry_id=} not in df_summary"
554556

@@ -562,7 +564,7 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
562564
assert sum(df_wbm.index != df_summary.index) == 0
563565

564566
for row in tqdm(df_wbm.itertuples(), total=len(df_wbm), desc="ML energies to CSEs"):
565-
mat_id, cse, formula = row.Index, row.cse, row.formula_from_cse
567+
mat_id, cse, formula = row.Index, row[Key.cse], row.formula_from_cse
566568
assert mat_id == cse.entry_id, f"{mat_id=} != {cse.entry_id=}"
567569
assert mat_id in df_summary.index, f"{mat_id=} not in df_summary"
568570

@@ -665,12 +667,9 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
665667

666668
# %% only here to load data for later inspection
667669
if False:
668-
wbm_summary_path = f"{WBM_DIR}/2022-10-19-wbm-summary.csv.gz"
669-
df_summary = pd.read_csv(wbm_summary_path).set_index(Key.mat_id)
670-
df_wbm = pd.read_json(
671-
f"{WBM_DIR}/2022-10-19-wbm-computed-structure-entries+init-structs.json.bz2"
672-
).set_index(Key.mat_id)
670+
df_summary = pd.read_csv(DATA_FILES.wbm_summary).set_index(Key.mat_id)
671+
df_wbm = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index(Key.mat_id)
673672

674-
df_wbm["cse"] = [
673+
df_wbm[Key.cse] = [
675674
ComputedStructureEntry.from_dict(dct) for dct in tqdm(df_wbm[Key.cse])
676675
]

matbench_discovery/plots.py

+1
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def hist_classified_stable_vs_hull_dist(
106106
clf_col, value_name = "classified", "count"
107107

108108
df_plot = pd.DataFrame()
109+
each_true_pos = each_true_neg = each_false_neg = each_false_pos = None
109110

110111
for facet, df_group in (
111112
df.groupby(kwargs["facet_col"]) if "facet_col" in kwargs else [(None, df)]

models/alignn_ff/readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This effort was aborted for the following reasons:
88
1. **Training difficulties**: ALIGNN-FF proved to be very resource-hungry. [12 GB of MPtrj training data](https://figshare.com/articles/dataset/23713842) turned into 600 GB of ALIGNN graph data. This forces small batch size even on nodes with large GPU memory, which slowed down training.
99
1. **Ineffectiveness of fine-tuning**: Efforts to fine-tune the ALIGNN-FF WT10 model on the CHGNet data suffered high initial loss, even worse than the untrained model, indicating significant dataset incompatibility.
1010

11-
The decision to abort adding ALIGNN FF to Matbench Discovery v1 was made after weeks of work due to ongoing technical challenges and resource limitations. See the [PR discussion](https://github.com/janosh/matbench-discovery/pull/47) for further details.
11+
The decision to abort testing ALIGNN FF was made after weeks of work due to ongoing technical challenges and resource limitations. See the [PR discussion](https://github.com/janosh/matbench-discovery/pull/47) for further details.
1212

1313
## Fine-tuning
1414

models/mace/readme.md

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## MACE formation energy predictions on WBM test set
22

33
The original MACE submission used the 2M parameter checkpoint [`2023-08-14-mace-yuan-trained-mptrj-04.model`](https://figshare.com/ndownloader/files/42374049) trained by Yuan Chiang on the [MPtrj dataset](https://figshare.com/articles/dataset/23713842).
4-
We initially tested the `2023-07-14-mace-universal-2-big-128-6.model` checkpoint trained on the much smaller [original M3GNet training set](https://figshare.com/articles/dataset/MPF_2021_2_8/19470599) which we received directly from Ilyes Batatia. MPtrj-trained MACE performed better and was used for the Matbench Discovery v1 submission.
4+
We initially tested the `2023-07-14-mace-universal-2-big-128-6.model` checkpoint trained on the much smaller [original M3GNet training set](https://figshare.com/articles/dataset/MPF_2021_2_8/19470599) which we received directly from Ilyes Batatia. MPtrj-trained MACE performed better and was used for the Matbench Discovery submission.
55

66
In late October (received 2023-10-29), Philipp Benner trained a much larger 16M parameter MACE for over 100 epochs in MPtrj which achieved an (at the time SOTA) F1 score of 0.64 and DAF of 3.13.
77

@@ -21,12 +21,7 @@ MACE relaxed each test set structure until the maximum force in the training set
2121

2222
#### Training
2323

24-
- `loss="uip"`
25-
- `energy_weight=1`
26-
- `forces_weight=1`
27-
- `stress_weight=0.01`
28-
- `r_max=6.0`
29-
- `lr=0.005`
30-
- `batch_size=10`
24+
See the module doc string in `train_mace.py` for how to install MACE for multi-GPU training.
25+
A single-GPU training script that works with the current [MACE PyPI release](https://pypi.org/project/mace-torch) (v0.3.4 as of 2024-03-21) could be provided if there's interest.
3126

32-
We used conditional loss weighting. We did _not_ use MACE's newest attention block feature which in our testing performed significantly worse than `RealAgnosticResidualInteractionBlock`.
27+
Our training used conditional loss weighting. We did _not_ use MACE's newest attention block feature which in our testing performed significantly worse than `RealAgnosticResidualInteractionBlock`.

models/mace/train_mace.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
"""
2+
This script requires installing the as-yet unmerged multi-GPU branch
3+
in the MACE repo.
4+
pip install git+https://github.com/ACEsuit/mace@multi-GPU
5+
Plan is to merge it into main and then release to PyPI. At that point,
6+
the install command will be:
7+
pip install mace-torch
8+
9+
If you want to fine-tune an existing MACE checkpoint rather than train a
10+
model from scratch, install the foundations branch instead which has an interface
11+
just for that.
12+
pip install git+https://github.com/ACEsuit/mace@foundations
13+
"""
14+
115
from __future__ import annotations
216

317
import ast
@@ -36,9 +50,6 @@
3650
__date__ = "2023-09-18"
3751

3852

39-
# This script requires installing MACE.
40-
# pip install git+https://github.com/ACEsuit/mace
41-
4253
module_dir = os.path.dirname(__file__)
4354

4455
slurm_vars = slurm_submit(
@@ -77,8 +88,8 @@ def main(**kwargs: Any) -> None:
7788
if args.distributed:
7889
try:
7990
distr_env = DistributedEnvironment()
80-
except Exception as e:
81-
print(f"Error specifying environment for distributed training: {e}")
91+
except Exception as exc:
92+
print(f"Error specifying environment for distributed training: {exc}")
8293
return
8394
world_size = distr_env.world_size
8495
local_rank = distr_env.local_rank
@@ -122,10 +133,10 @@ def main(**kwargs: Any) -> None:
122133

123134
# Data preparation
124135
if args.train_file.endswith(".xyz"):
125-
if args.valid_file is not None:
126-
assert args.valid_file.endswith(
127-
".xyz"
128-
), "valid_file if given must be same format as train_file"
136+
if args.valid_file is not None and not args.valid_file.endswith(".xyz"):
137+
raise RuntimeError(
138+
f"valid_file must be .xyz if train_file is .xyz, got {args.valid_file}"
139+
)
129140
config_type_weights = get_config_type_weights(args.config_type_weights)
130141
collections, atomic_energies_dict = get_dataset_from_xyz(
131142
train_path=args.train_file,
@@ -150,7 +161,7 @@ def main(**kwargs: Any) -> None:
150161
f"{len(collections.valid)}, tests=[{test_config_lens}]"
151162
)
152163
elif args.train_file.endswith(".h5"):
153-
atomic_energies_dict = None
164+
atomic_energies_dict = collections = None
154165
else:
155166
raise RuntimeError(
156167
f"train_file must be either .xyz or .h5, got {args.train_file}"
@@ -485,8 +496,8 @@ def main(**kwargs: Any) -> None:
485496
f"{args.swa_forces_weight}, learning rate : {args.swa_lr}"
486497
)
487498
if args.loss == "forces_only":
488-
print("Can not select swa with forces only loss.")
489-
elif args.loss == "virials":
499+
raise RuntimeError("Can not select SWA with forces-only loss.")
500+
if args.loss == "virials":
490501
loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss(
491502
energy_weight=args.swa_energy_weight,
492503
forces_weight=args.swa_forces_weight,

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ running-models = [
5555
# when attempting PyPI publish
5656
# "aviary@git+https://github.com/CompRhys/aviary",
5757
"alignn",
58-
"chgnet",
58+
"chgnet>=0.3.0",
5959
"jarvis-tools",
6060
"m3gnet",
6161
"mace-torch",
@@ -93,6 +93,7 @@ ignore = [
9393
"D205", # blank-line-after-summary
9494
"DTZ005",
9595
"E731", # lambda-assignment
96+
"EM101",
9697
"EM102",
9798
"FBT001",
9899
"FBT002",

scripts/update_wandb_runs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222

2323

2424
# %%
25-
df = pd.DataFrame([run.config | dict(run.summary) for run in runs])
26-
df[["display_name", "id"]] = [(run.display_name, run.id) for run in runs]
25+
df_runs = pd.DataFrame([run.config | dict(run.summary) for run in runs])
26+
df_runs[["display_name", "id"]] = [(run.display_name, run.id) for run in runs]
2727

2828

2929
# %%
30-
df.isna().sum()
30+
df_runs.isna().sum()
3131

3232

3333
# %% --- Update run metadata ---
@@ -41,9 +41,9 @@
4141
"mace-wbm-IS2RE-debug-", "mace-wbm-IS2RE-"
4242
)
4343

44-
for x in (Task.IS2RE, "ES2RE"):
45-
if x in run.display_name:
46-
new_config["task_type"] = x
44+
for key in (Task.IS2RE, Task.RS2RE):
45+
if key in run.display_name:
46+
new_config["task_type"] = key
4747

4848
if "SLURM_JOB_ID" in new_config:
4949
new_config["slurm_job_id"] = new_config.pop("SLURM_JOB_ID")

site/src/routes/preprint/references.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,7 @@ references:
14261426
URL: https://www.nature.com/articles/s41467-021-23339-x
14271427
volume: '12'
14281428

1429-
- id: mok_direction-based_2022
1429+
- id: mok_directionbased_2022
14301430
accessed:
14311431
- year: 2022
14321432
month: 10
@@ -1438,7 +1438,7 @@ references:
14381438
given: Jongseung
14391439
- family: Back
14401440
given: Seoin
1441-
citation-key: mok_direction-based_2022
1441+
citation-key: mok_directionbased_2022
14421442
DOI: 10.26434/chemrxiv-2022-dp58c
14431443
genre: preprint
14441444
issued:

0 commit comments

Comments
 (0)