|
30 | 30 | # %%
|
31 | 31 | module_dir = os.path.dirname(__file__)
|
32 | 32 | task_type = "IS2RE"
|
33 |
| -date = "2022-10-31" |
34 |
| -glob_pattern = f"{date}-m3gnet-wbm-{task_type}/*.json.gz" |
| 33 | +date = "2023-05-30" |
| 34 | +model_type = "directs" |
| 35 | +glob_pattern = f"{date}-m3gnet-{model_type}-wbm-{task_type}/*.json.gz" |
35 | 36 | file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
|
36 | 37 | struct_col = "m3gnet_structure"
|
37 | 38 | print(f"Found {len(file_paths):,} files for {glob_pattern = }")
|
38 | 39 |
|
39 |
| -dfs: dict[str, pd.DataFrame] = {} |
| 40 | +# prevent accidental overwrites |
| 41 | +if "dfs" not in locals(): |
| 42 | + dfs: dict[str, pd.DataFrame] = {} |
40 | 43 |
|
41 | 44 |
|
42 | 45 | # %%
|
|
66 | 69 | for row in tqdm(df_m3gnet.itertuples(), total=len(df_m3gnet)):
|
67 | 70 | mat_id, struct_dict, m3gnet_energy, *_ = row
|
68 | 71 | m3gnet_struct = Structure.from_dict(struct_dict)
|
69 |
| - df_m3gnet.loc[mat_id, struct_col] = m3gnet_struct |
| 72 | + df_m3gnet.at[mat_id, struct_col] = m3gnet_struct # noqa: PD008 |
70 | 73 | cse = df_cse.loc[mat_id, "cse"]
|
71 | 74 | cse._energy = m3gnet_energy # cse._energy is the uncorrected energy
|
72 | 75 | cse._structure = m3gnet_struct
|
|
81 | 84 |
|
82 | 85 |
|
83 | 86 | # %% compute corrected formation energies
|
84 |
| -df_m3gnet["e_form_per_atom_m3gnet"] = [ |
| 87 | +df_m3gnet[f"e_form_per_atom_m3gnet_{model_type}"] = [ |
85 | 88 | get_e_form_per_atom(cse) for cse in tqdm(df_m3gnet.cse)
|
86 | 89 | ]
|
87 | 90 |
|
|
93 | 96 |
|
94 | 97 |
|
95 | 98 | # %%
|
96 |
| -out_path = f"{module_dir}/{today}-m3gnet-wbm-{task_type}.json.gz" |
| 99 | +out_path = f"{module_dir}/{today}-m3gnet-{model_type}-wbm-{task_type}" |
97 | 100 | df_m3gnet = df_m3gnet.round(4)
|
98 |
| -df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler) |
| 101 | +df_m3gnet.select_dtypes("number").to_csv(f"{out_path}.csv") |
| 102 | +df_m3gnet.reset_index().to_json(f"{out_path}.json.gz", default_handler=as_dict_handler) |
99 | 103 |
|
100 |
| -df_m3gnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv")) |
101 | 104 |
|
102 | 105 | # in_path = f"{module_dir}/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
|
103 | 106 | # df_m3gnet = pd.read_csv(in_path.replace(".json.gz", ".csv")).set_index("material_id")
|
|
0 commit comments