Skip to content

Commit 612d308

Browse files
committed
update wrenformer preds CSV path
remove models/**/__init__.py files add figures/2022-11-18-wrenformer-wbm-hull-dist-hist-batches.pdf
1 parent b3c3aba commit 612d308

5 files changed

+21
-20
lines changed

matbench_discovery/plot_scripts/hist_classified_stable_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# %%
3232
df = pd.read_csv(
3333
# f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
34-
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
34+
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
3535
).set_index("material_id")
3636

3737
df["e_above_hull"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp

matbench_discovery/plot_scripts/hist_classified_stable_vs_hull_dist_batches.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -39,46 +39,50 @@
3939
f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
4040
).set_index("material_id")
4141
dfs["wrenformer"] = pd.read_csv(
42-
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
42+
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
4343
).set_index("material_id")
4444
dfs["bowsr_megnet"] = pd.read_json(
4545
f"{ROOT}/models/bowsr/2022-09-22-bowsr-megnet-wbm-IS2RE.json.gz"
4646
).set_index("material_id")
4747

4848

4949
# %%
50+
pred_col = "e_form_per_atom_pred"
51+
target_col = "e_form_per_atom"
5052
if "wren" in dfs:
5153
df = dfs["wren"]
5254
pred_cols = df.filter(regex=r"_pred_\d").columns
5355
# make sure we average the expected number of ensemble member predictions
5456
assert len(pred_cols) == 10
55-
df["e_form_per_atom_pred"] = df[pred_cols].mean(axis=1)
57+
df[pred_col] = df[pred_cols].mean(axis=1)
5658
if "m3gnet" in dfs:
5759
df = dfs["m3gnet"]
58-
df["e_form_per_atom_pred"] = df.e_form_per_atom_m3gnet
60+
df[pred_col] = df.e_form_per_atom_m3gnet
5961
if "bowsr_megnet" in dfs:
6062
df = dfs["bowsr_megnet"]
61-
df["e_form_per_atom_pred"] = df.e_form_per_atom_bowsr
63+
df[pred_col] = df.e_form_per_atom_bowsr
64+
if "wrenformer" in dfs:
65+
pred_col = "e_form_per_atom_mp2020_corrected_pred_ens"
6266

6367

6468
# %%
6569
which_energy: WhichEnergy = "true"
6670
stability_crit: StabilityCriterion = "energy"
6771
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
6872

69-
model_name = "m3gnet"
73+
model_name = "wrenformer"
7074
df = dfs[model_name]
7175

7276
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
73-
df["e_form_per_atom"] = df_wbm.e_form_per_atom_mp2020_corrected
77+
df[target_col] = df_wbm.e_form_per_atom_mp2020_corrected # e_form targets
7478

7579

7680
for batch_idx, ax in zip(range(1, 6), axs.flat):
7781
batch_df = df[df.index.str.startswith(f"wbm-step-{batch_idx}-")]
7882
assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")
7983

8084
ax, metrics = hist_classified_stable_vs_hull_dist(
81-
e_above_hull_pred=batch_df.e_form_per_atom_pred - batch_df.e_form_per_atom,
85+
e_above_hull_pred=batch_df[pred_col] - batch_df.e_form_per_atom,
8286
e_above_hull_true=batch_df.e_above_hull_mp,
8387
which_energy=which_energy,
8488
stability_crit=stability_crit,
@@ -93,31 +97,28 @@
9397

9498

9599
ax, metrics = hist_classified_stable_vs_hull_dist(
96-
e_above_hull_pred=df.e_form_per_atom_pred - df.e_form_per_atom,
100+
e_above_hull_pred=df[pred_col] - df.e_form_per_atom,
97101
e_above_hull_true=df.e_above_hull_mp,
98102
which_energy=which_energy,
99103
stability_crit=stability_crit,
100104
ax=axs.flat[-1],
101105
)
102106

103107
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
104-
ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
108+
ax.text(0.02, 0.3, text, fontsize=16, transform=ax.transAxes)
105109

106-
axs.flat[-1].set(title=f"Combined ({len(df.filter(like='e_').dropna()):,})")
110+
axs.flat[-1].set(title=f"All batches ({len(df.filter(like='e_').dropna()):,})")
107111
axs.flat[0].legend(frameon=False, loc="upper left")
108112

109-
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches"
110-
suptitle = img_name.replace("-", "/", 2).replace("-", " ")
111-
fig.suptitle(suptitle, y=1.07, fontsize=16)
113+
fig.suptitle(f"{today} {model_name}", y=1.07, fontsize=16)
112114

113115

114116
# %%
117+
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches"
115118
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")
116119

117120

118121
# %%
119122
pymatviz.density_scatter(
120-
df=dfs["m3gnet"].query("e_form_per_atom < 5"),
121-
x="e_form_per_atom",
122-
y="e_form_per_atom_pred",
123+
df=dfs[model_name].query(f"{target_col} < 5"), x=target_col, y=pred_col
123124
)

matbench_discovery/plot_scripts/precision_recall.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
).set_index("material_id")
2929

3030
dfs["wrenformer"] = pd.read_csv(
31-
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
31+
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
3232
).set_index("material_id")
3333

3434
dfs["bowsr_megnet"] = pd.read_json(

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
data_path = (
2020
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
21-
# f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
21+
# f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
2222
)
2323
df = pd.read_csv(data_path).set_index("material_id")
2424
legend_label = "Wren"

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
).set_index("material_id")
2020

2121
df_wrenformer = pd.read_csv(
22-
f"{ROOT}/models/wrenformer/mp/2022-11-15-wrenformer-IS2RE-preds.csv"
22+
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
2323
).set_index("material_id")
2424

2525

0 commit comments

Comments
 (0)