Skip to content

Commit e6bf955

Browse files
committed
add models/bowsr/slurm_array_bowsr_megnet_relax_wbm.py
1 parent 89976c4 commit e6bf955

File tree

3 files changed

+191
-32
lines changed

3 files changed

+191
-32
lines changed

models/bowsr/slurm_array_bowsr_wbm.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# %%
2+
from __future__ import annotations
3+
4+
import contextlib
5+
import os
6+
from datetime import datetime
7+
from importlib.metadata import version
8+
from typing import Any
9+
10+
import numpy as np
11+
import pandas as pd
12+
import wandb
13+
from maml.apps.bowsr.model.megnet import MEGNet
14+
from maml.apps.bowsr.optimizer import BayesianOptimizer
15+
from tqdm import tqdm
16+
17+
from mb_discovery import ROOT, as_dict_handler
18+
19+
"""
20+
To slurm submit this file, use
21+
22+
```sh
23+
# slurm will not create logdir automatically and fail if missing
24+
mkdir -p models/bowsr/slurm_logs
25+
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-500 \
26+
--time 12:0:0 --job-name bowsr-megnet-wbm-IS2RE --mem 12000 \
27+
--output models/bowsr/slurm_logs/slurm-%A-%a.out \
28+
--wrap "TF_CPP_MIN_LOG_LEVEL=2 python models/bowsr/slurm_array_bowsr_wbm.py"
29+
```
30+
31+
--time 2h is probably enough but missing indices are annoying so best be safe.
32+
--mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s)
33+
Some of your processes may have been killed by the cgroup out-of-memory handler.
34+
35+
TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
36+
https://stackoverflow.com/a/40982782
37+
38+
Requires MEGNet and MAML installation: pip install megnet maml
39+
"""
40+
41+
__author__ = "Janosh Riebesell"
42+
__date__ = "2022-08-15"
43+
44+
45+
task_type = "IS2RE"
46+
# task_type = "RS2RE"
47+
data_path = f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz"
48+
49+
module_dir = os.path.dirname(__file__)
50+
job_id = os.environ.get("SLURM_JOB_ID", "debug")
51+
job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
52+
# set large fallback job array size for fast testing/debugging
53+
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
54+
55+
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
56+
print(f"{job_id = }")
57+
print(f"{job_array_id = }")
58+
print(f"{version('maml') = }")
59+
print(f"{version('megnet') = }")
60+
61+
today = f"{datetime.now():%Y-%m-%d}"
62+
out_dir = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
63+
os.makedirs(out_dir, exist_ok=True)
64+
json_out_path = f"{out_dir}/{job_array_id}.json.gz"
65+
66+
if os.path.isfile(json_out_path):
67+
raise SystemExit(f"{json_out_path = } already exists, exciting early")
68+
69+
70+
# %%
71+
bayes_optim_kwargs = dict(
72+
relax_coords=True,
73+
relax_lattice=True,
74+
use_symmetry=True,
75+
seed=42,
76+
)
77+
optimize_kwargs = dict(n_init=100, n_iter=100, alpha=0.026**2)
78+
79+
run_params = dict(
80+
megnet_version=version("megnet"),
81+
maml_version=version("maml"),
82+
job_id=job_id,
83+
job_array_id=job_array_id,
84+
data_path=data_path,
85+
bayes_optim_kwargs=bayes_optim_kwargs,
86+
optimize_kwargs=optimize_kwargs,
87+
)
88+
if wandb.run is None:
89+
wandb.login()
90+
91+
# getting wandb: 429 encountered ({"error":"rate limit exceeded"}), retrying request
92+
# https://community.wandb.ai/t/753/14
93+
wandb.init(
94+
entity="janosh",
95+
project="matbench-discovery",
96+
name=f"bowsr-megnet-wbm-{task_type}-{job_id}-{job_array_id}",
97+
config=run_params,
98+
)
99+
100+
101+
# %%
102+
print(f"Loading from {data_path=}")
103+
df_wbm = pd.read_json(data_path).set_index("material_id")
104+
105+
df_this_job = np.array_split(df_wbm, job_array_size + 1)[job_array_id]
106+
107+
108+
# %%
109+
model = MEGNet()
110+
relax_results: dict[str, dict[str, Any]] = {}
111+
112+
if task_type == "IS2RE":
113+
from pymatgen.core import Structure
114+
115+
structures = df_this_job.initial_structure.map(Structure.from_dict)
116+
elif task_type == "RS2RE":
117+
from pymatgen.entries.computed_entries import ComputedStructureEntry
118+
119+
structures = df_this_job.cse.map(
120+
lambda x: ComputedStructureEntry.from_dict(x).structure
121+
)
122+
else:
123+
raise ValueError(f"Unknown {task_type = }")
124+
125+
126+
for material_id, structure in tqdm(
127+
structures.items(), desc="Main loop", total=len(structures)
128+
):
129+
if material_id in relax_results:
130+
continue
131+
bayes_optimizer = BayesianOptimizer(
132+
model=model, structure=structure, **bayes_optim_kwargs
133+
)
134+
bayes_optimizer.set_bounds()
135+
# reason for devnull here: https://github.com/materialsvirtuallab/maml/issues/469
136+
with open(os.devnull, "w") as devnull, contextlib.redirect_stdout(devnull):
137+
bayes_optimizer.optimize(**optimize_kwargs)
138+
139+
structure_pred, energy_pred = bayes_optimizer.get_optimized_structure_and_energy()
140+
141+
results = dict(
142+
e_form_per_atom_pred=model.predict_energy(structure),
143+
structure_pred=structure_pred,
144+
energy_pred=energy_pred,
145+
)
146+
147+
relax_results[material_id] = results
148+
149+
150+
# %%
151+
df_output = pd.DataFrame(relax_results).T
152+
df_output.index.name = "material_id"
153+
154+
df_output.reset_index().to_json(json_out_path, default_handler=as_dict_handler)
155+
156+
wandb.log_artifact(json_out_path, type=f"bowsr-megnet-wbm-{task_type}")

