Skip to content

Commit a2e3f46

Browse files
committed
add models/cgcnn/{slurm_train_cgcnn_ensemble,use_cgcnn_ensemble}.py
1 parent bf36e42 commit a2e3f46

6 files changed

+242
-43
lines changed

models/bowsr/slurm_array_bowsr_wbm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@
5757
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
5858
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
5959
out_path = f"{out_dir}/{slurm_array_task_id}.json.gz"
60+
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
6061

61-
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
62+
print(f"Job started running {timestamp}")
6263
print(f"{slurm_job_id = }")
6364
print(f"{slurm_array_task_id = }")
6465
print(f"{data_path = }")
+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# %%
2+
import os
3+
from datetime import datetime
4+
5+
import pandas as pd
6+
from aviary import ROOT
7+
from aviary.cgcnn.data import CrystalGraphData, collate_batch
8+
from aviary.cgcnn.model import CrystalGraphConvNet
9+
from aviary.core import TaskType
10+
from aviary.train import df_train_test_split, train_model
11+
from pymatgen.core import Structure
12+
from torch.utils.data import DataLoader
13+
from tqdm import tqdm
14+
15+
from matbench_discovery.slurm import slurm_submit_python
16+
17+
"""
18+
Train a Wrenformer ensemble of size n_folds on target_col of data_path.
19+
"""
20+
21+
__author__ = "Janosh Riebesell"
22+
__date__ = "2022-06-13"
23+
24+
25+
# %%
26+
epochs = 300
27+
target_col = "formation_energy_per_atom"
28+
run_name = f"cgcnn-robust-{epochs=}-{target_col}"
29+
print(f"{run_name=}")
30+
robust = "robust" in run_name.lower()
31+
n_folds = 10
32+
today = f"{datetime.now():%Y-%m-%d}"
33+
log_dir = f"{os.path.dirname(__file__)}/{today}-{run_name}"
34+
35+
slurm_submit_python(
36+
job_name=run_name,
37+
partition="ampere",
38+
account="LEE-SL3-GPU",
39+
time="8:0:0",
40+
array=f"1-{n_folds}",
41+
log_dir=log_dir,
42+
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
43+
# prepend into sbatch script to source module command and load default env
44+
# for Ampere GPU partition before actual job command
45+
pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;",
46+
)
47+
48+
49+
# %%
50+
optimizer = "AdamW"
51+
learning_rate = 3e-4
52+
batch_size = 128
53+
swa_start = None
54+
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
55+
task_type: TaskType = "regression"
56+
57+
58+
# %%
59+
data_path = f"{ROOT}/datasets/2022-08-13-mp-energies.json.gz"
60+
# data_path = f"{ROOT}/datasets/2022-08-13-mp-energies-1k-samples.json.gz"
61+
print(f"{data_path=}")
62+
df = pd.read_json(data_path).set_index("material_id", drop=False)
63+
df["structure"] = [Structure.from_dict(s) for s in tqdm(df.structure, disable=None)]
64+
assert target_col in df
65+
66+
train_df, test_df = df_train_test_split(df, test_size=0.5)
67+
68+
train_data = CrystalGraphData(train_df, task_dict={target_col: task_type})
69+
train_loader = DataLoader(
70+
train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
71+
)
72+
73+
test_data = CrystalGraphData(test_df, task_dict={target_col: task_type})
74+
test_loader = DataLoader(
75+
test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
76+
)
77+
78+
# 1 for regression, n_classes for classification
79+
n_targets = [1 if task_type == "regression" else df[target_col].max() + 1]
80+
81+
model_params = dict(
82+
n_targets=n_targets,
83+
elem_emb_len=train_data.elem_emb_len,
84+
nbr_fea_len=train_data.nbr_fea_dim,
85+
task_dict={target_col: task_type}, # e.g. {'exfoliation_en': 'regression'}
86+
robust=robust,
87+
)
88+
model = CrystalGraphConvNet(**model_params)
89+
90+
run_params = dict(
91+
batch_size=batch_size,
92+
train_df=dict(shape=train_data.df.shape, columns=", ".join(train_df)),
93+
test_df=dict(shape=test_data.df.shape, columns=", ".join(test_df)),
94+
)
95+
96+
97+
# %%
98+
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
99+
print(f"Job started running {timestamp}")
100+
101+
train_model(
102+
checkpoint="wandb", # None | 'local' | 'wandb',
103+
epochs=epochs,
104+
learning_rate=learning_rate,
105+
model_params=model_params,
106+
model=model,
107+
optimizer=optimizer,
108+
run_name=run_name,
109+
swa_start=swa_start,
110+
target_col=target_col,
111+
task_type=task_type,
112+
test_loader=test_loader,
113+
timestamp=timestamp,
114+
train_loader=train_loader,
115+
wandb_path="janosh/matbench-discovery",
116+
)

