Skip to content

Commit de61b89

Browse files
committed
add figures/2022-11-17-m3gnet-wbm-hull-dist-hist-batches.pdf
add 2022-11-17-{wren,wrenformer}-rolling-mae-vs-hull-dist-wbm-batches update plot script file paths
1 parent 06911dc commit de61b89

6 files changed

+35
-32
lines changed

matbench_discovery/plot_scripts/hist_classified_stable_vs_hull_dist_batches.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
3737
).set_index("material_id")
3838
dfs["m3gnet"] = pd.read_json(
39-
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
39+
f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
4040
).set_index("material_id")
4141
dfs["wrenformer"] = pd.read_csv(
4242
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
@@ -55,7 +55,7 @@
5555
df["e_form_per_atom_pred"] = df[pred_cols].mean(axis=1)
5656
if "m3gnet" in dfs:
5757
df = dfs["m3gnet"]
58-
df["e_form_per_atom_pred"] = df.e_form_m3gnet
58+
df["e_form_per_atom_pred"] = df.e_form_per_atom_m3gnet
5959
if "bowsr_megnet" in dfs:
6060
df = dfs["bowsr_megnet"]
6161
df["e_form_per_atom_pred"] = df.e_form_per_atom_bowsr
@@ -106,7 +106,7 @@
106106
axs.flat[-1].set(title=f"Combined ({len(df.filter(like='e_').dropna()):,})")
107107
axs.flat[0].legend(frameon=False, loc="upper left")
108108

109-
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
109+
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches"
110110
suptitle = img_name.replace("-", "/", 2).replace("-", " ")
111111
fig.suptitle(suptitle, y=1.07, fontsize=16)
112112

@@ -117,9 +117,7 @@
117117

118118
# %%
119119
pymatviz.density_scatter(
120-
dfs["wren"].dropna().e_form_per_atom_pred, dfs["wren"].dropna().e_form_per_atom
121-
)
122-
123-
pymatviz.density_scatter(
124-
dfs["m3gnet"].dropna().e_form_per_atom_pred, dfs["m3gnet"].dropna().e_form_per_atom
120+
df=dfs["m3gnet"].query("e_form_per_atom < 5"),
121+
x="e_form_per_atom",
122+
y="e_form_per_atom_pred",
125123
)

matbench_discovery/plot_scripts/precision_recall.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
dfs[model_name] = df
2525

2626
dfs["m3gnet"] = pd.read_json(
27-
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
27+
f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
2828
).set_index("material_id")
2929

3030
dfs["wrenformer"] = pd.read_csv(

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
# %%
28-
rare = "all"
28+
# rare = "all"
2929
# from pymatgen.core import Composition
3030
# rare = "no-lanthanides"
3131
# df["contains_rare_earths"] = df.composition.map(
@@ -62,5 +62,5 @@
6262
fig.set_size_inches(10, 9)
6363
ax.legend(loc="lower right", frameon=False)
6464

65-
img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist-{rare=}.pdf"
65+
img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist.pdf"
6666
# fig.savefig(img_path)

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,31 @@
1414

1515

1616
# %%
17-
rare = "all"
18-
1917
df_wren = pd.read_csv(
2018
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
2119
).set_index("material_id")
2220

2321
df_wrenformer = pd.read_csv(
24-
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
22+
f"{ROOT}/models/wrenformer/mp/2022-11-15-wrenformer-IS2RE-preds.csv"
2523
).set_index("material_id")
2624

2725

28-
df_wrenformer["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
29-
assert df_wrenformer.e_above_hull_mp.isna().sum() == 0
26+
# %%
27+
model_name = "wren"
28+
df = {"wren": df_wren, "wrenformer": df_wrenformer}[model_name]
29+
30+
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
31+
assert df.e_above_hull_mp.isna().sum() == 0
3032

31-
target_col = "e_form_per_atom"
32-
# target_col = "e_form_target"
33+
possible_targets = (
34+
"e_form_per_atom_mp2020_corrected e_form_per_atom e_form_target".split()
35+
)
36+
target_col = next(filter(lambda x: x in df, possible_targets))
3337

3438
# make sure we average the expected number of ensemble member predictions
35-
assert df_wrenformer.filter(regex=r"_pred_\d").shape[1] == 10
39+
assert df.filter(regex=r"_pred_\d").shape[1] == 10
3640

37-
df_wrenformer["e_above_hull_pred"] = (
38-
df_wrenformer.filter(regex=r"_pred_\d").mean(axis=1) - df_wrenformer[target_col]
39-
)
41+
df["e_above_hull_pred"] = df.filter(regex=r"_pred_\d").mean(axis=1) - df[target_col]
4042

4143

4244
# %%
@@ -45,13 +47,15 @@
4547
assert len(markers) == 5 # number of WBM rounds of element substitution
4648

4749
for idx, marker in enumerate(markers, 1):
48-
df = df_wrenformer[df_wrenformer.index.str.startswith(f"wbm-step-{idx}")]
49-
title = f"Batch {idx} ({len(df.filter(like='e_').dropna()):,})"
50-
assert 1e4 < len(df) < 1e5, print(f"{len(df) = :,}")
50+
# select all rows from WBM step=idx
51+
df_step = df[df.index.str.startswith(f"wbm-step-{idx}")]
52+
53+
title = f"Batch {idx} ({len(df_step.filter(like='e_').dropna()):,})"
54+
assert 1e4 < len(df_step) < 1e5, print(f"{len(df_step) = :,}")
5155

5256
rolling_mae_vs_hull_dist(
53-
e_above_hull_pred=df.e_above_hull_pred,
54-
e_above_hull_true=df.e_above_hull_mp,
57+
e_above_hull_pred=df_step.e_above_hull_pred,
58+
e_above_hull_true=df_step.e_above_hull_mp,
5559
ax=ax,
5660
label=title,
5761
marker=marker,
@@ -62,7 +66,8 @@
6266

6367

6468
ax.legend(loc="lower right", frameon=False)
69+
ax.set(title=f"{today} model={model_name}")
6570

6671

67-
img_path = f"{ROOT}/figures/{today}-rolling-mae-vs-hull-dist-wbm-batches-{rare=}.pdf"
68-
# fig.savefig(img_path)
72+
img_name = f"{today}-{model_name}-rolling-mae-vs-hull-dist-wbm-batches"
73+
fig.savefig(f"{ROOT}/figures/{img_name}.pdf")

models/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
# %%
33-
is2re_path_old = f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
33+
is2re_path_old = f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
3434
df_m3gnet_is2re_old = pd.read_json(is2re_path_old).set_index("material_id")
3535

3636
is2re_path = f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
@@ -226,5 +226,5 @@
226226
# %% write df back to compressed JSON
227227
# filter out columns containing 'rs2re'
228228
# df_m3gnet_is2re.reset_index().filter(regex="^((?!rs2re).)*$").to_json(
229-
# f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE-2.json.gz"
229+
# f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE-2.json.gz"
230230
# ).set_index("material_id")

models/m3gnet/join_m3gnet_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,5 @@
7474
out_path = f"{ROOT}/models/m3gnet/{today}-m3gnet-wbm-{task_type}.json.gz"
7575
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
7676

77-
# out_path = f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
77+
# out_path = f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
7878
# df_m3gnet = pd.read_json(out_path).set_index("material_id")

0 commit comments

Comments
 (0)