Skip to content

Commit 255429a

Browse files
committed
add wandb.log scatter-parity plot in test_{cgcnn,wrenformer}.py
1 parent 8219c43 commit 255429a

File tree

12 files changed

+169
-83
lines changed

12 files changed

+169
-83
lines changed

matbench_discovery/slurm.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import subprocess
33
import sys
44
from collections.abc import Sequence
5-
from datetime import datetime
65

76
SLURM_KEYS = (
87
"job_id array_task_id array_task_count mem_per_node nodelist submit_host"
@@ -74,11 +73,12 @@ def slurm_submit(
7473
# before actual job command
7574
pre_cmd += ". /etc/profile.d/modules.sh; module load rhel8/default-amp;"
7675

77-
today = f"{datetime.now():%Y-%m-%d}"
76+
os.makedirs(log_dir, exist_ok=True) # slurm fails if log_dir is missing
77+
7878
cmd = [
7979
*f"sbatch --{partition=} --{account=} --{time=}".replace("'", "").split(),
8080
*("--job-name", job_name),
81-
*("--output", f"{log_dir}/{today}-slurm-%A{'-%a' if array else ''}.log"),
81+
*("--output", f"{log_dir}/slurm-%A{'-%a' if array else ''}.log"),
8282
*slurm_flags,
8383
*("--wrap", f"{pre_cmd} python {py_file_path}".strip()),
8484
]
@@ -104,8 +104,6 @@ def slurm_submit(
104104
if "slurm-submit" not in sys.argv:
105105
return slurm_vars # if not submitting slurm job, resume outside code as normal
106106

107-
os.makedirs(log_dir, exist_ok=True) # slurm fails if log_dir is missing
108-
109107
result = subprocess.run(cmd, check=True)
110108

111109
# after sbatch submission, exit with slurm exit code

models/bowsr/test_bowsr.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
slurm_mem_per_node = 12000
3434
# set large job array size for fast testing/debugging
3535
slurm_array_task_count = 1000
36-
slurm_max_parallel = 100
36+
# see https://stackoverflow.com/a/55431306 for how to change array throttling
37+
# post submission
38+
slurm_max_parallel = 50
3739
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
3840
today = timestamp.split("@")[0]
3941
energy_model = "megnet"
@@ -89,6 +91,9 @@
8991
seed=42,
9092
)
9193
optimize_kwargs = dict(n_init=100, n_iter=100, alpha=0.026**2)
94+
slurm_dict = dict(
95+
slurm_max_parallel=slurm_max_parallel, slurm_max_job_time=slurm_max_job_time
96+
)
9297

9398
run_params = dict(
9499
bayes_optim_kwargs=bayes_optim_kwargs,
@@ -99,8 +104,7 @@
99104
energy_model_version=version(energy_model),
100105
optimize_kwargs=optimize_kwargs,
101106
task_type=task_type,
102-
slurm_max_job_time=slurm_max_job_time,
103-
slurm_vars=slurm_vars | dict(slurm_max_parallel=slurm_max_parallel),
107+
slurm_vars=slurm_vars | slurm_dict,
104108
)
105109
if wandb.run is None:
106110
wandb.login()

models/cgcnn/test_cgcnn.py

+54-23
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
import os
55
from datetime import datetime
6+
from importlib.metadata import version
67

78
import pandas as pd
89
import wandb
910
from aviary.cgcnn.data import CrystalGraphData, collate_batch
1011
from aviary.cgcnn.model import CrystalGraphConvNet
1112
from aviary.deploy import predict_from_wandb_checkpoints
1213
from pymatgen.core import Structure
13-
from pymatviz import density_scatter
1414
from torch.utils.data import DataLoader
1515
from tqdm import tqdm
1616

@@ -29,28 +29,25 @@
2929

3030
today = f"{datetime.now():%Y-%m-%d}"
3131
log_dir = f"{os.path.dirname(__file__)}/{today}-test"
32-
ensemble_id = "cgcnn-e_form-ensemble-1"
33-
run_name = f"{ensemble_id}-IS2RE"
32+
job_name = "test-cgcnn-ensemble"
3433

35-
slurm_submit(
36-
job_name=run_name,
34+
slurm_vars = slurm_submit(
35+
job_name=job_name,
3736
partition="ampere",
3837
account="LEE-SL3-GPU",
39-
time="1:0:0",
38+
time=(slurm_max_job_time := "2:0:0"),
4039
log_dir=log_dir,
4140
slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
4241
)
4342

4443

4544
# %%
46-
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
45+
task_type = "IS2RE"
46+
if task_type == "IS2RE":
47+
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
48+
elif task_type == "RS2RE":
49+
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-cses.json.bz2"
4750
df = pd.read_json(data_path).set_index("material_id", drop=False)
48-
old_len = len(df)
49-
no_init_structs = df.query("initial_structure.isnull()").index
50-
df = df.dropna() # two missing initial structures
51-
assert len(df) == old_len - 2
52-
53-
assert all(df.index == df_wbm.drop(index=no_init_structs).index)
5451

5552
target_col = "e_form_per_atom_mp2020_corrected"
5653
df[target_col] = df_wbm[target_col]
@@ -60,12 +57,38 @@
6057

6158
df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
6259

60+
filters = {
61+
"$and": [{"created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"}}],
62+
"display_name": {"$regex": "^cgcnn-robust"},
63+
}
6364
wandb.login()
64-
runs = wandb.Api().runs(
65-
"janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
65+
runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters)
66+
67+
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}"
68+
for idx, run in enumerate(runs):
69+
for key, val in run.config.items():
70+
if val == runs[0][key] or key.startswith(("slurm_", "timestamp")):
71+
continue
72+
raise ValueError(
73+
f"Configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
74+
)
75+
76+
run_params = dict(
77+
data_path=data_path,
78+
df=dict(shape=str(df.shape), columns=", ".join(df)),
79+
aviary_version=version("aviary"),
80+
ensemble_size=len(runs),
81+
task_type=task_type,
82+
target_col=target_col,
83+
input_col=input_col,
84+
filters=filters,
85+
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
6686
)
6787

68-
assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}"
88+
slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
89+
wandb.init(
90+
project="matbench-discovery", name=f"{job_name}-{slurm_job_id}", config=run_params
91+
)
6992