models/cgcnn/use_cgcnn_ensemble.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# %%
2+
from __future__ import annotations
3+
4+
import os
5+
from datetime import datetime
6+
7+
import pandas as pd
8+
import wandb
9+
from aviary.cgcnn.data import CrystalGraphData, collate_batch
10+
from aviary.cgcnn.model import CrystalGraphConvNet
11+
from aviary.deploy import predict_from_wandb_checkpoints
12+
from pymatgen.core import Structure
13+
from torch.utils.data import DataLoader
14+
from tqdm import tqdm
15+
16+
from matbench_discovery import ROOT
17+
from matbench_discovery.plot_scripts import df_wbm
18+
19+
__author__ = "Janosh Riebesell"
20+
__date__ = "2022-08-15"
21+
22+
"""
23+
Script that downloads checkpoints for an ensemble of Wrenformer models trained on the MP
24+
formation energies, then makes predictions on some dataset, prints ensemble metrics and
25+
stores predictions to CSV.
26+
"""
27+
28+
module_dir = os.path.dirname(__file__)
29+
today = f"{datetime.now():%Y-%m-%d}"
30+
31+
32+
# %%
33+
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
34+
df = pd.read_json(data_path).set_index("material_id", drop=False)
35+
old_len = len(df)
36+
df = df.dropna() # two missing initial structures
37+
assert len(df) == old_len - 2
38+
39+
df["e_form_per_atom_mp2020_corrected"] = df_wbm.e_form_per_atom_mp2020_corrected
40+
41+
target_col = "e_form_per_atom_mp2020_corrected"
42+
input_col = "initial_structure"
43+
assert target_col in df, f"{target_col=} not in {list(df)}"
44+
assert input_col in df, f"{input_col=} not in {list(df)}"
45+
46+
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col])]
47+
48+
wandb.login()
49+
wandb_api = wandb.Api()
50+
ensemble_id = "cgcnn-e_form-ensemble-1"
51+
runs = wandb_api.runs(
52+
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
53+
)
54+
55+
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"
56+
57+
cg_data = CrystalGraphData(
58+
df, task_dict={target_col: "regression"}, structure_col=input_col
59+
)
60+
data_loader = DataLoader(
61+
cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
62+
)
63+
df, ensemble_metrics = predict_from_wandb_checkpoints(
64+
runs,
65+
df=df,
66+
target_col=target_col,
67+
model_class=CrystalGraphConvNet,
68+
data_loader=data_loader,
69+
)
70+
71+
df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv")

models/m3gnet/slurm_array_m3gnet_wbm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@
5151
# %%
5252
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
5353
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
54+
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
5455

55-
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
56+
print(f"Job started running {timestamp}")
5657
print(f"{slurm_job_id = }")
5758
print(f"{slurm_array_task_id = }")
5859
print(f"{version('m3gnet') = }")

models/wrenformer/mp/use_ensemble.py models/wrenformer/mp/use_wrenformer_ensemble.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import pandas as pd
88
import wandb
9-
from aviary.wrenformer.deploy import deploy_wandb_checkpoints
9+
from aviary.deploy import predict_from_wandb_checkpoints
10+
from aviary.wrenformer.data import df_to_in_mem_dataloader
11+
from aviary.wrenformer.model import Wrenformer
1012

1113
__author__ = "Janosh Riebesell"
1214
__date__ = "2022-08-15"
@@ -26,8 +28,10 @@
2628
data_path = "https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
2729
df = pd.read_csv(data_path).set_index("material_id")
2830

29-
3031
target_col = "e_form_per_atom"
32+
input_col = "wyckoff"
33+
assert target_col in df, f"{target_col=} not in {list(df)}"
34+
assert input_col in df, f"{input_col=} not in {list(df)}"
3135

3236
wandb.login()
3337
wandb_api = wandb.Api()
@@ -38,8 +42,17 @@
3842

3943
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"
4044

