|
| 1 | +# %% |
| 2 | +import os |
| 3 | +from datetime import datetime |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | +from aviary import ROOT |
| 7 | +from aviary.cgcnn.data import CrystalGraphData, collate_batch |
| 8 | +from aviary.cgcnn.model import CrystalGraphConvNet |
| 9 | +from aviary.core import TaskType |
| 10 | +from aviary.train import df_train_test_split, train_model |
| 11 | +from pymatgen.core import Structure |
| 12 | +from torch.utils.data import DataLoader |
| 13 | +from tqdm import tqdm |
| 14 | + |
| 15 | +from matbench_discovery.slurm import slurm_submit_python |
| 16 | + |
| 17 | +""" |
| 18 | +Train a Wrenformer ensemble of size n_folds on target_col of data_path. |
| 19 | +""" |
| 20 | + |
| 21 | +__author__ = "Janosh Riebesell" |
| 22 | +__date__ = "2022-06-13" |
| 23 | + |
| 24 | + |
| 25 | +# %% |
| 26 | +epochs = 300 |
| 27 | +target_col = "formation_energy_per_atom" |
| 28 | +run_name = f"cgcnn-robust-{epochs=}-{target_col}" |
| 29 | +print(f"{run_name=}") |
| 30 | +robust = "robust" in run_name.lower() |
| 31 | +n_folds = 10 |
| 32 | +today = f"{datetime.now():%Y-%m-%d}" |
| 33 | +log_dir = f"{os.path.dirname(__file__)}/{today}-{run_name}" |
| 34 | + |
| 35 | +slurm_submit_python( |
| 36 | + job_name=run_name, |
| 37 | + partition="ampere", |
| 38 | + account="LEE-SL3-GPU", |
| 39 | + time="8:0:0", |
| 40 | + array=f"1-{n_folds}", |
| 41 | + log_dir=log_dir, |
| 42 | + slurm_flags=("--nodes", "1", "--gpus-per-node", "1"), |
| 43 | + # prepend into sbatch script to source module command and load default env |
| 44 | + # for Ampere GPU partition before actual job command |
| 45 | + pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;", |
| 46 | +) |
| 47 | + |
| 48 | + |
| 49 | +# %% |
| 50 | +optimizer = "AdamW" |
| 51 | +learning_rate = 3e-4 |
| 52 | +batch_size = 128 |
| 53 | +swa_start = None |
| 54 | +slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0)) |
| 55 | +task_type: TaskType = "regression" |
| 56 | + |
| 57 | + |
| 58 | +# %% |
| 59 | +data_path = f"{ROOT}/datasets/2022-08-13-mp-energies.json.gz" |
| 60 | +# data_path = f"{ROOT}/datasets/2022-08-13-mp-energies-1k-samples.json.gz" |
| 61 | +print(f"{data_path=}") |
| 62 | +df = pd.read_json(data_path).set_index("material_id", drop=False) |
| 63 | +df["structure"] = [Structure.from_dict(s) for s in tqdm(df.structure, disable=None)] |
| 64 | +assert target_col in df |
| 65 | + |
| 66 | +train_df, test_df = df_train_test_split(df, test_size=0.5) |
| 67 | + |
| 68 | +train_data = CrystalGraphData(train_df, task_dict={target_col: task_type}) |
| 69 | +train_loader = DataLoader( |
| 70 | + train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch |
| 71 | +) |
| 72 | + |
| 73 | +test_data = CrystalGraphData(test_df, task_dict={target_col: task_type}) |
| 74 | +test_loader = DataLoader( |
| 75 | + test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_batch |
| 76 | +) |
| 77 | + |
| 78 | +# 1 for regression, n_classes for classification |
| 79 | +n_targets = [1 if task_type == "regression" else df[target_col].max() + 1] |
| 80 | + |
| 81 | +model_params = dict( |
| 82 | + n_targets=n_targets, |
| 83 | + elem_emb_len=train_data.elem_emb_len, |
| 84 | + nbr_fea_len=train_data.nbr_fea_dim, |
| 85 | + task_dict={target_col: task_type}, # e.g. {'exfoliation_en': 'regression'} |
| 86 | + robust=robust, |
| 87 | +) |
| 88 | +model = CrystalGraphConvNet(**model_params) |
| 89 | + |
| 90 | +run_params = dict( |
| 91 | + batch_size=batch_size, |
| 92 | + train_df=dict(shape=train_data.df.shape, columns=", ".join(train_df)), |
| 93 | + test_df=dict(shape=test_data.df.shape, columns=", ".join(test_df)), |
| 94 | +) |
| 95 | + |
| 96 | + |
| 97 | +# %% |
| 98 | +timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}" |
| 99 | +print(f"Job started running {timestamp}") |
| 100 | + |
| 101 | +train_model( |
| 102 | + checkpoint="wandb", # None | 'local' | 'wandb', |
| 103 | + epochs=epochs, |
| 104 | + learning_rate=learning_rate, |
| 105 | + model_params=model_params, |
| 106 | + model=model, |
| 107 | + optimizer=optimizer, |
| 108 | + run_name=run_name, |
| 109 | + swa_start=swa_start, |
| 110 | + target_col=target_col, |
| 111 | + task_type=task_type, |
| 112 | + test_loader=test_loader, |
| 113 | + timestamp=timestamp, |
| 114 | + train_loader=train_loader, |
| 115 | + wandb_path="janosh/matbench-discovery", |
| 116 | +) |
0 commit comments