|
15 | 15 |
|
16 | 16 | from matbench_discovery import ROOT
|
17 | 17 | from matbench_discovery.plot_scripts import df_wbm
|
| 18 | +from matbench_discovery.slurm import slurm_submit_python |
18 | 19 |
|
19 | 20 | __author__ = "Janosh Riebesell"
|
20 | 21 | __date__ = "2022-08-15"
|
|
27 | 28 |
|
28 | 29 | module_dir = os.path.dirname(__file__)
|
29 | 30 | today = f"{datetime.now():%Y-%m-%d}"
|
| 31 | +ensemble_id = "cgcnn-e_form-ensemble-1" |
| 32 | +run_name = f"{today}-{ensemble_id}-IS2RE" |
| 33 | + |
| 34 | +slurm_submit_python( |
| 35 | + job_name=run_name, |
| 36 | + partition="ampere", |
| 37 | + account="LEE-SL3-GPU", |
| 38 | + time="1:0:0", |
| 39 | + log_dir=module_dir, |
| 40 | + slurm_flags=("--nodes", "1", "--gpus-per-node", "1"), |
| 41 | +) |
30 | 42 |
|
31 | 43 |
|
32 | 44 | # %%
|
33 | 45 | data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
|
34 | 46 | df = pd.read_json(data_path).set_index("material_id", drop=False)
|
35 | 47 | old_len = len(df)
|
| 48 | +no_init_structs = df.query("initial_structure.isnull()").index |
36 | 49 | df = df.dropna() # two missing initial structures
|
37 | 50 | assert len(df) == old_len - 2
|
38 | 51 |
|
| 52 | +assert all( |
| 53 | + df.index == df_wbm.drop(index=no_init_structs).index |
| 54 | +), "df and df_wbm must have same index" |
39 | 55 | df["e_form_per_atom_mp2020_corrected"] = df_wbm.e_form_per_atom_mp2020_corrected
|
40 | 56 |
|
41 | 57 | target_col = "e_form_per_atom_mp2020_corrected"
|
42 | 58 | input_col = "initial_structure"
|
43 | 59 | assert target_col in df, f"{target_col=} not in {list(df)}"
|
44 | 60 | assert input_col in df, f"{input_col=} not in {list(df)}"
|
45 | 61 |
|
46 |
| -df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col])] |
| 62 | +df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)] |
47 | 63 |
|
48 | 64 | wandb.login()
|
49 | 65 | wandb_api = wandb.Api()
|
50 |
| -ensemble_id = "cgcnn-e_form-ensemble-1" |
51 | 66 | runs = wandb_api.runs(
|
52 | 67 | "janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}}
|
53 | 68 | )
|
|
62 | 77 | )
|
63 | 78 | df, ensemble_metrics = predict_from_wandb_checkpoints(
|
64 | 79 | runs,
|
65 |
| - df=cg_data.df, # dropping isolated-atom structs means len(cg_data.df) < len(df) |
| 80 | + # dropping isolated-atom structs means len(cg_data.df) < len(df) |
| 81 | + df=cg_data.df.reset_index(drop=True).drop(columns=input_col), |
66 | 82 | target_col=target_col,
|
67 |
| - model_class=CrystalGraphConvNet, |
| 83 | + model_cls=CrystalGraphConvNet, |
68 | 84 | data_loader=data_loader,
|
69 | 85 | )
|
70 | 86 |
|
71 |
| -df.round(6).to_csv(f"{module_dir}/{today}-{ensemble_id}-preds-{target_col}.csv") |
| 87 | +df.round(6).to_csv(f"{module_dir}/{today}-{run_name}-preds.csv") |
0 commit comments