Skip to content

Commit 444fb7f

Browse files
committed
return slurm environment variables from slurm_submit_python()
1 parent 784ee96 commit 444fb7f

File tree

6 files changed

+87
-51
lines changed

6 files changed

+87
-51
lines changed

matbench_discovery/slurm.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from collections.abc import Sequence
55
from datetime import datetime
66

7+
SLURM_KEYS = (
8+
"job_id array_task_id array_task_count mem_per_node nodelist submit_host".split()
9+
)
10+
711

812
def _get_calling_file_path(frame: int = 1) -> str:
913
"""Return calling file's path.
@@ -28,7 +32,7 @@ def slurm_submit_python(
2832
slurm_flags: Sequence[str] = (),
2933
array: str = None,
3034
pre_cmd: str = "",
31-
) -> None:
35+
) -> dict[str, str]:
3236
"""Slurm submits a python script using `sbatch --wrap 'python path/to/file.py'`.
3337
3438
Usage: Call this function at the top of the script (before doing any real work) and
@@ -56,6 +60,10 @@ def slurm_submit_python(
5660
5761
Raises:
5862
SystemExit: Exit code will be subprocess.run(['sbatch', ...]).returncode.
63+
64+
Returns:
65+
dict[str, str]: Slurm variables like job ID, array task ID, compute nodes IDs,
66+
submission node ID and total job memory.
5967
"""
6068
if py_file_path is None:
6169
py_file_path = _get_calling_file_path(frame=2)
@@ -78,19 +86,26 @@ def slurm_submit_python(
7886

7987
is_log_file = not sys.stdout.isatty()
8088
is_slurm_job = "SLURM_JOB_ID" in os.environ
89+
90+
slurm_vars = {
91+
f"slurm_{key}": val
92+
for key in SLURM_KEYS
93+
if (val := os.environ.get(f"SLURM_{key}".upper()))
94+
}
95+
8196
if (is_slurm_job and is_log_file) or "slurm-submit" in sys.argv:
8297
# print sbatch command at submission time and into slurm log file
8398
# but not when running in command line or Jupyter
8499
print(f"\n{' '.join(cmd)}\n".replace(" --", "\n --"))
85-
for key in "JOB_ID ARRAY_TASK_ID MEM_PER_NODE NODELIST SUBMIT_HOST".split():
86-
if val := os.environ.get(f"SLURM_{key}"):
87-
print(f"SLURM_{key}={val}")
100+
for key, val in slurm_vars.items():
101+
print(f"{key}={val}")
88102

89103
if "slurm-submit" not in sys.argv:
90-
return
104+
return slurm_vars # if not submitting slurm job, resume outside code as normal
91105

92106
os.makedirs(log_dir, exist_ok=True) # slurm fails if log_dir is missing
93107

94108
result = subprocess.run(cmd, check=True)
95109

110+
# after sbatch submission, exit with slurm exit code
96111
raise SystemExit(result.returncode)

models/bowsr/slurm_array_bowsr_wbm.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@
3535
slurm_array_task_count = 500
3636
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
3737
today = timestamp.split("@")[0]
38-
job_name = f"bowsr-megnet-wbm-{task_type}"
38+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
39+
job_name = f"bowsr-megnet-wbm-{task_type}-{slurm_job_id}"
3940
out_dir = f"{module_dir}/{today}-{job_name}"
4041

4142
data_path = f"{ROOT}/data/2022-10-19-wbm-init-structs.json.gz"
4243

43-
slurm_submit_python(
44+
slurm_vars = slurm_submit_python(
4445
job_name=job_name,
4546
log_dir=out_dir,
4647
partition="icelake-himem",
@@ -56,13 +57,10 @@
5657

5758

5859
# %%
59-
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
6060
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
6161
out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
6262

6363
print(f"Job started running {timestamp}")
64-
print(f"{slurm_job_id = }")
65-
print(f"{slurm_array_task_id = }")
6664
print(f"{data_path = }")
6765
print(f"{out_path = }")
6866
print(f"{version('maml') = }")
@@ -88,12 +86,10 @@
8886
maml_version=version("maml"),
8987
megnet_version=version("megnet"),
9088
optimize_kwargs=optimize_kwargs,
89+
task_type=task_type,
9190
slurm_array_task_count=slurm_array_task_count,
92-
slurm_array_task_id=slurm_array_task_id,
93-
slurm_job_id=slurm_job_id,
9491
slurm_max_job_time=slurm_max_job_time,
95-
slurm_mem_per_node=slurm_mem_per_node,
96-
task_type=task_type,
92+
**slurm_vars,
9793
)
9894
if wandb.run is None:
9995
wandb.login()
@@ -103,7 +99,7 @@
10399
wandb.init(
104100
entity="janosh",
105101
project="matbench-discovery",
106-
name=f"bowsr-megnet-wbm-{task_type}-{slurm_job_id}-{slurm_array_task_id}",
102+
name=f"{job_name}-{slurm_array_task_id}",
107103
config=run_params,
108104
)
109105

models/m3gnet/slurm_array_m3gnet_wbm.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
# set large job array size for fast testing/debugging
3333
slurm_array_task_count = 100
3434
slurm_mem_per_node = 12000
35-
job_name = f"m3gnet-wbm-{task_type}"
35+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
36+
job_name = f"m3gnet-wbm-{task_type}-{slurm_job_id}"
3637
out_dir = f"{module_dir}/{today}-{job_name}"
3738

38-
slurm_submit_python(
39+
slurm_vars = slurm_submit_python(
3940
job_name=job_name,
4041
log_dir=out_dir,
4142
partition="icelake-himem",
@@ -50,12 +51,9 @@
5051

5152

5253
# %%
53-
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
5454
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
5555

5656
print(f"Job started running {timestamp}")
57-
print(f"{slurm_job_id = }")
58-
print(f"{slurm_array_task_id = }")
5957
print(f"{version('m3gnet') = }")
6058

6159
json_out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
@@ -80,11 +78,9 @@
8078
data_path=data_path,
8179
m3gnet_version=version("m3gnet"),
8280
slurm_array_task_count=slurm_array_task_count,
83-
slurm_array_task_id=slurm_array_task_id,
84-
slurm_job_id=slurm_job_id,
85-
slurm_max_job_time=slurm_max_job_time,
86-
slurm_mem_per_node=slurm_mem_per_node,
8781
task_type=task_type,
82+
slurm_max_job_time=slurm_max_job_time,
83+
**slurm_vars,
8884
)
8985
if wandb.run is None:
9086
wandb.login()

models/voronoi/featurize_mp_wbm.py

+49-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import warnings
44
from datetime import datetime
55

6+
import numpy as np
67
import pandas as pd
8+
import wandb
79
from matminer.featurizers.base import MultipleFeaturizer
810
from matminer.featurizers.composition import (
911
ElementProperty,
@@ -32,53 +34,74 @@
3234
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
3335
input_col = "structure"
3436
data_name = "wbm" if "wbm" in data_path else "mp"
37+
slurm_array_task_count = 100
38+
job_name = f"voronoi-featurize-{data_name}"
3539

36-
slurm_submit_python(
37-
job_name=f"voronoi-featurize-{data_name}",
40+
slurm_vars = slurm_submit_python(
41+
job_name=job_name,
3842
partition="icelake-himem",
3943
account="LEE-SL3-CPU",
40-
time="3:0:0",
44+
time=(slurm_max_job_time := "3:0:0"),
45+
array=f"1-{slurm_array_task_count}",
4146
log_dir=module_dir,
42-
slurm_flags=("--mem=40G",),
43-
)
44-
45-
46-
# %% Create the featurizer: Ward et al. use a variety of different featurizers
47-
# https://journals.aps.org/prb/abstract/10.1103/PhysRevB.96.024104
48-
featurizer = MultipleFeaturizer(
49-
[
50-
SiteStatsFingerprint.from_preset("CoordinationNumber_ward-prb-2017"),
51-
StructuralHeterogeneity(),
52-
ChemicalOrdering(),
53-
MaximumPackingEfficiency(),
54-
SiteStatsFingerprint.from_preset("LocalPropertyDifference_ward-prb-2017"),
55-
StructureComposition(Stoichiometry()),
56-
StructureComposition(ElementProperty.from_preset("magpie")),
57-
StructureComposition(ValenceOrbital(props=["frac"])),
58-
StructureComposition(IonProperty(fast=True)),
59-
],
6047
)
6148

6249

6350
# %%
6451
df = pd.read_json(data_path).set_index("material_id")
6552

53+
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
54+
df_this_job: pd.DataFrame = np.array_split(df, slurm_array_task_count)[
55+
slurm_array_task_id - 1
56+
]
57+
6658
if data_name == "mp":
67-
struct_dicts = [x["structure"] for x in df.entry]
59+
struct_dicts = [x["structure"] for x in df_this_job.entry]
6860
if data_name == "wbm":
69-
struct_dicts = df.initial_structure
61+
struct_dicts = df_this_job.initial_structure
62+
63+
df_this_job[input_col] = [
64+
Structure.from_dict(x) for x in tqdm(df_this_job.initial_structure, disable=None)
65+
]
7066

71-
df[input_col] = [
72-
Structure.from_dict(x) for x in tqdm(df.initial_structure, disable=None)
67+
68+
run_params = dict(
69+
data_path=data_path,
70+
slurm_max_job_time=slurm_max_job_time,
71+
**slurm_vars,
72+
)
73+
if wandb.run is None:
74+
wandb.login()
75+
76+
wandb.init(
77+
project="matbench-discovery",
78+
name=f"{job_name}-{slurm_array_task_id}",
79+
config=run_params,
80+
)
81+
82+
83+
# %% Create the featurizer: Ward et al. use a variety of different featurizers
84+
# https://journals.aps.org/prb/abstract/10.1103/PhysRevB.96.024104
85+
featurizers = [
86+
SiteStatsFingerprint.from_preset("CoordinationNumber_ward-prb-2017"),
87+
StructuralHeterogeneity(),
88+
ChemicalOrdering(),
89+
MaximumPackingEfficiency(),
90+
SiteStatsFingerprint.from_preset("LocalPropertyDifference_ward-prb-2017"),
91+
StructureComposition(Stoichiometry()),
92+
StructureComposition(ElementProperty.from_preset("magpie")),
93+
StructureComposition(ValenceOrbital(props=["frac"])),
94+
StructureComposition(IonProperty(fast=True)),
7395
]
96+
featurizer = MultipleFeaturizer(featurizers)
7497

7598

7699
# %% prints lots of pymatgen warnings
77100
# > No electronegativity for Ne. Setting to NaN. This has no physical meaning, ...
78101
warnings.filterwarnings(action="ignore", category=UserWarning, module="pymatgen")
79102

80103
df_features = featurizer.featurize_dataframe(
81-
df, input_col, ignore_errors=True, pbar=True
104+
df_this_job, input_col, ignore_errors=True, pbar=True
82105
)
83106

84107

models/wrenformer/slurm_train_wrenformer_ensemble.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# %%
2020
epochs = 300
2121
data_path = f"{ROOT}/data/mp/2022-08-13-mp-energies.json.gz"
22+
# data_path = f"{ROOT}/data/mp/2022-08-13-mp-energies-1k-samples.json.gz"
2223
target_col = "formation_energy_per_atom"
2324
# data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
2425
# target_col = "mp_energy_per_atom"
@@ -52,6 +53,7 @@
5253
print(f"{data_path=}")
5354

5455
df = pd.read_json(data_path).set_index("material_id", drop=False)
56+
5557
assert target_col in df, f"{target_col=} not in {list(df)}"
5658
assert input_col in df, f"{input_col=} not in {list(df)}"
5759
train_df, test_df = df_train_test_split(df, test_size=0.05)

tests/test_slurm.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from datetime import datetime
45
from unittest.mock import patch
56

@@ -11,6 +12,7 @@
1112
today = f"{datetime.now():%Y-%m-%d}"
1213

1314

15+
@patch.dict(os.environ, {"SLURM_JOB_ID": "1234"}, clear=True)
1416
@pytest.mark.parametrize("py_file_path", [None, "path/to/file.py"])
1517
def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) -> None:
1618
job_name = "test_job"
@@ -29,7 +31,9 @@ def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) ->
2931
slurm_flags=("--test-flag",),
3032
)
3133

32-
func_call()
34+
slurm_vars = func_call()
35+
36+
assert slurm_vars == {"slurm_job_id": "1234"}
3337
stdout, stderr = capsys.readouterr()
3438
# check slurm_submit_python() did nothing in normal mode
3539
assert stderr == stderr == ""

0 commit comments

Comments
 (0)