|
23 | 23 | __date__ = "2022-08-15"
|
24 | 24 |
|
25 | 25 | """
|
26 |
| -Script that downloads checkpoints for an ensemble of CGCNN models trained on all MP |
| 26 | +Download WandB checkpoints for an ensemble of CGCNN models trained on all MP |
27 | 27 | formation energies, then makes predictions on some dataset, prints ensemble metrics and
|
28 | 28 | saves predictions to CSV.
|
29 | 29 | """
|
30 | 30 |
|
31 |
| -task_type = "RS2RE" |
| 31 | +task_type = "IS2RE" |
32 | 32 | debug = "slurm-submit" in sys.argv
|
33 | 33 | job_name = f"test-cgcnn-wbm-{task_type}{'-debug' if DEBUG else ''}"
|
34 | 34 | module_dir = os.path.dirname(__file__)
|
|
58 | 58 |
|
59 | 59 | target_col = "e_form_per_atom_mp2020_corrected"
|
60 | 60 | df[target_col] = df_wbm[target_col]
|
61 |
| -assert target_col in df, f"{target_col=} not in {list(df)}" |
62 | 61 | if task_type == "RS2RE":
|
63 | 62 | df[input_col] = [x["structure"] for x in df.computed_structure_entry]
|
64 | 63 | assert input_col in df, f"{input_col=} not in {list(df)}"
|
65 | 64 |
|
66 | 65 | df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
|
67 | 66 |
|
68 | 67 | filters = {
|
69 |
| - "created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"}, |
70 |
| - "display_name": {"$regex": "^cgcnn-robust"}, |
| 68 | + "created_at": {"$gt": "2022-12-03", "$lt": "2022-12-04"}, |
| 69 | + "display_name": {"$regex": "^train-cgcnn-robust-augment=3-"}, |
71 | 70 | }
|
72 | 71 | runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters)
|
73 | 72 |
|
|
92 | 91 | slurm_vars=slurm_vars,
|
93 | 92 | )
|
94 | 93 |
|
95 |
| - |
96 | 94 | wandb.init(project="matbench-discovery", name=job_name, config=run_params)
|
97 | 95 |
|
98 | 96 | cg_data = CrystalGraphData(
|
99 |
| - df, |
100 |
| - task_dict={target_col: "regression"}, |
101 |
| - structure_col=input_col, |
102 |
| - identifiers=["formula_from_cse"], |
| 97 | + df, task_dict={target_col: "regression"}, structure_col=input_col |
103 | 98 | )
|
104 | 99 | data_loader = DataLoader(
|
105 | 100 | cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
|
106 | 101 | )
|
107 |
| -df, ensemble_metrics = predict_from_wandb_checkpoints( |
| 102 | +df_preds, ensemble_metrics = predict_from_wandb_checkpoints( |
108 | 103 | runs,
|
109 | 104 | # dropping isolated-atom structs means len(cg_data.df) < len(df)
|
110 | 105 | cache_dir=CHECKPOINT_DIR,
|
|
114 | 109 | data_loader=data_loader,
|
115 | 110 | )
|
116 | 111 |
|
117 |
| -df.to_csv(f"{out_dir}/{job_name}-preds.csv", index=False) |
| 112 | +df_preds.to_csv(f"{out_dir}/{job_name}-preds.csv", index=False) |
118 | 113 | pred_col = f"{target_col}_pred_ens"
|
119 |
| -table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index()) |
| 114 | +assert pred_col in df, f"{pred_col=} not in {list(df)}" |
| 115 | +table = wandb.Table(dataframe=df_preds[[target_col, pred_col]].reset_index()) |
120 | 116 |
|
121 | 117 |
|
122 | 118 | # %%
|
|
0 commit comments