|
9 | 9 | from pymatgen.core import Structure
|
10 | 10 | from pymatviz import density_scatter, plot_structure_2d, ptable_heatmap_plotly
|
11 | 11 |
|
12 |
| -from matbench_discovery import plots |
| 12 | +from matbench_discovery import plots as plots |
13 | 13 | from matbench_discovery.data import DATA_FILES, df_wbm
|
14 | 14 | from matbench_discovery.preds import PRED_FILES
|
15 | 15 |
|
16 | 16 | __author__ = "Janosh Riebesell"
|
17 | 17 | __date__ = "2023-03-06"
|
18 | 18 |
|
19 | 19 | module_dir = os.path.dirname(__file__)
|
20 |
| -del plots # https://github.com/PyCQA/pyflakes/issues/366 |
| 20 | +id_col = "material_id" |
21 | 21 |
|
22 | 22 |
|
23 | 23 | # %%
|
24 |
| -df_chgnet = pd.read_csv(PRED_FILES.CHGNet) |
25 |
| -df_chgnet = df_chgnet.set_index("material_id").add_suffix("_2000") |
26 |
| -df_chgnet_500 = pd.read_csv(PRED_FILES.CHGNet.replace("-06", "-04")) |
27 |
| -df_chgnet_500 = df_chgnet_500.set_index("material_id").add_suffix("_500") |
| 24 | +df_chgnet = pd.read_csv(PRED_FILES.__dict__["CHGNet"]) |
| 25 | +df_chgnet = df_chgnet.set_index(id_col).add_suffix("_2000") |
| 26 | +df_chgnet_500 = pd.read_csv(PRED_FILES.__dict__["CHGNet"].replace("-06", "-04")) |
| 27 | +df_chgnet_500 = df_chgnet_500.set_index(id_col).add_suffix("_500") |
28 | 28 | df_chgnet[list(df_chgnet_500)] = df_chgnet_500
|
29 | 29 | df_chgnet["formula"] = df_wbm.formula
|
30 | 30 |
|
31 | 31 | e_form_2000 = "e_form_per_atom_chgnet_2000"
|
32 | 32 | e_form_500 = "e_form_per_atom_chgnet_500"
|
33 | 33 |
|
34 |
| -min_e_diff = 0.35 |
| 34 | +min_e_diff = 0.1 |
| 35 | +# structures with smaller energy after longer relaxation need many steps |
| 36 | +df_long = df_chgnet.query(f"{e_form_2000} - {e_form_500} < -{min_e_diff}") |
| 37 | +# structures with larger energy after longer relaxation are problematic |
35 | 38 | df_bad = df_chgnet.query(f"{e_form_2000} - {e_form_500} > {min_e_diff}")
|
| 39 | +# both combined |
| 40 | +df_diff = df_chgnet.query(f"abs({e_form_2000} - {e_form_500}) > {min_e_diff}") |
| 41 | + |
| 42 | +assert len(df_long) + len(df_bad) == len(df_diff) |
| 43 | + |
| 44 | + |
| 45 | +# %% |
| 46 | +density_scatter(df=df_chgnet, x=e_form_500, y=e_form_2000) |
36 | 47 |
|
37 | 48 |
|
38 | 49 | # %%
|
39 |
| -density_scatter(df=df_chgnet, x=e_form_2000, y=e_form_500) |
| 50 | +df_diff.reset_index().plot.scatter( |
| 51 | + x=e_form_500, |
| 52 | + y=e_form_2000, |
| 53 | + hover_name=id_col, |
| 54 | + hover_data=["formula"], |
| 55 | + backend="plotly", |
| 56 | + title=f"{len(df_diff)} structures have > {min_e_diff} eV/atom energy diff after " |
| 57 | + "longer relaxation", |
| 58 | +) |
40 | 59 |
|
41 | 60 |
|
42 | 61 | # %%
|
43 | 62 | fig = ptable_heatmap_plotly(df_bad.formula)
|
44 |
| -title = "structures with larger error after longer relaxation" |
45 |
| -fig.layout.title.update(text=f"{len(df_bad)} {title}") |
| 63 | +title = "structures with larger error<br>after longer relaxation" |
| 64 | +fig.layout.title.update(text=f"{len(df_diff)} {title}", x=0.4, y=0.9) |
| 65 | +fig.show() |
46 | 66 |
|
47 | 67 |
|
48 | 68 | # %%
|
49 |
| -df_cse = pd.read_json(DATA_FILES.wbm_initial_structures).set_index("material_id") |
| 69 | +df_cse = pd.read_json(DATA_FILES.wbm_cses_plus_init_structs).set_index(id_col) |
| 70 | +df_cse.loc[df_diff.index].reset_index().to_json( |
| 71 | + f"{module_dir}/wbm-chgnet-bad-relax.json.gz" |
| 72 | +) |
50 | 73 |
|
51 | 74 |
|
52 | 75 | # %%
|
53 | 76 | n_rows, n_cols = 3, 4
|
54 | 77 | fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 4 * n_rows))
|
55 |
| -n_struct = min(n_rows * n_cols, len(df_bad)) |
| 78 | +n_struct = min(n_rows * n_cols, len(df_diff)) |
56 | 79 | struct_col = "initial_structure"
|
57 | 80 |
|
58 | 81 | fig.suptitle(f"{n_struct} {struct_col} {title}", fontsize=16, fontweight="bold", y=1.05)
|
59 | 82 | for idx, (ax, row) in enumerate(
|
60 |
| - zip(axs.flat, df_cse.loc[df_bad.index].itertuples()), 1 |
| 83 | + zip(axs.flat, df_cse.loc[df_diff.index].itertuples()), 1 |
61 | 84 | ):
|
62 | 85 | struct = Structure.from_dict(getattr(row, struct_col))
|
63 | 86 | plot_structure_2d(struct, ax=ax)
|
|
0 commit comments