Skip to content

Commit 3a0173b

Browse files
committed
slurm_submit() accept strings for slurm_flags, return timelimit, slurm_flags, pre_cmd as part of slurm_vars
1 parent b292e9b commit 3a0173b

File tree

10 files changed

+30
-23
lines changed

10 files changed

+30
-23
lines changed

matbench_discovery/slurm.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import os
24
import subprocess
35
import sys
@@ -29,7 +31,7 @@ def slurm_submit(
2931
partition: str,
3032
account: str,
3133
py_file_path: str = None,
32-
slurm_flags: Sequence[str] = (),
34+
slurm_flags: str | Sequence[str] = (),
3335
array: str = None,
3436
pre_cmd: str = "",
3537
) -> dict[str, str]:
@@ -48,7 +50,7 @@ def slurm_submit(
4850
Defaults to the path of the file calling slurm_submit().
4951
partition (str, optional): Slurm partition.
5052
account (str, optional): Account to charge for this job.
51-
slurm_flags (Sequence[str], optional): Extra slurm CLI flags. Defaults to ().
53+
slurm_flags (str | list[str], optional): Extra slurm CLI flags. Defaults to ().
5254
Examples: ('--nodes 1', '--gpus-per-node 1') or ('--mem', '16000').
5355
array (str, optional): Slurm array specifier. Defaults to None. Example:
5456
'9' (for SLURM_ARRAY_TASK_ID from 0-9 inclusive), '1-10' or '1-10%2', etc.
@@ -79,7 +81,7 @@ def slurm_submit(
7981
*f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(),
8082
*("--job-name", job_name),
8183
*("--output", f"{out_dir}/slurm-%A{'-%a' if array else ''}.log"),
82-
*slurm_flags,
84+
*(slurm_flags.split() if isinstance(slurm_flags, str) else slurm_flags),
8385
*("--wrap", f"{pre_cmd} python {py_file_path}".strip()),
8486
]
8587
if array:
@@ -93,6 +95,11 @@ def slurm_submit(
9395
for key in SLURM_KEYS
9496
if (val := os.environ.get(f"SLURM_{key}".upper()))
9597
}
98+
slurm_vars["slurm_timelimit"] = time
99+
if slurm_flags:
100+
slurm_vars["slurm_flags"] = str(slurm_flags)
101+
if pre_cmd:
102+
slurm_vars["pre_cmd"] = pre_cmd
96103

97104
if (is_slurm_job and is_log_file) or "slurm-submit" in sys.argv:
98105
# print sbatch command at submission time and into slurm log file

models/bowsr/test_bowsr.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
out_dir=out_dir,
4747
partition="icelake-himem",
4848
account="LEE-SL3-CPU",
49-
time=(slurm_max_job_time := "12:0:0"),
49+
time="12:0:0",
5050
# --time 2h is probably enough but best be safe.
5151
array=f"1-{slurm_array_task_count}%{slurm_max_parallel}",
5252
slurm_flags=("--mem", str(slurm_mem_per_node)),
@@ -86,9 +86,6 @@
8686
seed=42,
8787
)
8888
optimize_kwargs = dict(n_init=100, n_iter=100, alpha=0.026**2)
89-
slurm_dict = dict(
90-
slurm_max_parallel=slurm_max_parallel, slurm_max_job_time=slurm_max_job_time
91-
)
9289

9390
run_params = dict(
9491
bayes_optim_kwargs=bayes_optim_kwargs,
@@ -99,7 +96,7 @@
9996
energy_model_version=version(energy_model),
10097
optimize_kwargs=optimize_kwargs,
10198
task_type=task_type,
102-
slurm_vars=slurm_vars | slurm_dict,
99+
slurm_vars=slurm_vars,
103100
)
104101
if wandb.run is None:
105102
wandb.login()

models/cgcnn/test_cgcnn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
job_name=job_name,
3939
partition="ampere",
4040
account="LEE-SL3-GPU",
41-
time=(slurm_max_job_time := "2:0:0"),
41+
time="2:0:0",
4242
out_dir=out_dir,
43-
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
43+
slurm_flags="--nodes 1 --gpus-per-node 1",
4444
)
4545

4646

