|
2 | 2 | from datetime import datetime
|
3 | 3 |
|
4 | 4 | import pandas as pd
|
| 5 | +import pymatviz |
5 | 6 |
|
6 | 7 | from mb_discovery import ROOT
|
7 | 8 | from mb_discovery.plots import (
|
|
29 | 30 |
|
30 | 31 |
|
31 | 32 | # %%
|
32 |
| -df = pd.read_csv( |
| 33 | +dfs = {} |
| 34 | +dfs["wren"] = pd.read_csv( |
33 | 35 | f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
|
34 | 36 | ).set_index("material_id")
|
| 37 | +dfs["m3gnet"] = pd.read_json( |
| 38 | + f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz" |
| 39 | +).set_index("material_id") |
| 40 | +dfs["Wrenformer"] = pd.read_csv( |
| 41 | + f"{ROOT}/models/wrenformer/mp/" |
| 42 | + "2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv" |
| 43 | +).set_index("material_id") |
| 44 | + |
35 | 45 |
|
36 | 46 | df_hull = pd.read_csv(
|
37 | 47 | f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
|
38 | 48 | ).set_index("material_id")
|
39 | 49 |
|
40 |
| -df["e_above_mp_hull"] = df_hull.e_above_mp_hull |
41 |
| - |
42 | 50 | # download wbm-steps-summary.csv (23.31 MB)
|
43 |
| -df_summary = pd.read_csv( |
| 51 | +df_wbm = pd.read_csv( |
44 | 52 | "https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
|
45 | 53 | ).set_index("material_id")
|
46 | 54 |
|
47 | 55 |
|
| 56 | +dfs["m3gnet"] = dfs.pop("M3Gnet") |
| 57 | + |
| 58 | + |
| 59 | +# %% |
| 60 | +if "wren" in dfs: |
| 61 | + df = dfs["wren"] |
| 62 | + pred_cols = df.filter(regex=r"_pred_\d").columns |
| 63 | + # make sure we average the expected number of ensemble member predictions |
| 64 | + assert len(pred_cols) == 10 |
| 65 | + df["e_form_per_atom_pred"] = df[pred_cols].mean(axis=1) |
| 66 | +if "m3gnet" in dfs: |
| 67 | + df = dfs["m3gnet"] |
| 68 | + df["e_form_per_atom_pred"] = df.e_form_ppd_2022_01_25 |
| 69 | + |
| 70 | + |
48 | 71 | # %%
|
49 | 72 | which_energy: WhichEnergy = "true"
|
50 | 73 | stability_crit: StabilityCriterion = "energy"
|
51 |
| -df["wbm_batch"] = df.index.str.split("-").str[2] |
52 | 74 | fig, axs = plt.subplots(2, 3, figsize=(18, 9))
|
53 | 75 |
|
54 |
| -# make sure we average the expected number of ensemble member predictions |
55 |
| -pred_cols = df.filter(regex=r"_pred_\d").columns |
56 |
| -assert len(pred_cols) == 10 |
| 76 | +df = dfs[(model_name := "wren")] |
57 | 77 |
|
| 78 | +df["e_above_mp_hull"] = df_hull.e_above_mp_hull |
| 79 | +df["e_form_per_atom"] = df_wbm.e_form_per_atom |
| 80 | + |
| 81 | + |
| 82 | +for batch_idx, ax in zip(range(1, 6), axs.flat): |
| 83 | + batch_df = df[df.index.str.startswith(f"wbm-step-{batch_idx}-")] |
| 84 | + assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}") |
58 | 85 |
|
59 |
| -for (batch_idx, batch_df), ax in zip(df.groupby("wbm_batch"), axs.flat): |
60 | 86 | hist_classified_stable_as_func_of_hull_dist(
|
61 |
| - e_above_hull_pred=batch_df[pred_cols].mean(axis=1) - batch_df.e_form_target, |
| 87 | + e_above_hull_pred=batch_df.e_form_per_atom_pred - batch_df.e_form_per_atom, |
62 | 88 | e_above_hull_true=batch_df.e_above_mp_hull,
|
63 | 89 | which_energy=which_energy,
|
64 | 90 | stability_crit=stability_crit,
|
65 | 91 | ax=ax,
|
66 | 92 | )
|
67 | 93 |
|
68 |
| - title = f"Batch {batch_idx} ({len(df):,})" |
| 94 | + title = f"Batch {batch_idx} ({len(batch_df):,})" |
69 | 95 | ax.set(title=title)
|
70 | 96 |
|
71 | 97 |
|
72 | 98 | hist_classified_stable_as_func_of_hull_dist(
|
73 |
| - e_above_hull_pred=df[pred_cols].mean(axis=1), |
| 99 | + e_above_hull_pred=df.e_form_per_atom_pred - df.e_form_per_atom, |
74 | 100 | e_above_hull_true=df.e_above_mp_hull,
|
75 | 101 | which_energy=which_energy,
|
76 | 102 | stability_crit=stability_crit,
|
|
80 | 106 | axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
|
81 | 107 | axs.flat[0].legend(frameon=False, loc="upper left")
|
82 | 108 |
|
83 |
| -img_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf" |
| 109 | +img_name = ( |
| 110 | + f"{today}-{model_name}-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf" |
| 111 | +) |
84 | 112 | # plt.savefig(f"{ROOT}/figures/{img_name}")
|
| 113 | + |
| 114 | + |
| 115 | +# %% |
| 116 | +pymatviz.density_scatter( |
| 117 | + dfs["wren"].dropna().e_form_per_atom_pred, dfs["wren"].dropna().e_form_per_atom |
| 118 | +) |
| 119 | + |
| 120 | +pymatviz.density_scatter( |
| 121 | + dfs["m3gnet"].dropna().e_form_per_atom_pred, dfs["m3gnet"].dropna().e_form_per_atom |
| 122 | +) |
0 commit comments