|
6 | 6 |
|
7 | 7 | from mb_discovery import ROOT
|
8 | 8 | from mb_discovery.plot_scripts.plot_funcs import (
|
| 9 | + StabilityCriterion, |
| 10 | + WhichEnergy, |
9 | 11 | hist_classified_stable_as_func_of_hull_dist,
|
10 | 12 | )
|
11 | 13 |
|
|
50 | 52 |
|
51 | 53 |
|
52 | 54 | # %%
|
53 |
| -energy_type = "true" |
54 |
| -stability_crit = "energy" |
| 55 | +which_energy: WhichEnergy = "true" |
| 56 | +stability_crit: StabilityCriterion = "energy" |
55 | 57 | df["wbm_batch"] = df.index.str.split("-").str[2]
|
56 | 58 | fig, axs = plt.subplots(2, 3, figsize=(18, 9))
|
57 | 59 |
|
58 | 60 | # make sure we average the expected number of ensemble member predictions
|
59 | 61 | pred_cols = df.filter(regex=r"_pred_\d").columns
|
60 | 62 | assert len(pred_cols) == 10
|
61 | 63 |
|
62 |
| -common_kwargs = dict( |
63 |
| - target_col="e_form_target", |
64 |
| - pred_cols=pred_cols, |
65 |
| - energy_type=energy_type, |
66 |
| - stability_crit=stability_crit, |
67 |
| - e_above_hull_col="e_above_mp_hull", |
68 |
| -) |
69 | 64 |
|
70 | 65 | for (batch_idx, batch_df), ax in zip(df.groupby("wbm_batch"), axs.flat):
|
71 |
| - hist_classified_stable_as_func_of_hull_dist(batch_df, ax=ax, **common_kwargs) |
| 66 | + hist_classified_stable_as_func_of_hull_dist( |
| 67 | + e_above_hull_pred=batch_df[pred_cols].mean(axis=1) - batch_df.e_form_target, |
| 68 | + e_above_hull_true=batch_df.e_above_mp_hull, |
| 69 | + which_energy=which_energy, |
| 70 | + stability_crit=stability_crit, |
| 71 | + ax=ax, |
| 72 | + ) |
72 | 73 |
|
73 | 74 | title = f"Batch {batch_idx} ({len(df):,})"
|
74 | 75 | ax.set(title=title)
|
75 | 76 |
|
76 | 77 |
|
77 |
| -hist_classified_stable_as_func_of_hull_dist(df, ax=axs.flat[-1], **common_kwargs) |
| 78 | +hist_classified_stable_as_func_of_hull_dist( |
| 79 | + e_above_hull_pred=df[pred_cols].mean(axis=1), |
| 80 | + e_above_hull_true=df.e_above_mp_hull, |
| 81 | + which_energy=which_energy, |
| 82 | + stability_crit=stability_crit, |
| 83 | + ax=axs.flat[-1], |
| 84 | +) |
78 | 85 |
|
79 | 86 | axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
|
80 | 87 | axs.flat[0].legend(frameon=False, loc="upper left")
|
81 | 88 |
|
82 |
| -img_name = f"{today}-wren-wbm-hull-dist-hist-{energy_type=}-{stability_crit=}.pdf" |
| 89 | +img_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf" |
83 | 90 | # plt.savefig(f"{ROOT}/figures/{img_name}")
|
0 commit comments