Skip to content

Commit 97c6949

Browse files
committed
add slurm_submit_python() in new module mb_discovery/slurm.py
1 parent 4f6d853 commit 97c6949

File tree

6 files changed

+181
-67
lines changed

6 files changed

+181
-67
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ job-logs/
2121
# slurm logs
2222
slurm-*out
2323
models/**/*.csv
24+
mb_discovery/energy/**/*.csv
2425

2526
# temporary ignore rule
2627
paper
28+
meeting-notes
29+
models/voronoi

mb_discovery/slurm.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
import subprocess
3+
import sys
4+
from collections.abc import Sequence
5+
6+
7+
def _get_calling_file_path(frame: int = 1) -> str:
8+
"""Return calling file's path.
9+
10+
Args:
11+
frame (int, optional): How many function call's up? Defaults to 1.
12+
13+
Returns:
14+
str: Calling function's file path n frames up the stack.
15+
"""
16+
caller_path = sys._getframe(frame).f_code.co_filename
17+
return caller_path
18+
19+
20+
def slurm_submit_python(
21+
job_name: str,
22+
log_dir: str,
23+
time: str,
24+
py_file_path: str = None,
25+
slurm_flags: Sequence[str] = (),
26+
partition: str = "icelake",
27+
account: str = "LEE-SL3-CPU",
28+
array: str = "",
29+
env_vars: str = "",
30+
) -> None:
31+
"""Slurm submit a python script using sbatch --wrap 'python path/to/file.py' by
32+
calling this function in the script and invoking the script with
33+
`python path/to/file.py slurm-submit`.
34+
35+
Args:
36+
job_name (str): Slurm job name.
37+
log_dir (str): Directory to write slurm logs. Log file will include job ID and
38+
array task ID.
39+
time (str): 'HH:MM:SS' time limit for the job.
40+
py_file_path (str): Path to the python script to be submitted. Defaults to the
41+
path of the file calling slurm_submit_python().
42+
slurm_flags (Sequence[str], optional): Extra slurm CLI flags. Defaults to ().
43+
partition (str, optional): Slurm partition. Defaults to "icelake".
44+
account (str, optional): Account to charge for this job.
45+
Defaults to "LEE-SL3-CPU".
46+
array (str, optional): Slurm array specifier. Defaults to "".
47+
env_vars (str, optional): Environment variables to set when running the python
48+
script, e.g. ENV_VAR=42 python path/to/file.py. Defaults to "".
49+
50+
Raises:
51+
SystemExit: Exit code will be subprocess.run(['sbatch', ...]).returncode.
52+
"""
53+
if "slurm-submit" not in sys.argv:
54+
return
55+
os.makedirs(log_dir, exist_ok=True) # slurm fails if log_dir is missing
56+
57+
# calling file's path.
58+
if py_file_path is None:
59+
py_file_path = _get_calling_file_path(frame=2)
60+
61+
cmd = [
62+
*f"sbatch --{partition=} --{account=} --{time=} --{array=}".split(),
63+
*("--job-name", job_name),
64+
*("--output", f"{log_dir}/slurm-%A-%a.out"),
65+
*slurm_flags,
66+
*("--wrap", f"'{env_vars} python {py_file_path}'"),
67+
]
68+
result = subprocess.run(cmd, check=True)
69+
70+
raise SystemExit(result.returncode)

models/bowsr/join_bowsr_results.py

-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030

3131

3232
# %%
33-
# 2022-08-16 tried multiprocessing.Pool() to load files in parallel but was somehow
34-
# slower than serial loading
3533
for file_path in tqdm(file_paths):
3634
if file_path in dfs:
3735
continue

models/bowsr/slurm_array_bowsr_wbm.py

+46-40
Original file line numberDiff line numberDiff line change
@@ -15,57 +15,61 @@
1515
from tqdm import tqdm
1616

1717
from mb_discovery import ROOT, as_dict_handler
18+
from mb_discovery.slurm import slurm_submit_python
1819

