Skip to content

Commit 5d9a3b2

Browse files
committed
add slurm_submit_python() to use_cgcnn_ensemble.py and use_wrenformer_ensemble.py
auto-load default ampere partition in slurm_submit_python() if partition contains 'GPU'
1 parent e5b099c commit 5d9a3b2

File tree

5 files changed

+44
-18
lines changed

5 files changed

+44
-18
lines changed

matbench_discovery/slurm.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,20 @@ def slurm_submit_python(
5050
pre_cmd (str, optional): Things like `module load` commands and environment
5151
variables to set when running the python script go here. Example:
5252
pre_cmd='ENV_VAR=42' or 'module load rhel8/default-amp;'. Defaults to "".
53+
If running on CPU, pre_cmd="unset OMP_NUM_THREADS" allows PyTorch to use
54+
all cores https://docs.hpc.cam.ac.uk/hpc/software-packages/pytorch.html
5355
5456
Raises:
5557
SystemExit: Exit code will be subprocess.run(['sbatch', ...]).returncode.
5658
"""
5759
if py_file_path is None:
5860
py_file_path = _get_calling_file_path(frame=2)
5961

62+
if "GPU" in partition:
63+
# on Ampere GPU partition, source module CLI and load default Ampere env
64+
# before actual job command
65+
pre_cmd += ". /etc/profile.d/modules.sh; module load rhel8/default-amp;"
66+
6067
cmd = [
6168
*f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(),
6269
*("--job-name", job_name),
@@ -72,7 +79,7 @@ def slurm_submit_python(
7279
if (is_slurm_job and is_log_file) or "slurm-submit" in sys.argv:
7380
# print sbatch command at submission time and into slurm log file
7481
# but not when running in command line or Jupyter
75-
print(" ".join(cmd))
82+
print(f"\n{' '.join(cmd)}\n")
7683

7784
if "slurm-submit" not in sys.argv:
7885
return

models/cgcnn/slurm_train_cgcnn_ensemble.py

-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@
4040
array=f"1-{n_folds}",
4141
log_dir=log_dir,
4242
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
43-
# prepend into sbatch script to source module command and load default env
44-
# for Ampere GPU partition before actual job command
45-
pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;",
4643
)
4744

4845

models/cgcnn/use_cgcnn_ensemble.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from matbench_discovery import ROOT
1717
from matbench_discovery.plot_scripts import df_wbm
18+
from matbench_discovery.slurm import slurm_submit_python
1819

1920
__author__ = "Janosh Riebesell"
2021
__date__ = "2022-08-15"
@@ -27,27 +28,41 @@
2728

2829
module_dir = os.path.dirname(__file__)
2930
today = f"{datetime.now():%Y-%m-%d}"
31+
ensemble_id = "cgcnn-e_form-ensemble-1"
32+
run_name = f"{today}-{ensemble_id}-IS2RE"
33+
34+
slurm_submit_python(
35+
job_name=run_name,
36+
partition="ampere",
37+
account="LEE-SL3-GPU",
38+
time="1:0:0",
39+
log_dir=module_dir,
40+
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
41+
)
3042

3143

3244
# %%
3345
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
3446
df = pd.read_json(data_path).set_index("material_id", drop=False)
3547
old_len = len(df)
48+
no_init_structs = df.query("initial_structure.isnull()").index
3649
df = df.dropna() # two missing initial structures
3750
assert len(df) == old_len - 2
3851

52+
assert all(
53+
df.index == df_wbm.drop(index=no_init_structs).index
54+
), "df and df_wbm must have same index"
3955
df["e_form_per_atom_mp2020_corrected"] = df_wbm.e_form_per_atom_mp2020_corrected
4056

4157
target_col = "e_form_per_atom_mp2020_corrected"
4258
input_col = "initial_structure"
4359
assert target_col in df, f"{target_col=} not in {list(df)}"
4460
assert input_col in df, f"{input_col=} not in {list(df)}"
4561

46-
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col])]
62+
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
4763

4864
wandb.login()
4965
wandb_api = wandb.Api()
50-
ensemble_id = "cgcnn-e_form-ensemble-1"
5166
runs = wandb_api.runs(
5267
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
5368
)
@@ -62,10 +77,11 @@
6277
)
6378
df, ensemble_metrics = predict_from_wandb_checkpoints(
6479
runs,
65-
df=cg_data.df, # dropping isolated-atom structs means len(cg_data.df) < len(df)
80+
# dropping isolated-atom structs means len(cg_data.df) < len(df)
81+
df=cg_data.df.reset_index(drop=True).drop(columns=input_col),
6682
target_col=target_col,
67-
model_class=CrystalGraphConvNet,
83+
model_cls=CrystalGraphConvNet,
6884
data_loader=data_loader,
6985
)
7086

71-
df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv")
87+
df.round(6).to_csv(f"{module_dir}/{today}-{run_name}-preds.csv")

models/wrenformer/mp/use_wrenformer_ensemble.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from aviary.wrenformer.data import df_to_in_mem_dataloader
1111
from aviary.wrenformer.model import Wrenformer
1212

13+
from matbench_discovery.slurm import slurm_submit_python
14+
1315
__author__ = "Janosh Riebesell"
1416
__date__ = "2022-08-15"
1517

@@ -21,6 +23,17 @@
2123

2224
module_dir = os.path.dirname(__file__)
2325
today = f"{datetime.now():%Y-%m-%d}"
26+
ensemble_id = "wrenformer-e_form-ensemble-1"
27+
run_name = f"{today}-{ensemble_id}-IS2RE"
28+
29+
slurm_submit_python(
30+
job_name=run_name,
31+
partition="ampere",
32+
account="LEE-SL3-GPU",
33+
time="1:0:0",
34+
log_dir=module_dir,
35+
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
36+
)
2437

2538

2639
# %%
@@ -35,7 +48,6 @@
3548

3649
wandb.login()
3750
wandb_api = wandb.Api()
38-
ensemble_id = "wrenformer-e_form-ensemble-1"
3951
runs = wandb_api.runs(
4052
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
4153
)
@@ -52,7 +64,7 @@
5264
)
5365

5466
df, ensemble_metrics = predict_from_wandb_checkpoints(
55-
runs, data_loader, df=df, model_class=Wrenformer
67+
runs, data_loader=data_loader, df=df, model_cls=Wrenformer
5668
)
5769

58-
df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv")
70+
df.round(6).to_csv(f"{module_dir}/{today}-{run_name}-preds.csv")

models/wrenformer/slurm_train_wrenformer_ensemble.py

-6
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@
3333
array=f"1-{n_folds}",
3434
log_dir=log_dir,
3535
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
36-
# prepend into sbatch script to source module command and load default env
37-
# for Ampere GPU partition before actual job command
38-
pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;",
39-
# if running on CPU, unsetting OMP threads allows using PyTorch to use all cores
40-
# https://docs.hpc.cam.ac.uk/hpc/software-packages/pytorch.html
41-
# pre_cmd="unset OMP_NUM_THREADS",
4236
)
4337

4438

0 commit comments

Comments
 (0)