|
1 | 1 | # %%
|
2 | 2 | import os
|
3 |
| -from datetime import datetime |
4 | 3 | from importlib.metadata import version
|
5 | 4 |
|
6 | 5 | import pandas as pd
|
7 | 6 | from aviary.train import df_train_test_split, train_wrenformer
|
8 | 7 |
|
9 |
| -from matbench_discovery import ROOT |
| 8 | +from matbench_discovery import DEBUG, ROOT, timestamp, today |
10 | 9 | from matbench_discovery.slurm import slurm_submit
|
11 | 10 |
|
12 | 11 | """
|
13 |
| -Train a Wrenformer ensemble of size n_ens on target_col of data_path. |
| 12 | +Train a Wrenformer ensemble on target_col of data_path. |
14 | 13 | """
|
15 | 14 |
|
16 | 15 | __author__ = "Janosh Riebesell"
|
|
25 | 24 | # data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
|
26 | 25 | # target_col = "mp_energy_per_atom"
|
27 | 26 | data_name = "m3gnet-trainset" if "m3gnet" in data_path else "mp"
|
28 |
| -job_name = f"train-wrenformer-robust-{data_name}" |
29 |
| -n_ens = 10 |
30 |
| -timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}" |
31 |
| -today = timestamp.split("@")[0] |
| 27 | +job_name = f"train-wrenformer-robust-{data_name}{'-debug' if DEBUG else ''}" |
| 28 | +ensemble_size = 10 |
32 | 29 | dataset = "mp"
|
33 | 30 | module_dir = os.path.dirname(__file__)
|
34 | 31 | out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}")
|
|
39 | 36 | partition="ampere",
|
40 | 37 | account="LEE-SL3-GPU",
|
41 | 38 | time="8:0:0",
|
42 |
| - array=f"1-{n_ens}", |
| 39 | + array=f"1-{ensemble_size}", |
43 | 40 | out_dir=out_dir,
|
44 | 41 | slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
|
45 | 42 | )
|
|
70 | 67 | slurm_vars=slurm_vars,
|
71 | 68 | )
|
72 | 69 |
|
73 |
| -slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug") |
74 | 70 | train_wrenformer(
|
75 |
| - run_name=f"{job_name}-{slurm_job_id}-{slurm_array_task_id}", |
| 71 | + run_name=f"{job_name}-{slurm_array_task_id}", |
76 | 72 | train_df=train_df,
|
77 | 73 | test_df=test_df,
|
78 | 74 | target_col=target_col,
|
79 | 75 | task_type="regression",
|
80 | 76 | timestamp=timestamp,
|
81 |
| - # folds=(n_ens, slurm_array_task_id), |
| 77 | + # folds=(ensemble_size, slurm_array_task_id), |
82 | 78 | epochs=epochs,
|
83 | 79 | checkpoint="wandb", # None | 'local' | 'wandb',
|
84 | 80 | input_col=input_col,
|
|
0 commit comments