Skip to content

Commit 4f1e5b6

Browse files
committed
matbench_discovery/__init__.py define CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
test_cgcnn remove reset_index from cg_data.df.reset_index(drop=True)
1 parent 6030ef3 commit 4f1e5b6

File tree

4 files changed

+17
-14
lines changed

4 files changed

+17
-14
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ models/**/*.csv
2323

2424
# temporary ignore rules
2525
paper
26-
meeting-notes
2726
models/voronoi/*.zip
27+
site

matbench_discovery/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
ROOT = os.path.dirname(os.path.dirname(__file__))
1010
DEBUG = "slurm-submit" not in sys.argv and "SLURM_JOB_ID" not in os.environ
11+
CHECKPOINT_DIR = f"{ROOT}/wandb/checkpoints"
1112

1213
timestamp = f"{datetime.now():%Y-%m-%d@%H-%M-%S}"
1314
today = timestamp.split("@")[0]

models/cgcnn/test_cgcnn.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.utils.data import DataLoader
1515
from tqdm import tqdm
1616

17-
from matbench_discovery import DEBUG, ROOT, today
17+
from matbench_discovery import CHECKPOINT_DIR, DEBUG, ROOT, today
1818
from matbench_discovery.load_preds import df_wbm
1919
from matbench_discovery.plots import wandb_scatter
2020
from matbench_discovery.slurm import slurm_submit
@@ -23,9 +23,9 @@
2323
__date__ = "2022-08-15"
2424

2525
"""
26-
Script that downloads checkpoints for an ensemble of Wrenformer models trained on the MP
26+
Script that downloads checkpoints for an ensemble of CGCNN models trained on all MP
2727
formation energies, then makes predictions on some dataset, prints ensemble metrics and
28-
stores predictions to CSV.
28+
saves predictions to CSV.
2929
"""
3030

3131
task_type = "RS2RE"
@@ -54,7 +54,7 @@
5454
else:
5555
raise ValueError(f"Unexpected {task_type=}")
5656

57-
df = pd.read_json(data_path).set_index("material_id", drop=False)
57+
df = pd.read_json(data_path).set_index("material_id")
5858

5959
target_col = "e_form_per_atom_mp2020_corrected"
6060
df[target_col] = df_wbm[target_col]
@@ -88,7 +88,7 @@
8888
task_type=task_type,
8989
target_col=target_col,
9090
input_col=input_col,
91-
filters=filters,
91+
wandb_run_filters=filters,
9292
slurm_vars=slurm_vars,
9393
)
9494

@@ -99,15 +99,16 @@
9999
df,
100100
task_dict={target_col: "regression"},
101101
structure_col=input_col,
102-
identifiers=("material_id", "formula_from_cse"),
102+
identifiers=["formula_from_cse"],
103103
)
104104
data_loader = DataLoader(
105105
cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
106106
)
107107
df, ensemble_metrics = predict_from_wandb_checkpoints(
108108
runs,
109109
# dropping isolated-atom structs means len(cg_data.df) < len(df)
110-
df=cg_data.df.reset_index(drop=True).drop(columns=input_col),
110+
cache_dir=CHECKPOINT_DIR,
111+
df=cg_data.df.drop(columns=input_col),
111112
target_col=target_col,
112113
model_cls=CrystalGraphConvNet,
113114
data_loader=data_loader,
@@ -122,6 +123,6 @@
122123
MAE = ensemble_metrics.MAE.mean()
123124
R2 = ensemble_metrics.R2.mean()
124125

125-
title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
126+
title = f"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
126127

127128
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

models/wrenformer/test_wrenformer.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
from aviary.wrenformer.data import df_to_in_mem_dataloader
1212
from aviary.wrenformer.model import Wrenformer
1313

14-
from matbench_discovery import DEBUG, ROOT, today
14+
from matbench_discovery import CHECKPOINT_DIR, DEBUG, ROOT, today
1515
from matbench_discovery.plots import wandb_scatter
1616
from matbench_discovery.slurm import slurm_submit
1717

1818
__author__ = "Janosh Riebesell"
1919
__date__ = "2022-08-15"
2020

2121
"""
22-
Download WandB checkpoints for an ensemble of Wrenformer models trained on MP
22+
Download WandB checkpoints for an ensemble of Wrenformer models trained on all MP
2323
formation energies, then makes predictions on some dataset, prints ensemble metrics and
24-
stores predictions to CSV.
24+
saves predictions to CSV.
2525
"""
2626

2727
task_type = "IS2RE"
@@ -74,7 +74,7 @@
7474
task_type=task_type,
7575
target_col=target_col,
7676
input_col=input_col,
77-
filters=filters,
77+
wandb_run_filters=filters,
7878
slurm_vars=slurm_vars,
7979
)
8080

@@ -84,6 +84,7 @@
8484
# %%
8585
data_loader = df_to_in_mem_dataloader(
8686
df=df,
87+
cache_dir=CHECKPOINT_DIR,
8788
target_col=target_col,
8889
batch_size=1024,
8990
input_col=input_col,
@@ -108,6 +109,6 @@
108109
MAE = ensemble_metrics.MAE.mean()
109110
R2 = ensemble_metrics.R2.mean()
110111

111-
title = rf"Wrenformer {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
112+
title = f"Wrenformer {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
112113

113114
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

0 commit comments

Comments
 (0)