|
39 | 39 | f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
|
40 | 40 | ).set_index("material_id")
|
41 | 41 | dfs["wrenformer"] = pd.read_csv(
|
42 |
| - f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv" |
| 42 | + f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv" |
43 | 43 | ).set_index("material_id")
|
44 | 44 | dfs["bowsr_megnet"] = pd.read_json(
|
45 | 45 | f"{ROOT}/models/bowsr/2022-09-22-bowsr-megnet-wbm-IS2RE.json.gz"
|
46 | 46 | ).set_index("material_id")
|
47 | 47 |
|
48 | 48 |
|
49 | 49 | # %%
|
| 50 | +pred_col = "e_form_per_atom_pred" |
| 51 | +target_col = "e_form_per_atom" |
50 | 52 | if "wren" in dfs:
|
51 | 53 | df = dfs["wren"]
|
52 | 54 | pred_cols = df.filter(regex=r"_pred_\d").columns
|
53 | 55 | # make sure we average the expected number of ensemble member predictions
|
54 | 56 | assert len(pred_cols) == 10
|
55 |
| - df["e_form_per_atom_pred"] = df[pred_cols].mean(axis=1) |
| 57 | + df[pred_col] = df[pred_cols].mean(axis=1) |
56 | 58 | if "m3gnet" in dfs:
|
57 | 59 | df = dfs["m3gnet"]
|
58 |
| - df["e_form_per_atom_pred"] = df.e_form_per_atom_m3gnet |
| 60 | + df[pred_col] = df.e_form_per_atom_m3gnet |
59 | 61 | if "bowsr_megnet" in dfs:
|
60 | 62 | df = dfs["bowsr_megnet"]
|
61 |
| - df["e_form_per_atom_pred"] = df.e_form_per_atom_bowsr |
| 63 | + df[pred_col] = df.e_form_per_atom_bowsr |
| 64 | +if "wrenformer" in dfs: |
| 65 | + pred_col = "e_form_per_atom_mp2020_corrected_pred_ens" |
62 | 66 |
|
63 | 67 |
|
64 | 68 | # %%
|
65 | 69 | which_energy: WhichEnergy = "true"
|
66 | 70 | stability_crit: StabilityCriterion = "energy"
|
67 | 71 | fig, axs = plt.subplots(2, 3, figsize=(18, 9))
|
68 | 72 |
|
69 |
| -model_name = "m3gnet" |
| 73 | +model_name = "wrenformer" |
70 | 74 | df = dfs[model_name]
|
71 | 75 |
|
72 | 76 | df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
|
73 |
| -df["e_form_per_atom"] = df_wbm.e_form_per_atom_mp2020_corrected |
| 77 | +df[target_col] = df_wbm.e_form_per_atom_mp2020_corrected # e_form targets |
74 | 78 |
|
75 | 79 |
|
76 | 80 | for batch_idx, ax in zip(range(1, 6), axs.flat):
|
77 | 81 | batch_df = df[df.index.str.startswith(f"wbm-step-{batch_idx}-")]
|
78 | 82 | assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")
|
79 | 83 |
|
80 | 84 | ax, metrics = hist_classified_stable_vs_hull_dist(
|
81 |
| - e_above_hull_pred=batch_df.e_form_per_atom_pred - batch_df.e_form_per_atom, |
| 85 | + e_above_hull_pred=batch_df[pred_col] - batch_df.e_form_per_atom, |
82 | 86 | e_above_hull_true=batch_df.e_above_hull_mp,
|
83 | 87 | which_energy=which_energy,
|
84 | 88 | stability_crit=stability_crit,
|
|
93 | 97 |
|
94 | 98 |
|
95 | 99 | ax, metrics = hist_classified_stable_vs_hull_dist(
|
96 |
| - e_above_hull_pred=df.e_form_per_atom_pred - df.e_form_per_atom, |
| 100 | + e_above_hull_pred=df[pred_col] - df.e_form_per_atom, |
97 | 101 | e_above_hull_true=df.e_above_hull_mp,
|
98 | 102 | which_energy=which_energy,
|
99 | 103 | stability_crit=stability_crit,
|
100 | 104 | ax=axs.flat[-1],
|
101 | 105 | )
|
102 | 106 |
|
103 | 107 | text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
|
104 |
| -ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes) |
| 108 | +ax.text(0.02, 0.3, text, fontsize=16, transform=ax.transAxes) |
105 | 109 |
|
106 |
| -axs.flat[-1].set(title=f"Combined ({len(df.filter(like='e_').dropna()):,})") |
| 110 | +axs.flat[-1].set(title=f"All batches ({len(df.filter(like='e_').dropna()):,})") |
107 | 111 | axs.flat[0].legend(frameon=False, loc="upper left")
|
108 | 112 |
|
109 |
| -img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches" |
110 |
| -suptitle = img_name.replace("-", "/", 2).replace("-", " ") |
111 |
| -fig.suptitle(suptitle, y=1.07, fontsize=16) |
| 113 | +fig.suptitle(f"{today} {model_name}", y=1.07, fontsize=16) |
112 | 114 |
|
113 | 115 |
|
114 | 116 | # %%
|
| 117 | +img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches" |
115 | 118 | ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")
|
116 | 119 |
|
117 | 120 |
|
118 | 121 | # %%
|
119 | 122 | pymatviz.density_scatter(
|
120 |
| - df=dfs["m3gnet"].query("e_form_per_atom < 5"), |
121 |
| - x="e_form_per_atom", |
122 |
| - y="e_form_per_atom_pred", |
| 123 | + df=dfs[model_name].query(f"{target_col} < 5"), x=target_col, y=pred_col |
123 | 124 | )
|
0 commit comments