|
29 | 29 | """
|
30 | 30 |
|
31 | 31 | today = f"{datetime.now():%Y-%m-%d}"
|
32 |
| -log_dir = f"{os.path.dirname(__file__)}/{today}-test" |
33 |
| -job_name = "test-cgcnn-ensemble" |
| 32 | +task_type = "RS2RE" |
| 33 | +job_name = f"test-cgcnn-wbm-{task_type}" |
| 34 | +module_dir = os.path.dirname(__file__) |
| 35 | +out_dir = os.environ.get("SBATCH_OUTPUT", f"{module_dir}/{today}-{job_name}") |
34 | 36 |
|
35 | 37 | slurm_vars = slurm_submit(
|
36 | 38 | job_name=job_name,
|
37 | 39 | partition="ampere",
|
38 | 40 | account="LEE-SL3-GPU",
|
39 | 41 | time=(slurm_max_job_time := "2:0:0"),
|
40 |
| - log_dir=log_dir, |
| 42 | + out_dir=out_dir, |
41 | 43 | slurm_flags=("--nodes", "1", "--gpus-per-node", "1"),
|
42 | 44 | )
|
43 | 45 |
|
44 | 46 |
|
45 | 47 | # %%
|
46 |
| -task_type = "IS2RE" |
47 | 48 | if task_type == "IS2RE":
|
48 | 49 | data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-init-structs.json.bz2"
|
| 50 | + input_col = "initial_structure" |
49 | 51 | elif task_type == "RS2RE":
|
50 | 52 | data_path = f"{ROOT}/data/wbm/2022-10-19-wbm-cses.json.bz2"
|
| 53 | + input_col = "relaxed_structure" |
| 54 | +else: |
| 55 | + raise ValueError(f"Unexpected {task_type=}") |
| 56 | + |
51 | 57 | df = pd.read_json(data_path).set_index("material_id", drop=False)
|
52 | 58 |
|
53 | 59 | target_col = "e_form_per_atom_mp2020_corrected"
|
54 | 60 | df[target_col] = df_wbm[target_col]
|
55 |
| -input_col = "initial_structure" |
56 | 61 | assert target_col in df, f"{target_col=} not in {list(df)}"
|
| 62 | +if task_type == "RS2RE": |
| 63 | + df[input_col] = [x["structure"] for x in df.computed_structure_entry] |
57 | 64 | assert input_col in df, f"{input_col=} not in {list(df)}"
|
58 | 65 |
|
59 | 66 | df[input_col] = [Structure.from_dict(x) for x in tqdm(df[input_col], disable=None)]
|
60 | 67 |
|
61 | 68 | filters = {
|
62 |
| - "$and": [{"created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"}}], |
| 69 | + "created_at": {"$gt": "2022-11-22", "$lt": "2022-11-23"}, |
63 | 70 | "display_name": {"$regex": "^cgcnn-robust"},
|
64 | 71 | }
|
65 | 72 | wandb.login()
|
|
87 | 94 | )
|
88 | 95 |
|
89 | 96 | slurm_job_id = os.environ.get("SLURM_JOB_ID", "debug")
|
90 |
| -wandb.init( |
91 |
| - project="matbench-discovery", name=f"{job_name}-{slurm_job_id}", config=run_params |
92 |
| -) |
| 97 | +run_name = f"{job_name}-{slurm_job_id}" |
| 98 | +wandb.init(project="matbench-discovery", name=run_name, config=run_params) |
93 | 99 |
|
94 | 100 | cg_data = CrystalGraphData(
|
95 | 101 | df, task_dict={target_col: "regression"}, structure_col=input_col
|
|
106 | 112 | data_loader=data_loader,
|
107 | 113 | )
|
108 | 114 |
|
109 |
| -df.to_csv(f"{log_dir}/{today}-{job_name}-preds.csv", index=False) |
| 115 | +df.to_csv(f"{out_dir}/{job_name}-preds.csv", index=False) |
110 | 116 | pred_col = f"{target_col}_pred_ens"
|
111 | 117 | table = wandb.Table(dataframe=df[[target_col, pred_col]].reset_index())
|
112 | 118 |
|
|
0 commit comments