1
+ from __future__ import annotations
2
+
1
3
import os
2
4
import subprocess
3
5
import sys
@@ -29,7 +31,7 @@ def slurm_submit(
29
31
partition : str ,
30
32
account : str ,
31
33
py_file_path : str = None ,
32
- slurm_flags : Sequence [str ] = (),
34
+ slurm_flags : str | Sequence [str ] = (),
33
35
array : str = None ,
34
36
pre_cmd : str = "" ,
35
37
) -> dict [str , str ]:
@@ -48,7 +50,7 @@ def slurm_submit(
48
50
Defaults to the path of the file calling slurm_submit().
49
51
partition (str, optional): Slurm partition.
50
52
account (str, optional): Account to charge for this job.
51
- slurm_flags (Sequence [str], optional): Extra slurm CLI flags. Defaults to ().
53
+ slurm_flags (str | list [str], optional): Extra slurm CLI flags. Defaults to ().
52
54
Examples: ('--nodes 1', '--gpus-per-node 1') or ('--mem', '16000').
53
55
array (str, optional): Slurm array specifier. Defaults to None. Example:
54
56
'9' (for SLURM_ARRAY_TASK_ID from 0-9 inclusive), '1-10' or '1-10%2', etc.
@@ -79,7 +81,7 @@ def slurm_submit(
79
81
* f"sbatch --{ partition = } --{ account = } --{ time = } " .replace ("'" , "" ).split (),
80
82
* ("--job-name" , job_name ),
81
83
* ("--output" , f"{ out_dir } /slurm-%A{ '-%a' if array else '' } .log" ),
82
- * slurm_flags ,
84
+ * ( slurm_flags . split () if isinstance ( slurm_flags , str ) else slurm_flags ) ,
83
85
* ("--wrap" , f"{ pre_cmd } python { py_file_path } " .strip ()),
84
86
]
85
87
if array :
@@ -93,6 +95,11 @@ def slurm_submit(
93
95
for key in SLURM_KEYS
94
96
if (val := os .environ .get (f"SLURM_{ key } " .upper ()))
95
97
}
98
+ slurm_vars ["slurm_timelimit" ] = time
99
+ if slurm_flags :
100
+ slurm_vars ["slurm_flags" ] = str (slurm_flags )
101
+ if pre_cmd :
102
+ slurm_vars ["pre_cmd" ] = pre_cmd
96
103
97
104
if (is_slurm_job and is_log_file ) or "slurm-submit" in sys .argv :
98
105
# print sbatch command at submission time and into slurm log file
0 commit comments