41-
df, ensemble_metrics = deploy_wandb_checkpoints(
42-
runs, df, input_col="wyckoff", target_col=target_col
45+
data_loader = df_to_in_mem_dataloader(
46+
df=df,
47+
target_col=target_col,
48+
batch_size=1024,
49+
input_col=input_col,
50+
embedding_type="wyckoff",
51+
shuffle=False, # False is default but best be explicit
52+
)
53+
54+
df, ensemble_metrics = predict_from_wandb_checkpoints(
55+
runs, data_loader, df=df, model_class=Wrenformer
4356
)
4457

4558
df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv")

models/wrenformer/slurm_array_wrenformer.py models/wrenformer/slurm_train_wrenformer_ensemble.py

+34-37
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,37 @@
22
import os
33
from datetime import datetime
44

5-
from aviary.wrenformer.train import train_wrenformer_on_df
5+
import pandas as pd
6+
from aviary import ROOT
7+
from aviary.train import df_train_test_split, train_wrenformer
68

7-
from matbench_discovery import ROOT
89
from matbench_discovery.slurm import slurm_submit_python
910

1011
"""
11-
Train a Wrenformer
12-
ensemble of size n_folds on target_col of df_or_path.
12+
Train a Wrenformer ensemble of size n_folds on target_col of data_path.
1313
"""
1414

1515
__author__ = "Janosh Riebesell"
1616
__date__ = "2022-08-13"
1717

1818

1919
# %%
20-
df_or_path = f"{ROOT}/data/2022-08-13-mp-energies.json.gz"
21-
target_col = "energy_per_atom"
22-
# df_or_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
23-
# target_col = "mp_energy_per_atom"
24-
2520
epochs = 300
26-
job_name = f"wrenformer-robust-{epochs=}-{target_col}"
21+
target_col = "e_form"
22+
run_name = f"wrenformer-robust-mp+wbm-{epochs=}-{target_col}"
2723
n_folds = 10
2824
today = f"{datetime.now():%Y-%m-%d}"
2925
dataset = "mp"
30-
# dataset = 'm3gnet_train_set'
31-
log_dir = f"{os.path.dirname(__file__)}/{dataset}/{today}-{job_name}"
26+
log_dir = f"{os.path.dirname(__file__)}/{dataset}/{today}-{run_name}"
3227

3328
slurm_submit_python(
34-
job_name=job_name,
29+
job_name=run_name,
3530
partition="ampere",
31+
account="LEE-SL3-GPU",
3632
time="8:0:0",
3733
array=f"1-{n_folds}",
3834
log_dir=log_dir,
39-
account="LEE-SL3-GPU",
40-
slurm_flags=("--nodes 1", "--gpus-per-node 1"),
35+
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
4136
# prepend into sbatch script to source module command and load default env
4237
# for Ampere GPU partition before actual job command
4338
pre_cmd=". /etc/profile.d/modules.sh; module load rhel8/default-amp;",
@@ -48,39 +43,41 @@
4843

4944

5045
# %%
51-
n_attn_layers = 3
52-
embedding_aggregations = ("mean",)
53-
optimizer = "AdamW"
5446
learning_rate = 3e-4
55-
task_type = "regression"
56-
checkpoint = "wandb" # None | 'local' | 'wandb'
47+
data_path = f"{ROOT}/data/2022-08-13-mp-energies.json.gz"
48+
target_col = "energy_per_atom"
49+
# data_path = f"{ROOT}/data/2022-08-25-m3gnet-trainset-mp-2021-struct-energy.json.gz"
50+
# target_col = "mp_energy_per_atom"
5751
batch_size = 128
58-
swa_start = None
52+
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
5953
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
6054

61-
print(f"Job started running {datetime.now():%Y-%m-%d@%H-%M}")
62-
slurm_job_id = os.environ.get("SLURM_JOB_ID")
63-
slurm_array_task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0))
55+
print(f"Job started running {timestamp}")
56+
print(f"{run_name=}")
57+
print(f"{data_path=}")
6458

65-
print(f"{slurm_job_id=}")
66-
print(f"{slurm_array_task_id=}")
67-
print(f"{job_name=}")
68-
print(f"{df_or_path=}")
59+
df = pd.read_json(data_path).set_index("material_id", drop=False)
60+
assert target_col in df
61+
train_df, test_df = df_train_test_split(df, test_size=0.3)
62+
63+
run_params = dict(
64+
batch_size=batch_size,
65+
train_df=dict(shape=train_df.shape, columns=", ".join(train_df)),
66+
test_df=dict(shape=test_df.shape, columns=", ".join(test_df)),
67+
)
6968

70-
train_wrenformer_on_df(
71-
run_name=job_name,
69+
train_wrenformer(
70+
run_name=run_name,
71+
train_df=train_df,
72+
test_df=test_df,
7273
target_col=target_col,
73-
df_or_path=df_or_path,
74+
task_type="regression",
7475
timestamp=timestamp,
75-
test_size=0.05,
7676
# folds=(n_folds, slurm_array_task_id),
7777
epochs=epochs,
78-
n_attn_layers=n_attn_layers,
79-
checkpoint=checkpoint,
80-
optimizer=optimizer,
78+
checkpoint="wandb", # None | 'local' | 'wandb',
8179
learning_rate=learning_rate,
82-
embedding_aggregations=embedding_aggregations,
8380
batch_size=batch_size,
84-
swa_start=swa_start,
8581
wandb_path="janosh/matbench-discovery",
82+
run_params=run_params,
8683
)

0 commit comments

Comments
 (0)