19-
"""
20-
To slurm submit this file, use
21-
22-
```sh
23-
log_dir=models/bowsr/$(date +"%Y-%m-%d")-bowsr-megnet-wbm
24-
job_name=bowsr-megnet-wbm-IS2RE
25-
mkdir -p $log_dir # slurm fails if log_dir is missing
26-
27-
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-500 \
28-
--time 12:0:0 --job-name $job_name --mem 12000 \
29-
--output $log_dir/slurm-%A-%a.out \
30-
--wrap "TF_CPP_MIN_LOG_LEVEL=2 python models/bowsr/slurm_array_bowsr_wbm.py"
31-
```
20+
__author__ = "Janosh Riebesell"
21+
__date__ = "2022-08-15"
3222

33-
--time 2h is probably enough but missing indices are annoying so best be safe.
34-
--mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s)
35-
Some of your processes may have been killed by the cgroup out-of-memory handler.
23+
"""
24+
To slurm submit this file, run:
3625
37-
TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
38-
https://stackoverflow.com/a/40982782
26+
python path/to/file.py slurm-submit
3927
4028
Requires MEGNet and MAML installation: pip install megnet maml
4129
"""
4230

43-
__author__ = "Janosh Riebesell"
44-
__date__ = "2022-08-15"
45-
31+
task_type = "IS2RE" # "RS2RE"
32+
today = f"{datetime.now():%Y-%m-%d}"
33+
module_dir = os.path.dirname(__file__)
34+
# --mem 12000 avoids slurmstepd: error: Detected 1 oom-kill event(s)
35+
# Some of your processes may have been killed by the cgroup out-of-memory handler.
36+
slurm_mem_per_node = 12000
37+
# set large job array size for fast testing/debugging
38+
slurm_array_task_count = 500
39+
out_dir = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
4640

47-
task_type = "IS2RE"
48-
# task_type = "RS2RE"
4941
data_path = f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz"
5042

51-
module_dir = os.path.dirname(__file__)
43+
slurm_submit_python(
44+
job_name=f"bowsr-megnet-wbm-{task_type}",
45+
log_dir=out_dir,
46+
time=(slurm_max_job_time := "3:0:0"),
47+
# --time 2h is probably enough but best be safe.
48+
array=f"1-{slurm_array_task_count}",
49+
slurm_flags=("--mem", str(slurm_mem_per_node)),
50+
partition="icelake-himem",
51+
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
52+
# https://stackoverflow.com/a/40982782
53+
env_vars="TF_CPP_MIN_LOG_LEVEL=2",
54+
)
55+
56+
57+
# %%
5258
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
5359
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
54-
# set large fallback job array size for fast testing/debugging
55-
slurm_array_task_count = int(os.environ.get("SLURM_ARRAY_TASK_COUNT", 10_000))
60+
out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
5661

5762
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
5863
print(f"{slurm_job_id = }")
5964
print(f"{slurm_array_task_id = }")
65+
print(f"{data_path = }")
66+
print(f"{out_path = }")
6067
print(f"{version('maml') = }")
6168
print(f"{version('megnet') = }")
6269

63-
today = f"{datetime.now():%Y-%m-%d}"
64-
out_dir = f"{module_dir}/{today}-bowsr-megnet-wbm-{task_type}"
65-
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
6670

67-
if os.path.isfile(json_out_path):
68-
raise SystemExit(f"{json_out_path = } already exists, exciting early")
71+
if os.path.isfile(out_path):
72+
raise SystemExit(f"{out_path = } already exists, exciting early")
6973

7074

7175
# %%
@@ -78,14 +82,16 @@
7882
optimize_kwargs = dict(n_init=100, n_iter=100, alpha=0.026**2)
7983

8084
run_params = dict(
81-
megnet_version=version("megnet"),
82-
maml_version=version("maml"),
83-
slurm_job_id=slurm_job_id,
84-
slurm_array_task_id=slurm_array_task_id,
85-
slurm_array_task_count=slurm_array_task_count,
86-
data_path=data_path,
8785
bayes_optim_kwargs=bayes_optim_kwargs,
86+
data_path=data_path,
87+
maml_version=version("maml"),
88+
megnet_version=version("megnet"),
8889
optimize_kwargs=optimize_kwargs,
90+
slurm_array_task_count=slurm_array_task_count,
91+
slurm_array_task_id=slurm_array_task_id,
92+
slurm_job_id=slurm_job_id,
93+
slurm_max_job_time=slurm_max_job_time,
94+
slurm_mem_per_node=slurm_mem_per_node,
8995
task_type=task_type,
9096
)
9197
if wandb.run is None:
@@ -127,7 +133,7 @@
127133

128134

129135
for material_id, structure in tqdm(
130-
structures.items(), desc="Main loop", total=len(structures)
136+
structures.items(), desc="Main loop", total=len(structures), disable=None
131137
):
132138
if material_id in relax_results:
133139
continue
@@ -154,6 +160,6 @@
154160
df_output = pd.DataFrame(relax_results).T
155161
df_output.index.name = "material_id"
156162

157-
df_output.reset_index().to_json(json_out_path, default_handler=as_dict_handler)
163+
df_output.reset_index().to_json(out_path, default_handler=as_dict_handler)
158164

159-
wandb.log_artifact(json_out_path, type=f"bowsr-megnet-wbm-{task_type}")
165+
wandb.log_artifact(out_path, type=f"bowsr-megnet-wbm-{task_type}")

models/m3gnet/slurm_array_m3gnet_relax_wbm.py

+31-25
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,43 @@
1111
import pandas as pd
1212
import wandb
1313
from m3gnet.models import Relaxer
14+
from tqdm import tqdm
1415

1516
from mb_discovery import ROOT, as_dict_handler
17+
from mb_discovery.slurm import slurm_submit_python
1618

1719
"""
18-
To slurm submit this file, use
20+
To slurm submit this file, run:
1921
20-
```sh
21-
job_name=m3gnet-wbm-relax-IS2RE
22-
log_dir=models/m3gnet/$(date +"%Y-%m-%d")-$job_name
23-
mkdir -p $log_dir # slurm fails if log_dir is missing
24-
25-
sbatch --partition icelake-himem --account LEE-SL3-CPU --array 1-100 \
26-
--time 3:0:0 --job-name $job_name --mem 12000 \
27-
--output $log_dir/slurm-%A-%a.out \
28-
--wrap "TF_CPP_MIN_LOG_LEVEL=2 python models/m3gnet/slurm_array_m3gnet_relax_wbm.py"
29-
```
30-
31-
--time 2h is probably enough but missing indices are annoying so best be safe.
32-
33-
TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
34-
https://stackoverflow.com/a/40982782
22+
python path/to/file.py slurm-submit
3523
3624
Requires M3GNet installation: pip install m3gnet
3725
"""
3826

3927
__author__ = "Janosh Riebesell"
4028
__date__ = "2022-08-15"
4129

42-
task_type = "IS2RE"
43-
# task_type = "RS2RE"
30+
task_type = "IS2RE" # "RS2RE"
31+
today = f"{datetime.now():%Y-%m-%d}"
32+
module_dir = os.path.dirname(__file__)
33+
slurm_array_task_count = 100
34+
slurm_mem_per_node = 12000
35+
job_name = f"m3gnet-wbm-relax-{task_type}"
36+
out_dir = f"{module_dir}/{today}-{job_name}"
37+
38+
slurm_submit_python(
39+
job_name=job_name,
40+
log_dir=out_dir,
41+
time=(slurm_max_job_time := "3:0:0"),
42+
array=f"1-{slurm_array_task_count}",
43+
slurm_flags=("--mem", str(slurm_mem_per_node)),
44+
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
45+
# https://stackoverflow.com/a/40982782
46+
env_vars="TF_CPP_MIN_LOG_LEVEL=2",
47+
)
48+
4449

50+
# %%
4551
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
4652
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
4753
# set large fallback job array size for fast testing/debugging
@@ -52,8 +58,6 @@
5258
print(f"{slurm_array_task_id = }")
5359
print(f"{version('m3gnet') = }")
5460

55-
today = f"{datetime.now():%Y-%m-%d}"
56-
out_dir = f"{ROOT}/data/{today}-m3gnet-wbm-{task_type}"
5761
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
5862

5963
if os.path.isfile(json_out_path):
@@ -71,19 +75,21 @@
7175
df_this_job = np.array_split(df_wbm, slurm_array_task_count)[slurm_array_task_id - 1]
7276

7377
run_params = dict(
78+
data_path=data_path,
7479
m3gnet_version=version("m3gnet"),
75-
slurm_job_id=slurm_job_id,
76-
slurm_array_task_id=slurm_array_task_id,
7780
slurm_array_task_count=slurm_array_task_count,
78-
data_path=data_path,
81+
slurm_array_task_id=slurm_array_task_id,
82+
slurm_job_id=slurm_job_id,
83+
slurm_max_job_time=slurm_max_job_time,
84+
slurm_mem_per_node=slurm_mem_per_node,
7985
task_type=task_type,
8086
)
8187
if wandb.run is None:
8288
wandb.login()
8389

8490
wandb.init(
8591
project="m3gnet",
86-
name=f"m3gnet-wbm-relax-{task_type}-{slurm_job_id}-{slurm_array_task_id}",
92+
name=f"{job_name}-{slurm_job_id}-{slurm_array_task_id}",
8793
config=run_params,
8894
)
8995

@@ -105,7 +111,7 @@
105111
raise ValueError(f"Unknown {task_type = }")
106112

107113

108-
for material_id, structure in structures.items():
114+
for material_id, structure in tqdm(structures.items(), disable=None):
109115
if material_id in relax_results:
110116
continue
111117
relax_result = relaxer.relax(structure)

tests/test_slurm.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import sys
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
from mb_discovery.slurm import _get_calling_file_path, slurm_submit_python
7+
8+
9+
def test_slurm_submit() -> None:
10+
11+
sys.argv += ["slurm-submit"]
12+
with pytest.raises(SystemExit) as exc_info, patch(
13+
"mb_discovery.slurm.subprocess.run"
14+
) as mock_run:
15+
slurm_submit_python(
16+
job_name="test_job",
17+
log_dir="test_log_dir",
18+
time="0:0:1",
19+
slurm_flags=("--test_flag",),
20+
)
21+
assert exc_info.value.code == 0
22+
assert mock_run.call_count == 1
23+
24+
25+
def test_get_calling_file_path() -> None:
26+
assert _get_calling_file_path(frame=1) == __file__
27+
28+
def wrapper(frame: int) -> str:
29+
return _get_calling_file_path(frame)
30+
31+
assert wrapper(frame=2) == __file__

0 commit comments

Comments
 (0)