Skip to content

Commit 5804d13

Browse files
committed
fix predict_from_wandb_checkpoints() didn't return ensemble_metrics due to missing target_col
change slurm log file ext from .out to .log for better syntax highlighting
1 parent f127a9e commit 5804d13

File tree

6 files changed

+17
-20
lines changed

6 files changed

+17
-20
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ wandb/
1818
job-logs/
1919

2020
# slurm logs
21-
slurm-*out
21+
*slurm-*.log
2222
models/**/*.csv
2323

2424
# temporary ignore rule

matbench_discovery/slurm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def slurm_submit_python(
7777
cmd = [
7878
*f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(),
7979
*("--job-name", job_name),
80-
*("--output", f"{log_dir}/slurm-%A{'-%a' if array else ''}-{today}.out"),
80+
*("--output", f"{log_dir}/{today}-slurm-%A{'-%a' if array else ''}.log"),
8181
*slurm_flags,
8282
*("--wrap", f"{pre_cmd} python {py_file_path}".strip()),
8383
]

models/cgcnn/use_cgcnn_ensemble.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@
6060
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
6161

6262
wandb.login()
63-
wandb_api = wandb.Api()
64-
runs = wandb_api.runs(
63+
runs = wandb.Api().runs(
6564
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
6665
)
6766

models/voronoi/featurize_mp_wbm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
account="LEE-SL3-CPU",
4444
time=(slurm_max_job_time := "3:0:0"),
4545
array=f"1-{slurm_array_task_count}",
46-
log_dir=module_dir,
46+
log_dir=f"{module_dir}/{job_name}",
4747
)
4848

4949

@@ -68,6 +68,7 @@
6868
run_params = dict(
6969
data_path=data_path,
7070
slurm_max_job_time=slurm_max_job_time,
71+
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
7172
**slurm_vars,
7273
)
7374
if wandb.run is None:

models/wrenformer/mp/use_wrenformer_ensemble.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424

2525
module_dir = os.path.dirname(__file__)
2626
today = f"{datetime.now():%Y-%m-%d}"
27-
ensemble_id = "wrenformer-e_form-ensemble-1"
28-
run_name = f"{today}-{ensemble_id}-IS2RE"
27+
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
28+
assert "wbm" in data_path
29+
run_name = "wrenformer-wbm-IS2RE"
2930

3031
slurm_submit_python(
3132
job_name=run_name,
@@ -38,7 +39,6 @@
3839

3940

4041
# %%
41-
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv"
4242
target_col = "e_form_per_atom_mp2020_corrected"
4343
input_col = "wyckoff_spglib"
4444
df = pd.read_csv(data_path).dropna(subset=input_col).set_index("material_id")
@@ -58,21 +58,18 @@
5858

5959
# %%
6060
wandb.login()
61-
wandb_api = wandb.Api()
62-
runs = wandb_api.runs(
63-
"janosh/matbench-discovery",
64-
filters={
65-
"$and": [{"created_at": {"$gt": "2022-11-10", "$lt": "2022-11-11"}}],
66-
"display_name": "wrenformer-robust-mp-formation_energy_per_atom-epochs=300",
67-
},
68-
)
61+
filters = {
62+
"$and": [{"created_at": {"$gt": "2022-11-10", "$lt": "2022-11-11"}}],
63+
"display_name": "wrenformer-robust-mp-formation_energy_per_atom-epochs=300",
64+
}
65+
runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters)
6966

70-
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"
67+
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}"
7168

7269

7370
# %%
74-
df, ensemble_metrics = predict_from_wandb_checkpoints(
75-
runs, data_loader=data_loader, df=df, model_cls=Wrenformer
71+
df, _ensemble_metrics = predict_from_wandb_checkpoints(
72+
runs, data_loader=data_loader, df=df, model_cls=Wrenformer, target_col=target_col
7673
)
7774

7875
df.round(6).to_csv(f"{module_dir}/{today}-{run_name}-preds.csv")

tests/test_slurm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_slurm_submit(capsys: CaptureFixture[str], py_file_path: str | None) ->
4848

4949
sbatch_cmd = (
5050
f"sbatch --partition={partition} --account={account} --time={time} "
51-
f"--job-name {job_name} --output {log_dir}/slurm-%A-{today}.out --test-flag "
51+
f"--job-name {job_name} --output {log_dir}/{today}-slurm-%A.log --test-flag "
5252
f"--wrap python {py_file_path or __file__}"
5353
).replace(" --", "\n --")
5454
stdout, stderr = capsys.readouterr()

0 commit comments

Comments
 (0)