7093
cg_data = CrystalGraphData(
7194
df, task_dict={target_col: "regression"}, structure_col=input_col
@@ -82,14 +105,22 @@
82105
data_loader=data_loader,
83106
)
84107

85-
df.round(6).to_csv(f"{log_dir}/{today}-{run_name}-preds.csv", index=False)
108+
df.to_csv(f"{log_dir}/{today}-{job_name}-preds.csv", index=False)
109+
table = wandb.Table(dataframe=df)
86110

87111

88112
# %%
89-
print(f"{runs[0].url=}")
90-
ax = density_scatter(
91-
df=df.query("e_form_per_atom_mp2020_corrected < 10"),
92-
x="e_form_per_atom_mp2020_corrected",
93-
y="e_form_per_atom_mp2020_corrected_pred_1",
113+
pred_col = f"{target_col}_pred_ens"
114+
MAE = ensemble_metrics["MAE"]
115+
R2 = ensemble_metrics["R2"]
116+
117+
title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
118+
print(title)
119+
120+
scatter_plot = wandb.plot_table(
121+
vega_spec_name="janosh/scatter-parity",
122+
data_table=table,
123+
fields=dict(x=target_col, y=pred_col, title=title),
94124
)
95-
# ax.figure.savefig(f"{ROOT}/tmp/{today}-{run_name}-scatter-preds.png", dpi=300)
125+
126+
wandb.log({"true_pred_scatter": scatter_plot})

models/cgcnn/train_cgcnn.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@
2525
# %%
2626
epochs = 300
2727
target_col = "formation_energy_per_atom"
28-
run_name = f"cgcnn-robust-{target_col}"
28+
run_name = f"train-cgcnn-robust-{target_col}"
2929
print(f"{run_name=}")
3030
robust = "robust" in run_name.lower()
3131
n_ens = 10
3232
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
3333
today = timestamp.split("@")[0]
3434
log_dir = f"{os.path.dirname(__file__)}/{today}-{run_name}"
3535

36-
slurm_submit(
36+
slurm_vars = slurm_submit(
3737
job_name=run_name,
3838
partition="ampere",
3939
account="LEE-SL3-GPU",
@@ -63,11 +63,13 @@
6363

6464
train_df, test_df = df_train_test_split(df, test_size=0.05)
6565

66+
print(f"{train_df.shape=}")
6667
train_data = CrystalGraphData(train_df, task_dict={target_col: task_type})
6768
train_loader = DataLoader(
6869
train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
6970
)
7071

72+
print(f"{test_df.shape=}")
7173
test_data = CrystalGraphData(test_df, task_dict={target_col: task_type})
7274
test_loader = DataLoader(
7375
test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_batch
@@ -90,6 +92,7 @@
9092
batch_size=batch_size,
9193
train_df=dict(shape=str(train_data.df.shape), columns=", ".join(train_df)),
9294
test_df=dict(shape=str(test_data.df.shape), columns=", ".join(test_df)),
95+
slurm_vars=slurm_vars,
9396
)
9497

9598

@@ -111,4 +114,5 @@
111114
timestamp=timestamp,
112115
train_loader=train_loader,
113116
wandb_path="janosh/matbench-discovery",
117+
run_params=run_params,
114118
)

models/m3gnet/test_m3gnet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@
7777
data_path=data_path,
7878
m3gnet_version=version("m3gnet"),
7979
task_type=task_type,
80-
slurm_max_job_time=slurm_max_job_time,
8180
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
82-
slurm_vars=slurm_vars,
81+
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
8382
)
8483
if wandb.run is None:
8584
wandb.login()

models/megnet/test_megnet.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import wandb
1010
from megnet.utils.models import load_model
11+
from sklearn.metrics import r2_score
1112
from tqdm import tqdm
1213

1314
from matbench_discovery import ROOT
@@ -54,8 +55,10 @@
5455
# %%
5556
data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
5657
print(f"{data_path=}")
57-
df_wbm_structs = pd.read_json(data_path).set_index("material_id")
58+
target_col = "e_form_per_atom_mp2020_corrected"
59+
assert target_col in df_wbm, f"{target_col=} not in {list(df_wbm)=}"
5860

61+
df_wbm_structs = pd.read_json(data_path).set_index("material_id")
5962
megnet_mp_e_form = load_model(model_name := "Eform_MP_2019")
6063

6164

@@ -65,9 +68,9 @@
6568
megnet_version=version("megnet"),
6669
model_name=model_name,
6770
task_type=task_type,
68-
slurm_max_job_time=slurm_max_job_time,
71+
target_col=target_col,
6972
df=dict(shape=str(df_wbm_structs.shape), columns=", ".join(df_wbm_structs)),
70-
slurm_vars=slurm_vars,
73+
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
7174
)
7275
if wandb.run is None:
7376
wandb.login()
@@ -105,26 +108,24 @@
105108
print(f"{len(megnet_e_form_preds)=:,}")
106109
print(f"{len(structures)=:,}")
107110
print(f"missing: {len(structures) - len(megnet_e_form_preds):,}")
108-
out_col = "e_form_per_atom_megnet"
109-
df_wbm[out_col] = pd.Series(megnet_e_form_preds)
111+
pred_col = "e_form_per_atom_megnet"
112+
df_wbm[pred_col] = pd.Series(megnet_e_form_preds)
110113