@@ -90,7 +90,7 @@
9090
target_col=target_col,
9191
input_col=input_col,
9292
filters=filters,
93-
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
93+
slurm_vars=slurm_vars,
9494
)
9595

9696

models/cgcnn/train_cgcnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
time="8:0:0",
3939
array=f"1-{ensemble_size}",
4040
out_dir=out_dir,
41-
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
41+
slurm_flags="--nodes 1 --gpus-per-node 1",
4242
)
4343

4444

models/m3gnet/test_m3gnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
out_dir=out_dir,
3838
partition="icelake-himem",
3939
account="LEE-SL3-CPU",
40-
time=(slurm_max_job_time := "3:0:0"),
40+
time="3:0:0",
4141
array=f"1-{slurm_array_task_count}",
4242
slurm_flags=("--mem", str(slurm_mem_per_node)),
4343
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
@@ -73,7 +73,7 @@
7373
m3gnet_version=version("m3gnet"),
7474
task_type=task_type,
7575
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
76-
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
76+
slurm_vars=slurm_vars,
7777
)
7878
if wandb.run is None:
7979
wandb.login()

models/megnet/test_megnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
out_dir=out_dir,
3535
partition="icelake-himem",
3636
account="LEE-SL3-CPU",
37-
time=(slurm_max_job_time := "12:0:0"),
37+
time="12:0:0",
3838
slurm_flags=("--mem", "30G"),
3939
# TF_CPP_MIN_LOG_LEVEL=2 means INFO and WARNING logs are not printed
4040
# https://stackoverflow.com/a/40982782
@@ -65,7 +65,7 @@
6565
task_type=task_type,
6666
target_col=target_col,
6767
df=dict(shape=str(df_wbm_structs.shape), columns=", ".join(df_wbm_structs)),
68-
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
68+
slurm_vars=slurm_vars,
6969
)
7070
if wandb.run is None:
7171
wandb.login()

models/voronoi/voronoi_featurize_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
job_name=job_name,
3434
partition="icelake-himem",
3535
account="LEE-SL3-CPU",
36-
time=(slurm_max_job_time := "12:0:0"),
36+
time="12:0:0",
3737
array=f"1-{slurm_array_task_count}",
3838
slurm_flags=("--mem", "15G") if data_name == "mp" else (),
3939
out_dir=out_dir,
@@ -69,7 +69,7 @@
6969
data_path=data_path,
7070
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
7171
input_col=input_col,
72-
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
72+
slurm_vars=slurm_vars,
7373
)
7474
if wandb.run is None:
7575
wandb.login()

models/wrenformer/test_wrenformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
account="LEE-SL3-GPU",
3838
time="2:0:0",
3939
out_dir=out_dir,
40-
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
40+
slurm_flags="--nodes 1 --gpus-per-node 1",
4141
)
4242

4343

models/wrenformer/train_wrenformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
time="8:0:0",
3939
array=f"1-{ensemble_size}",
4040
out_dir=out_dir,
41-
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
41+
slurm_flags="--nodes 1 --gpus-per-node 1",
4242
)
4343

4444

tests/test_slurm.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) ->
2525
partition=partition,
2626
account=account,
2727
py_file_path=py_file_path,
28-
slurm_flags=("--test-flag",),
28+
slurm_flags="--foo",
2929
)
3030

3131
slurm_vars = func_call()
3232

33-
assert slurm_vars == {"slurm_job_id": "1234"}
33+
assert slurm_vars == dict(
34+
slurm_job_id="1234", slurm_timelimit="0:0:1", slurm_flags="--foo"
35+
)
36+
3437
stdout, stderr = capsys.readouterr()
3538
# check slurm_submit() did nothing in normal mode
3639
assert stderr == stderr == ""
@@ -45,7 +48,7 @@ def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) ->
4548

4649
sbatch_cmd = (
4750
f"sbatch --partition={partition} --account={account} --time={time} "
48-
f"--job-name {job_name} --output {out_dir}/slurm-%A.log --test-flag "
51+
f"--job-name {job_name} --output {out_dir}/slurm-%A.log --foo "
4952
f"--wrap python {py_file_path or __file__}"
5053
).replace(" --", "\n --")
5154
stdout, stderr = capsys.readouterr()

0 commit comments

Comments
 (0)