@@ -29,8 +29,9 @@ def _get_calling_file_path(frame: int = 1) -> str:
29
29
def slurm_submit (
30
30
job_name : str ,
31
31
out_dir : str ,
32
- time : str ,
33
- account : str ,
32
+ * ,
33
+ time : str | None = None ,
34
+ account : str | None = None ,
34
35
partition : str | None = None ,
35
36
py_file_path : str | None = None ,
36
37
slurm_flags : str | Sequence [str ] = (),
@@ -72,30 +73,34 @@ def slurm_submit(
72
73
73
74
os .makedirs (out_dir , exist_ok = True ) # slurm fails if out_dir is missing
74
75
76
+ # ensure pre_cmd ends with a semicolon
77
+ if pre_cmd and not pre_cmd .strip ().endswith (";" ):
78
+ pre_cmd += ";"
79
+
75
80
cmd = [
76
- * f"sbatch --{ account = } --{ time = } " .replace ("'" , "" ).split (),
77
- * ("--job-name" , job_name ),
81
+ * ("sbatch" , "--job-name" , job_name ),
78
82
* ("--output" , f"{ out_dir } /slurm-%A{ '-%a' if array else '' } .log" ),
79
83
* (slurm_flags .split () if isinstance (slurm_flags , str ) else slurm_flags ),
80
- * ("--wrap" , f"{ pre_cmd } python { py_file_path } " .strip ()),
84
+ * ("--wrap" , f"{ pre_cmd or '' } python { py_file_path } " .strip ()),
81
85
]
82
- if partition :
83
- cmd += [ "--partition" , partition ]
84
- if array :
85
- cmd += [ "--array " , array ]
86
+ for flag in ( f" { time = } " , f" { account = } " , f" { partition = } " , f" { array = } " ) :
87
+ key , val = flag . split ( "=" )
88
+ if val != "None" :
89
+ cmd += ( f "--{ key } " , val )
86
90
87
91
is_log_file = not sys .stdout .isatty ()
88
92
is_slurm_job = "SLURM_JOB_ID" in os .environ
89
93
90
94
slurm_vars = {
91
- f"slurm_{ key } " : val
95
+ f"slurm_{ key } " : os . environ [ f"SLURM_ { key } " . upper ()]
92
96
for key in SLURM_KEYS
93
- if ( val := os . getenv ( f"SLURM_{ key } " .upper ()))
97
+ if f"SLURM_{ key } " .upper () in os . environ
94
98
}
95
- slurm_vars ["slurm_timelimit" ] = time
96
- if slurm_flags :
99
+ if time is not None :
100
+ slurm_vars ["slurm_timelimit" ] = time
101
+ if slurm_flags != ():
97
102
slurm_vars ["slurm_flags" ] = str (slurm_flags )
98
- if pre_cmd :
103
+ if pre_cmd not in ( "" , None ) :
99
104
slurm_vars ["pre_cmd" ] = pre_cmd
100
105
101
106
# print sbatch command into slurm log file and at job submission time
0 commit comments