|
3 | 3 |
|
4 | 4 | import os
|
5 | 5 | from datetime import datetime
|
| 6 | +from importlib.metadata import version |
6 | 7 |
|
7 | 8 | import pandas as pd
|
8 | 9 | import wandb
|
9 | 10 | from aviary.cgcnn.data import CrystalGraphData, collate_batch
|
10 | 11 | from aviary.cgcnn.model import CrystalGraphConvNet
|
11 | 12 | from aviary.deploy import predict_from_wandb_checkpoints
|
12 | 13 | from pymatgen.core import Structure
|
13 |
| -from pymatviz import density_scatter |
14 | 14 | from torch.utils.data import DataLoader
|
15 | 15 | from tqdm import tqdm
|
16 | 16 |
|
|
29 | 29 |
|
30 | 30 | today = f"{datetime.now():%Y-%m-%d}"
|
31 | 31 | log_dir = f"{os.path.dirname(__file__)}/{today}-test"
|
32 |
| -ensemble_id = "cgcnn-e_form-ensemble-1" |
33 |
| -run_name = f"{ensemble_id}-IS2RE" |
| 32 | +job_name = "test-cgcnn-ensemble" |
34 | 33 |
|
35 |
| -slurm_submit( |
36 |
| - job_name=run_name, |
| 34 | +slurm_vars = slurm_submit( |
| 35 | + job_name=job_name, |
37 | 36 | partition="ampere",
|
38 | 37 | account="LEE-SL3-GPU",
|
39 |
| - time="1:0:0", |
| 38 | + time=(slurm_max_job_time := "2:0:0"), |
40 | 39 | log_dir=log_dir,
|
41 | 40 | slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
|
42 | 41 | )
|
43 | 42 |
|
44 | 43 |
|
45 | 44 | # %%
|
46 |
| -data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2" |
| 45 | +task_type = "IS2RE" |
| 46 | +if task_type == "IS2RE": |
| 47 | + data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2" |
| 48 | +elif task_type == "RS2RE": |
| 49 | + data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-cses.json.bz2" |
47 | 50 | df = pd.read_json(data_path).set_index("material_id", drop=False)
|
48 |
| -old_len = len(df) |
49 |
| -no_init_structs = df.query("initial_structure.isnull()").index |
50 |
| -df = df.dropna() # two missing initial structures |
51 |
| -assert len(df) == old_len - 2 |
52 |
| - |
53 |
| -assert all(df.index == df_wbm.drop(index=no_init_structs).index) |
54 | 51 |
|
55 | 52 | target_col = "e_form_per_atom_mp2020_corrected"
|
56 | 53 | df[target_col] = df_wbm[target_col]
|
|
60 | 57 |
|
61 | 58 | df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
|
62 | 59 |
|
| 60 | +filters = { |
| 61 | + "$and": [{"created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"}}], |
| 62 | + "display_name": {"$regex": "^cgcnn-robust"}, |
| 63 | +} |
63 | 64 | wandb.login()
|
64 |
| -runs = wandb.Api().runs( |
65 |
| - "janosh/matbench-discovery", filters={"tags": {"$in": [ensemble_id]}} |
| 65 | +runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters) |
| 66 | + |
| 67 | +assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {filters=}" |
| 68 | +for idx, run in enumerate(runs): |
| 69 | + for key, val in run.config.items(): |
| 70 | + if val == runs[0][key] or key.startswith(("slurm_", "timestamp")): |
| 71 | + continue |
| 72 | + raise ValueError( |
| 73 | + f"Configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}" |
| 74 | + ) |
| 75 | + |
| 76 | +run_params = dict( |
| 77 | + data_path=data_path, |
| 78 | + df=dict(shape=str(df.shape), columns=", ".join(df)), |
| 79 | + aviary_version=version("aviary"), |
| 80 | + ensemble_size=len(runs), |
| 81 | + task_type=task_type, |
| 82 | + target_col=target_col, |
| 83 | + input_col=input_col, |
| 84 | + filters=filters, |
| 85 | + slurm_vars=slurm_vars | dict(slurm_max_job_time=slurm_max_job_time), |
66 | 86 | )
|
67 | 87 |
|
68 |
| -assert len(runs) == 10, f"Expected 10 runs, got {len(runs)} for {ensemble_id=}" |
| 88 | +slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug") |
| 89 | +wandb.init( |
| 90 | + project="matbench-discovery", name=f"{job_name}-{slurm_job_id}", config=run_params |
| 91 | +) |
69 | 92 |
|
70 | 93 | cg_data = CrystalGraphData(
|
71 | 94 | df, task_dict={target_col: "regression"}, structure_col=input_col
|
|
82 | 105 | data_loader=data_loader,
|
83 | 106 | )
|
84 | 107 |
|
85 |
| -df.round(6).to_csv(f"{log_dir}/{today}-{run_name}-preds.csv", index=False) |
| 108 | +df.to_csv(f"{log_dir}/{today}-{job_name}-preds.csv", index=False) |
| 109 | +table = wandb.Table(dataframe=df) |
86 | 110 |
|
87 | 111 |
|
88 | 112 | # %%
|
89 |
| -print(f"{runs[0].url=}") |
90 |
| -ax = density_scatter( |
91 |
| - df=df.query("e_form_per_atom_mp2020_corrected < 10"), |
92 |
| - x="e_form_per_atom_mp2020_corrected", |
93 |
| - y="e_form_per_atom_mp2020_corrected_pred_1", |
| 113 | +pred_col = f"{target_col}_pred_ens" |
| 114 | +MAE = ensemble_metrics["MAE"] |
| 115 | +R2 = ensemble_metrics["R2"] |
| 116 | + |
| 117 | +title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}" |
| 118 | +print(title) |
| 119 | + |
| 120 | +scatter_plot = wandb.plot_table( |
| 121 | + vega_spec_name="janosh/scatter-parity", |
| 122 | + data_table=table, |
| 123 | + fields=dict(x=target_col, y=pred_col, title=title), |
94 | 124 | )
|
95 |
| -# ax.figure.savefig(f"{ROOT}/tmp/{today}-{run_name}-scatter-preds.png", dpi=300) |
| 125 | + |
| 126 | +wandb.log({"true_pred_scatter": scatter_plot}) |
0 commit comments