|
14 | 14 |
|
15 | 15 |
|
16 | 16 | # %%
|
17 |
| -rare = "all" |
18 |
| - |
19 | 17 | df_wren = pd.read_csv(
|
20 | 18 | f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
|
21 | 19 | ).set_index("material_id")
|
22 | 20 |
|
23 | 21 | df_wrenformer = pd.read_csv(
|
24 |
| - f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv" |
| 22 | + f"{ROOT}/models/wrenformer/mp/2022-11-15-wrenformer-IS2RE-preds.csv" |
25 | 23 | ).set_index("material_id")
|
26 | 24 |
|
27 | 25 |
|
28 |
| -df_wrenformer["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp |
29 |
| -assert df_wrenformer.e_above_hull_mp.isna().sum() == 0 |
| 26 | +# %% |
| 27 | +model_name = "wren" |
| 28 | +df = {"wren": df_wren, "wrenformer": df_wrenformer}[model_name] |
| 29 | + |
| 30 | +df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp |
| 31 | +assert df.e_above_hull_mp.isna().sum() == 0 |
30 | 32 |
|
31 |
| -target_col = "e_form_per_atom" |
32 |
| -# target_col = "e_form_target" |
| 33 | +possible_targets = ( |
| 34 | + "e_form_per_atom_mp2020_corrected e_form_per_atom e_form_target".split() |
| 35 | +) |
| 36 | +target_col = next(filter(lambda x: x in df, possible_targets)) |
33 | 37 |
|
34 | 38 | # make sure we average the expected number of ensemble member predictions
|
35 |
| -assert df_wrenformer.filter(regex=r"_pred_\d").shape[1] == 10 |
| 39 | +assert df.filter(regex=r"_pred_\d").shape[1] == 10 |
36 | 40 |
|
37 |
| -df_wrenformer["e_above_hull_pred"] = ( |
38 |
| - df_wrenformer.filter(regex=r"_pred_\d").mean(axis=1) - df_wrenformer[target_col] |
39 |
| -) |
| 41 | +df["e_above_hull_pred"] = df.filter(regex=r"_pred_\d").mean(axis=1) - df[target_col] |
40 | 42 |
|
41 | 43 |
|
42 | 44 | # %%
|
|
45 | 47 | assert len(markers) == 5 # number of WBM rounds of element substitution
|
46 | 48 |
|
47 | 49 | for idx, marker in enumerate(markers, 1):
|
48 |
| - df = df_wrenformer[df_wrenformer.index.str.startswith(f"wbm-step-{idx}")] |
49 |
| - title = f"Batch {idx} ({len(df.filter(like='e_').dropna()):,})" |
50 |
| - assert 1e4 < len(df) < 1e5, print(f"{len(df) = :,}") |
| 50 | + # select all rows from WBM step=idx |
| 51 | + df_step = df[df.index.str.startswith(f"wbm-step-{idx}")] |
| 52 | + |
| 53 | + title = f"Batch {idx} ({len(df_step.filter(like='e_').dropna()):,})" |
| 54 | + assert 1e4 < len(df_step) < 1e5, print(f"{len(df_step) = :,}") |
51 | 55 |
|
52 | 56 | rolling_mae_vs_hull_dist(
|
53 |
| - e_above_hull_pred=df.e_above_hull_pred, |
54 |
| - e_above_hull_true=df.e_above_hull_mp, |
| 57 | + e_above_hull_pred=df_step.e_above_hull_pred, |
| 58 | + e_above_hull_true=df_step.e_above_hull_mp, |
55 | 59 | ax=ax,
|
56 | 60 | label=title,
|
57 | 61 | marker=marker,
|
|
62 | 66 |
|
63 | 67 |
|
64 | 68 | ax.legend(loc="lower right", frameon=False)
|
| 69 | +ax.set(title=f"{today} model={model_name}") |
65 | 70 |
|
66 | 71 |
|
67 |
| -img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist-wbm-batches-{rare=}.pdf" |
68 |
| -# fig.savefig(img_path) |
| 72 | +img_name = f"{today}-{model_name}-rolling-mae-vs-hull-dist-wbm-batches" |
| 73 | +fig.savefig(f"{ROOT}/figures/{img_name}.pdf") |
0 commit comments