models/m3gnet/slurm_array_m3gnet_relax_wbm.py

+32-29
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import os
55
import warnings
66
from datetime import datetime
7+
from importlib.metadata import version
78
from typing import Any
89

9-
import m3gnet
1010
import numpy as np
1111
import pandas as pd
1212
import wandb
@@ -18,33 +18,37 @@
1818
To slurm submit this file, use
1919
2020
```sh
21-
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-101 \
22-
--time 3:0:0 --job-name m3gnet-wbm-relax-RS2RE --mem 12000 \
21+
# slurm will not create logdir automatically and fail if missing
22+
mkdir -p models/m3gnet/slurm_logs
23+
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-100 \
24+
--time 3:0:0 --job-name m3gnet-wbm-relax-IS2RE --mem 12000 \
2325
--output models/m3gnet/slurm_logs/slurm-%A-%a.out \
24-
--wrap "python models/m3gnet/slurm_array_m3gnet_relax_wbm.py"
26+
--wrap "TF_CPP_MIN_LOG_LEVEL=2 python models/m3gnet/slurm_array_m3gnet_relax_wbm.py"
2527
```
2628
2729
--time 2h is probably enough but missing indices are annoying so best be safe.
2830
31+
TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
32+
https://stackoverflow.com/a/40982782
33+
2934
Requires M3GNet installation: pip install m3gnet
3035
"""
3136

3237
__author__ = "Janosh Riebesell"
3338
__date__ = "2022-08-15"
3439

35-
# task_type = "IS2RE"
36-
task_type = "RS2RE"
40+
task_type = "IS2RE"
41+
# task_type = "RS2RE"
3742

38-
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
3943
job_id = os.environ.get("SLURM_JOB_ID", "debug")
40-
print(f"{job_id=}")
41-
m3gnet_version = m3gnet.__version__
42-
print(f"{m3gnet_version=}")
43-
4444
job_array_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
45-
# set default job array size to 1000 for fast testing
45+
# set large fallback job array size for fast testing/debugging
4646
job_array_size = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
47-
print(f"{job_array_id=}")
47+
48+
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
49+
print(f"{job_id = }")
50+
print(f"{job_array_id = }")
51+
print(f"{version('m3gnet') = }")
4852

4953
today = f"{datetime.now():%Y-%m-%d}"
5054
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-relax-{task_type}"
@@ -57,50 +61,51 @@
5761
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
5862
warnings.filterwarnings(action="ignore", category=UserWarning, module="tensorflow")
5963

60-
relax_results: dict[str, dict[str, Any]] = {}
61-
6264

6365
# %%
6466
data_path = f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz"
67+
print(f"Loading from {data_path=}")
6568
df_wbm = pd.read_json(data_path).set_index("material_id")
6669

67-
df_to_relax = np.array_split(df_wbm, job_array_size)[job_array_id]
70+
df_this_job = np.array_split(df_wbm, job_array_size)[job_array_id]
6871

6972
run_params = dict(
70-
m3gnet_version=m3gnet_version,
73+
m3gnet_version=version("m3gnet"),
7174
job_id=job_id,
7275
job_array_id=job_array_id,
7376
data_path=data_path,
7477
)
7578
if wandb.run is None:
7679
wandb.login()
80+
7781
wandb.init(
78-
project="m3gnet", # run will be added to this project
82+
project="m3gnet",
7983
name=f"m3gnet-wbm-relax-{task_type}-{job_id}-{job_array_id}",
8084
config=run_params,
8185
)
8286

8387

8488
# %%
8589
relaxer = Relaxer() # This loads the default pre-trained M3GNet model
90+
relax_results: dict[str, dict[str, Any]] = {}
8691

8792
if task_type == "IS2RE":
8893
from pymatgen.core import Structure
8994

90-
structures = df_to_relax.initial_structure.map(Structure.from_dict)
95+
structures = df_this_job.initial_structure.map(Structure.from_dict)
9196
elif task_type == "RS2RE":
9297
from pymatgen.entries.computed_entries import ComputedStructureEntry
9398

94-
df_to_relax.cse = df_to_relax.cse.map(ComputedStructureEntry.from_dict)
95-
structures = df_to_relax.cse.map(lambda x: x.structure)
99+
df_this_job.cse = df_this_job.cse.map(ComputedStructureEntry.from_dict)
100+
structures = df_this_job.cse.map(lambda x: x.structure)
96101
else:
97102
raise ValueError(f"Unknown {task_type = }")
98103

99104

100-
for material_id, struct in structures.items():
105+
for material_id, structure in structures.items():
101106
if material_id in relax_results:
102107
continue
103-
relax_result = relaxer.relax(struct)
108+
relax_result = relaxer.relax(structure)
104109
relax_dict = {
105110
"m3gnet_structure": relax_result["final_structure"],
106111
"m3gnet_trajectory": relax_result["trajectory"].__dict__,
@@ -110,11 +115,9 @@
110115

111116

112117
# %%
113-
df_m3gnet = pd.DataFrame(relax_results).T
114-
df_m3gnet.index.name = "material_id"
115-
116-
117-
df_m3gnet.to_json(json_out_path, default_handler=as_dict_handler)
118+
df_output = pd.DataFrame(relax_results).T
119+
df_output.index.name = "material_id"
118120

121+
df_output.reset_index().to_json(json_out_path, default_handler=as_dict_handler)
119122

120-
wandb.log_artifact(json_out_path, type="m3gnet-relaxed-wbm-initial-structures")
123+
wandb.log_artifact(json_out_path, type=f"m3gnet-relax-wbm-{task_type}")

tests/test_plots.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_precision_recall_vs_calc_count(
4949
stability_threshold: float,
5050
expected_line_count: int,
5151
) -> None:
52-
ax = plt.figure().gca() # ensure test functions use different axes
52+
ax = plt.figure().gca() # new figure ensures test functions use different axes
5353

5454
for (model_name, df), color in zip(
5555
test_dfs.items(), ("tab:blue", "tab:orange", "tab:pink")
@@ -106,7 +106,7 @@ def test_precision_recall_vs_calc_count_raises(
106106
def test_rolling_mae_vs_hull_dist(
107107
half_window: float, bin_width: float, x_lim: tuple[float, float]
108108
) -> None:
109-
ax = plt.figure().gca() # ensure test functions use different axes
109+
ax = plt.figure().gca() # new figure ensures test functions use different axes
110110

111111
for (model_name, df), color in zip(
112112
test_dfs.items(), ("tab:blue", "tab:orange", "tab:pink")
@@ -136,7 +136,7 @@ def test_hist_classified_stable_as_func_of_hull_dist(
136136
stability_crit: StabilityCriterion,
137137
x_lim: tuple[float, float],
138138
) -> None:
139-
ax = plt.figure().gca() # ensure test functions use different axes
139+
ax = plt.figure().gca() # new figure ensures test functions use different axes
140140

141141
df = test_dfs["Wren"]
142142

0 commit comments

Comments
 (0)