Skip to content

Commit a2a99f2

Browse files
committed
add code to run test_cgcnn.py with task_type=RS2RE
default all slurm out_dirs to os.environ.get("SBATCH_OUTPUT") to make sure jobs in the same array all write to the same dir prev jobs queueing across day boundaries wrote to different dirs
1 parent 65172ff commit a2a99f2

File tree

10 files changed

+55
-48
lines changed

10 files changed

+55
-48
lines changed

matbench_discovery/slurm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _get_calling_file_path(frame: int = 1) -> str:
2424

2525
def slurm_submit(
2626
job_name: str,
27-
log_dir: str,
27+
out_dir: str,
2828
time: str,
2929
partition: str,
3030
account: str,
@@ -41,7 +41,7 @@ def slurm_submit(
4141
4242
Args:
4343
job_name (str): Slurm job name.
44-
log_dir (str): Directory to write slurm logs. Log file will include slurm job
44+
out_dir (str): Directory to write slurm logs. Log file will include slurm job
4545
ID and array task ID.
4646
time (str): 'HH:MM:SS' time limit for the job.
4747
py_file_path (str, optional): Path to the python script to be submitted.
@@ -73,12 +73,12 @@ def slurm_submit(
7373
# before actual job command
7474
pre_cmd += ". /etc/profile.d/modules.sh; module load rhel8/default-amp;"
7575

76-
os.makedirs(log_dir, exist_ok=True) # slurm fails if log_dir is missing
76+
os.makedirs(out_dir, exist_ok=True) # slurm fails if out_dir is missing
7777

7878
cmd = [
7979
*f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(),
8080
*("--job-name", job_name),
81-
*("--output", f"{log_dir}/slurm-%A{'-%a' if array else ''}.log"),
81+
*("--output", f"{out_dir}/slurm-%A{'-%a' if array else ''}.log"),
8282
*slurm_flags,
8383
*("--wrap", f"{pre_cmd} python {py_file_path}".strip()),
8484
]

models/bowsr/test_bowsr.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@
4040
today = timestamp.split("@")[0]
4141
energy_model = "megnet"
4242
job_name = f"bowsr-{energy_model}-wbm-{task_type}"
43-
out_dir = f"{module_dir}/{today}-{job_name}"
43+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
4444

4545
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
4646

4747
slurm_vars = slurm_submit(
4848
job_name=job_name,
49-
log_dir=out_dir,
49+
out_dir=out_dir,
5050
partition="icelake-himem",
5151
account="LEE-SL3-CPU",
5252
time=(slurm_max_job_time := "12:0:0"),
@@ -109,12 +109,8 @@
109109
if wandb.run is None:
110110
wandb.login()
111111

112-
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
113-
wandb.init(
114-
project="matbench-discovery",
115-
name=f"{job_name}-{slurm_job_id}-{slurm_array_task_id}",
116-
config=run_params,
117-
)
112+
run_name = f"{job_name}-{slurm_array_task_id}"
113+
wandb.init(project="matbench-discovery", name=run_name, config=run_params)
118114

119115

120116
# %%

models/cgcnn/test_cgcnn.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -29,37 +29,44 @@
2929
"""
3030

3131
today = f"{datetime.now():%Y-%m-%d}"
32-
log_dir = f"{os.path.dirname(__file__)}/{today}-test"
33-
job_name = "test-cgcnn-ensemble"
32+
task_type = "RS2RE"
33+
job_name = f"test-cgcnn-wbm-{task_type}"
34+
module_dir = os.path.dirname(__file__)
35+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3436

3537
slurm_vars = slurm_submit(
3638
job_name=job_name,
3739
partition="ampere",
3840
account="LEE-SL3-GPU",
3941
time=(slurm_max_job_time := "2:0:0"),
40-
log_dir=log_dir,
42+
out_dir=out_dir,
4143
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
4244
)
4345

4446

4547
# %%
46-
task_type = "IS2RE"
4748
if task_type == "IS2RE":
4849
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
50+
input_col = "initial_structure"
4951
elif task_type == "RS2RE":
5052
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-cses.json.bz2"
53+
input_col = "relaxed_structure"
54+
else:
55+
raise ValueError(f"Unexpected {task_type=}")
56+
5157
df = pd.read_json(data_path).set_index("material_id", drop=False)
5258

5359
target_col = "e_form_per_atom_mp2020_corrected"
5460
df[target_col] = df_wbm[target_col]
55-
input_col = "initial_structure"
5661
assert target_col in df, f"{target_col=} not in {list(df)}"
62+
if task_type == "RS2RE":
63+
df[input_col] = [x["structure"] for x in df.computed_structure_entry]
5764
assert input_col in df, f"{input_col=} not in {list(df)}"
5865

5966
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
6067

6168
filters = {
62-
"$and": [{"created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"}}],
69+
"created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"},
6370
"display_name": {"$regex": "^cgcnn-robust"},
6471
}
6572
wandb.login()
@@ -87,9 +94,8 @@
8794
)
8895

8996
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
90-
wandb.init(
91-
project="matbench-discovery", name=f"{job_name}-{slurm_job_id}", config=run_params
92-
)
97+
run_name = f"{job_name}-{slurm_job_id}"
98+
wandb.init(project="matbench-discovery", name=run_name, config=run_params)
9399

94100
cg_data = CrystalGraphData(
95101
df, task_dict={target_col: "regression"}, structure_col=input_col
@@ -106,7 +112,7 @@
106112
data_loader=data_loader,
107113
)
108114

109-
df.to_csv(f"{log_dir}/{today}-{job_name}-preds.csv", index=False)
115+
df.to_csv(f"{out_dir}/{job_name}-preds.csv", index=False)
110116
pred_col = f"{target_col}_pred_ens"
111117
table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())
112118

models/cgcnn/train_cgcnn.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,22 @@
2525
# %%
2626
epochs = 300
2727
target_col = "formation_energy_per_atom"
28-
run_name = f"train-cgcnn-robust-{target_col}"
29-
print(f"{run_name=}")
30-
robust = "robust" in run_name.lower()
28+
job_name = f"train-cgcnn-robust-{target_col}"
29+
print(f"{job_name=}")
30+
robust = "robust" in job_name.lower()
3131
n_ens = 10
3232
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
3333
today = timestamp.split("@")[0]
34-
log_dir = f"{os.path.dirname(__file__)}/{today}-{run_name}"
34+
module_dir = os.path.dirname(__file__)
35+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3536

3637
slurm_vars = slurm_submit(
37-
job_name=run_name,
38+
job_name=job_name,
3839
partition="ampere",
3940
account="LEE-SL3-GPU",
4041
time="8:0:0",
4142
array=f"1-{n_ens}",
42-
log_dir=log_dir,
43+
out_dir=out_dir,
4344
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
4445
)
4546

@@ -106,7 +107,7 @@
106107
model_params=model_params,
107108
model=model,
108109
optimizer=optimizer,
109-
run_name=run_name,
110+
run_name=job_name,
110111
swa_start=swa_start,
111112
target_col=target_col,
112113
task_type=task_type,

models/m3gnet/test_m3gnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
slurm_array_task_count = 100
3434
slurm_mem_per_node = 12000
3535
job_name = f"m3gnet-wbm-{task_type}"
36-
out_dir = f"{module_dir}/{today}-{job_name}"
36+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3737

3838
slurm_vars = slurm_submit(
3939
job_name=job_name,
40-
log_dir=out_dir,
40+
out_dir=out_dir,
4141
partition="icelake-himem",
4242
account="LEE-SL3-CPU",
4343
time=(slurm_max_job_time := "3:0:0"),

models/megnet/test_megnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
today = timestamp.split("@")[0]
3131
module_dir = os.path.dirname(__file__)
3232
job_name = f"megnet-wbm-{task_type}"
33-
out_dir = f"{module_dir}/{today}-{job_name}"
33+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3434

3535
slurm_vars = slurm_submit(
3636
job_name=job_name,
37-
log_dir=out_dir,
37+
out_dir=out_dir,
3838
partition="icelake-himem",
3939
account="LEE-SL3-CPU",
4040
time=(slurm_max_job_time := "12:0:0"),

models/voronoi/voronoi_featurize_dataset.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727
slurm_array_task_count = 30
2828
job_name = f"voronoi-features-{data_name}"
29-
log_dir = f"{module_dir}/{today}-{job_name}"
29+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
30+
3031

3132
slurm_vars = slurm_submit(
3233
job_name=job_name,
@@ -35,13 +36,13 @@
3536
time=(slurm_max_job_time := "12:0:0"),
3637
array=f"1-{slurm_array_task_count}",
3738
slurm_flags=("--mem", "15G") if data_name == "mp" else (),
38-
log_dir=log_dir,
39+
out_dir=out_dir,
3940
)
4041

4142

4243
# %%
4344
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
44-
out_path = f"{log_dir}/{job_name}.csv.bz2"
45+
out_path = f"{out_dir}/{job_name}.csv.bz2"
4546

4647
if os.path.isfile(out_path):
4748
raise SystemExit(f"{out_path = } already exists, exciting early")

models/wrenformer/test_wrenformer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@
2828
task_type = "IS2RE"
2929
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
3030
job_name = "test-wrenformer-wbm-IS2RE"
31-
log_dir = f"{os.path.dirname(__file__)}/{today}-{job_name}"
31+
module_dir = os.path.dirname(__file__)
32+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
3233

3334
slurm_vars = slurm_submit(
3435
job_name=job_name,
3536
partition="ampere",
3637
account="LEE-SL3-GPU",
3738
time="2:0:0",
38-
log_dir=log_dir,
39+
out_dir=out_dir,
3940
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
4041
)
4142

@@ -99,7 +100,7 @@
99100
runs, data_loader=data_loader, df=df, model_cls=Wrenformer, target_col=target_col
100101
)
101102

102-
df.to_csv(f"{log_dir}/{job_name}-preds.csv")
103+
df.to_csv(f"{out_dir}/{job_name}-preds.csv")
103104

104105

105106
# %%

models/wrenformer/train_wrenformer.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,22 @@
2525
# data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
2626
# target_col = "mp_energy_per_atom"
2727
data_name = "m3gnet-trainset" if "m3gnet" in data_path else "mp"
28-
run_name = f"train-wrenformer-robust-{data_name}"
28+
job_name = f"train-wrenformer-robust-{data_name}"
2929
n_ens = 10
3030
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
3131
today = timestamp.split("@")[0]
3232
dataset = "mp"
33-
log_dir = f"{os.path.dirname(__file__)}/{dataset}/{today}-{run_name}"
33+
module_dir = os.path.dirname(__file__)
34+
out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
35+
3436

3537
slurm_vars = slurm_submit(
36-
job_name=run_name,
38+
job_name=job_name,
3739
partition="ampere",
3840
account="LEE-SL3-GPU",
3941
time="8:0:0",
4042
array=f"1-{n_ens}",
41-
log_dir=log_dir,
43+
out_dir=out_dir,
4244
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
4345
)
4446

@@ -50,7 +52,7 @@
5052
input_col = "wyckoff_spglib"
5153

5254
print(f"\nJob started running {timestamp}")
53-
print(f"{run_name=}")
55+
print(f"{job_name=}")
5456
print(f"{data_path=}")
5557

5658
df = pd.read_json(data_path).set_index("material_id", drop=False)
@@ -70,7 +72,7 @@
7072

7173
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
7274
train_wrenformer(
73-
run_name=f"{run_name}-{slurm_job_id}-{slurm_array_task_id}",
75+
run_name=f"{job_name}-{slurm_job_id}-{slurm_array_task_id}",
7476
train_df=train_df,
7577
test_df=test_df,
7678
target_col=target_col,

tests/test_slurm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
@pytest.mark.parametrize("py_file_path", [None, "path/to/file.py"])
1414
def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) -> None:
1515
job_name = "test_job"
16-
log_dir = "tmp"
16+
out_dir = "tmp"
1717
time = "0:0:1"
1818
partition = "fake-partition"
1919
account = "fake-account"
2020

2121
func_call = lambda: slurm_submit(
2222
job_name=job_name,
23-
log_dir=log_dir,
23+
out_dir=out_dir,
2424
time=time,
2525
partition=partition,
2626
account=account,
@@ -45,7 +45,7 @@ def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) ->
4545

4646
sbatch_cmd = (
4747
f"sbatch --partition={partition} --account={account} --time={time} "
48-
f"--job-name {job_name} --output {log_dir}/slurm-%A.log --test-flag "
48+
f"--job-name {job_name} --output {out_dir}/slurm-%A.log --test-flag "
4949
f"--wrap python {py_file_path or __file__}"
5050
).replace(" --", "\n --")
5151
stdout, stderr = capsys.readouterr()

0 commit comments

Comments
 (0)