111-
df_wbm[out_col].reset_index().to_csv(out_path, index=False)
114+
df_wbm[pred_col].reset_index().to_csv(out_path, index=False)
112115

113116

114117
# %%
115-
fields = {"x": "e_form_per_atom_mp2020_corrected", "y": out_col}
116-
cols = list(fields.values())
117-
assert all(col in df_wbm for col in cols)
118-
119-
table = wandb.Table(dataframe=df_wbm[cols].reset_index())
118+
table = wandb.Table(dataframe=df_wbm[[target_col, pred_col]].reset_index())
120119

121-
MAE = (df_wbm[fields["x"]] - df_wbm[fields["y"]]).abs().mean()
120+
MAE = (df_wbm[target_col] - df_wbm[pred_col]).abs().mean()
121+
R2 = r2_score(df_wbm[target_col], df_wbm[pred_col])
122+
title = f"{model_name} {task_type} {MAE=:.4} {R2=:.4}"
123+
print(title)
122124

123125
scatter_plot = wandb.plot_table(
124126
vega_spec_name="janosh/scatter-parity",
125127
data_table=table,
126-
fields=fields,
127-
string_fields={"title": f"{model_name} {task_type} {MAE=:.4}"},
128+
fields=dict(x=target_col, y=pred_col, title=title),
128129
)
129130

