|
16 | 16 | from matbench_discovery import CHECKPOINT_DIR, ROOT, WANDB_PATH, today
|
17 | 17 | from matbench_discovery.data import DATA_FILES, df_wbm
|
18 | 18 | from matbench_discovery.plots import wandb_scatter
|
| 19 | +from matbench_discovery.preds import e_form_col as target_col |
19 | 20 | from matbench_discovery.slurm import slurm_submit
|
20 | 21 |
|
21 | 22 | __author__ = "Janosh Riebesell"
|
|
53 | 54 |
|
54 | 55 | df = pd.read_json(data_path).set_index("material_id")
|
55 | 56 |
|
56 |
| -e_form_col = "e_form_per_atom_mp2020_corrected" |
57 |
| -df[e_form_col] = df_wbm[e_form_col] |
| 57 | +df[target_col] = df_wbm[target_col] |
58 | 58 | if task_type == "RS2RE":
|
59 | 59 | df[input_col] = [x["structure"] for x in df.computed_structure_entry]
|
60 | 60 | assert input_col in df, f"{input_col=} not in {list(df)}"
|
|
87 | 87 | versions={dep: version(dep) for dep in ("aviary", "numpy", "torch")},
|
88 | 88 | ensemble_size=len(runs),
|
89 | 89 | task_type=task_type,
|
90 |
| - target_col=e_form_col, |
| 90 | + target_col=target_col, |
91 | 91 | input_col=input_col,
|
92 | 92 | wandb_run_filters=filters,
|
93 | 93 | slurm_vars=slurm_vars,
|
|
97 | 97 | wandb.init(project="matbench-discovery", name=job_name, config=run_params)
|
98 | 98 |
|
99 | 99 | cg_data = CrystalGraphData(
|
100 |
| - df, task_dict={e_form_col: "regression"}, structure_col=input_col |
| 100 | + df, task_dict={target_col: "regression"}, structure_col=input_col |
101 | 101 | )
|
102 | 102 | data_loader = DataLoader(
|
103 | 103 | cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
|
|
110 | 110 | # dropping isolated-atom structs means len(cg_data.df) < len(df)
|
111 | 111 | cache_dir=CHECKPOINT_DIR,
|
112 | 112 | df=cg_data.df.drop(columns=input_col),
|
113 |
| - target_col=e_form_col, |
| 113 | + target_col=target_col, |
114 | 114 | model_cls=CrystalGraphConvNet,
|
115 | 115 | data_loader=data_loader,
|
116 | 116 | )
|
117 | 117 |
|
118 | 118 | slurm_job_id = os.getenv("SLURM_JOB_ID", "debug")
|
119 | 119 | df.round(4).to_csv(f"{out_dir}/{job_name}-preds-{slurm_job_id}.csv.gz")
|
120 |
| -pred_col = f"{e_form_col}_pred_ens" |
| 120 | +pred_col = f"{target_col}_pred_ens" |
121 | 121 | assert pred_col in df, f"{pred_col=} not in {list(df)}"
|
122 |
| -table = wandb.Table(dataframe=df[[e_form_col, pred_col]].reset_index()) |
| 122 | +table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index()) |
123 | 123 |
|
124 | 124 |
|
125 | 125 | # %%
|
|
128 | 128 |
|
129 | 129 | title = f"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
|
130 | 130 |
|
131 |
| -wandb_scatter(table, fields=dict(x=e_form_col, y=pred_col), title=title) |
| 131 | +wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title) |
0 commit comments