|
14 | 14 | from torch.utils.data import DataLoader
|
15 | 15 | from tqdm import tqdm
|
16 | 16 |
|
17 |
| -from matbench_discovery import DEBUG, ROOT, today |
| 17 | +from matbench_discovery import CHECKPOINT_DIR, DEBUG, ROOT, today |
18 | 18 | from matbench_discovery.load_preds import df_wbm
|
19 | 19 | from matbench_discovery.plots import wandb_scatter
|
20 | 20 | from matbench_discovery.slurm import slurm_submit
|
|
23 | 23 | __date__ = "2022-08-15"
|
24 | 24 |
|
25 | 25 | """
|
26 |
| -Script that downloads checkpoints for an ensemble of Wrenformer models trained on the MP |
| 26 | +Script that downloads checkpoints for an ensemble of CGCNN models trained on all MP |
27 | 27 | formation energies, then makes predictions on some dataset, prints ensemble metrics and
|
28 |
| -stores predictions to CSV. |
| 28 | +saves predictions to CSV. |
29 | 29 | """
|
30 | 30 |
|
31 | 31 | task_type = "RS2RE"
|
|
54 | 54 | else:
|
55 | 55 | raise ValueError(f"Unexpected {task_type=}")
|
56 | 56 |
|
57 |
| -df = pd.read_json(data_path).set_index("material_id", drop=False) |
| 57 | +df = pd.read_json(data_path).set_index("material_id") |
58 | 58 |
|
59 | 59 | target_col = "e_form_per_atom_mp2020_corrected"
|
60 | 60 | df[target_col] = df_wbm[target_col]
|
|
88 | 88 | task_type=task_type,
|
89 | 89 | target_col=target_col,
|
90 | 90 | input_col=input_col,
|
91 |
| - filters=filters, |
| 91 | + wandb_run_filters=filters, |
92 | 92 | slurm_vars=slurm_vars,
|
93 | 93 | )
|
94 | 94 |
|
|
99 | 99 | df,
|
100 | 100 | task_dict={target_col: "regression"},
|
101 | 101 | structure_col=input_col,
|
102 |
| - identifiers=("material_id", "formula_from_cse"), |
| 102 | + identifiers=["formula_from_cse"], |
103 | 103 | )
|
104 | 104 | data_loader = DataLoader(
|
105 | 105 | cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
|
106 | 106 | )
|
107 | 107 | df, ensemble_metrics = predict_from_wandb_checkpoints(
|
108 | 108 | runs,
|
109 | 109 | # dropping isolated-atom structs means len(cg_data.df) < len(df)
|
110 |
| - df=cg_data.df.reset_index(drop=True).drop(columns=input_col), |
| 110 | + cache_dir=CHECKPOINT_DIR, |
| 111 | + df=cg_data.df.drop(columns=input_col), |
111 | 112 | target_col=target_col,
|
112 | 113 | model_cls=CrystalGraphConvNet,
|
113 | 114 | data_loader=data_loader,
|
|
122 | 123 | MAE = ensemble_metrics.MAE.mean()
|
123 | 124 | R2 = ensemble_metrics.R2.mean()
|
124 | 125 |
|
125 |
| -title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}" |
| 126 | +title = f"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}" |
126 | 127 |
|
127 | 128 | wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
|
0 commit comments