130131
wandb.log({"true_pred_scatter": scatter_plot})

models/voronoi/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
]
2020
featurizer = MultipleFeaturizer(featurizers)
2121

22+
2223
# multiprocessing seems to be the cause of OOM errors on large structures even when
2324
# taking only small slice of the data and launching slurm jobs with --mem 100G
25+
# Alex Dunn has been aware of this problem for a while. presumed cause: chunk of data
26+
# (eg 50 structures) is sent to a single process, but sometimes one of those structures
27+
# might be huge causing that process to stall. Other processes in pool can't synchronize
28+
# at the end, effectively freezing the job
2429
featurizer.set_n_jobs(1)

models/voronoi/readme.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Voronoi Tessellation with matminer featurezation piped into `scikit-learn` Random Forest
1+
# Voronoi Tessellation with `matminer` featurization piped into `scikit-learn` `RandomForestRegressor`
22

33
## OOM errors during featurization
44

@@ -14,4 +14,4 @@ Saving tip came from [Alex Dunn via Slack](https://berkeleytheory.slack.com/arch
1414

1515
## Archive
1616

17-
Files in `2022-10-04-rhys-voronoi.zip` received from Rhys via [Slack](https://ml-physics.slack.com/archives/DD8GBBRLN/p1664929946687049). All originals before making any changes for this project.
17+
Files in `2022-10-04-rhys-voronoi.zip` received from Rhys via [Slack](https://ml-physics.slack.com/archives/DD8GBBRLN/p1664929946687049). They are unchanged originals.

models/voronoi/voronoi_featurize_dataset.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
today = f"{datetime.now():%Y-%m-%d}"
1717
module_dir = os.path.dirname(__file__)
18-
assert featurizer._n_jobs == 1, "set n_jobs=1 to avoid OOM errors"
1918

2019
data_name = "mp" # "mp"
2120
if data_name == "wbm":
@@ -25,17 +24,17 @@
2524
data_path = f"{ROOT}/data/mp/2022-09-16-mp-computed-structure-entries.json.gz"
2625
input_col = "structure"
2726

28-
slurm_array_task_count = 10
27+
slurm_array_task_count = 30
2928
job_name = f"voronoi-features-{data_name}"
3029
log_dir = f"{module_dir}/{today}-{job_name}"
3130

3231
slurm_vars = slurm_submit(
3332
job_name=job_name,
3433
partition="icelake-himem",
3534
account="LEE-SL3-CPU",
36-
time=(slurm_max_job_time := "8:0:0"),
35+
time=(slurm_max_job_time := "12:0:0"),
3736
array=f"1-{slurm_array_task_count}",
38-
slurm_flags=("--mem", "30G") if data_name == "mp" else (),
37+
slurm_flags=("--mem", "20G") if data_name == "mp" else (),
3938
log_dir=log_dir,
4039
)
4140

@@ -66,10 +65,9 @@
6665
# %%
6766
run_params = dict(
6867
data_path=data_path,
69-
slurm_max_job_time=slurm_max_job_time,
7068
df=dict(shape=str(df_this_job.shape), columns=", ".join(df_this_job)),
7169
input_col=input_col,
72-
slurm_vars=slurm_vars,
70+
slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time),
7371
)
7472
if wandb.run is None:
7573
wandb.login()
@@ -88,10 +86,12 @@
8886

8987
df_features = featurizer.featurize_dataframe(
9088
df_this_job, input_col, ignore_errors=True, pbar=dict(position=0, leave=True)
91-
).drop(columns=input_col)
89+
)
9290

9391

9492
# %%
95-
df_features.to_csv(out_path, default_handler=as_dict_handler)
93+
df_features[featurizer.feature_labels()].to_csv(
94+
out_path, default_handler=as_dict_handler
95+
)
9696

9797
wandb.log({"voronoi_features": wandb.Table(dataframe=df_features)})

0 commit comments

Comments
 (0)