Skip to content

Commit c487910

Browse files
committed
make model switching easy in hist_classified_stable_as_func_of_hull_dist_batches.py
1 parent 42a7909 commit c487910

6 files changed

+75
-64
lines changed

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
4141

4242
# download wbm-steps-summary.csv (23.31 MB)
43-
df_summary = pd.read_csv(
43+
df_wbm = pd.read_csv(
4444
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
4545
).set_index("material_id")
4646

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+51-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import datetime
33

44
import pandas as pd
5+
import pymatviz
56

67
from mb_discovery import ROOT
78
from mb_discovery.plots import (
@@ -29,48 +30,73 @@
2930

3031

3132
# %%
32-
df = pd.read_csv(
33+
dfs = {}
34+
dfs["wren"] = pd.read_csv(
3335
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
3436
).set_index("material_id")
37+
dfs["m3gnet"] = pd.read_json(
38+
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
39+
).set_index("material_id")
40+
dfs["Wrenformer"] = pd.read_csv(
41+
f"{ROOT}/models/wrenformer/mp/"
42+
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
43+
).set_index("material_id")
44+
3545

3646
df_hull = pd.read_csv(
3747
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
3848
).set_index("material_id")
3949

40-
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
41-
4250
# download wbm-steps-summary.csv (23.31 MB)
43-
df_summary = pd.read_csv(
51+
df_wbm = pd.read_csv(
4452
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
4553
).set_index("material_id")
4654

4755

56+
dfs["m3gnet"] = dfs.pop("M3Gnet")
57+
58+
59+
# %%
60+
if "wren" in dfs:
61+
df = dfs["wren"]
62+
pred_cols = df.filter(regex=r"_pred_\d").columns
63+
# make sure we average the expected number of ensemble member predictions
64+
assert len(pred_cols) == 10
65+
df["e_form_per_atom_pred"] = df[pred_cols].mean(axis=1)
66+
if "m3gnet" in dfs:
67+
df = dfs["m3gnet"]
68+
df["e_form_per_atom_pred"] = df.e_form_ppd_2022_01_25
69+
70+
4871
# %%
4972
which_energy: WhichEnergy = "true"
5073
stability_crit: StabilityCriterion = "energy"
51-
df["wbm_batch"] = df.index.str.split("-").str[2]
5274
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
5375

54-
# make sure we average the expected number of ensemble member predictions
55-
pred_cols = df.filter(regex=r"_pred_\d").columns
56-
assert len(pred_cols) == 10
76+
df = dfs[(model_name := "wren")]
5777

78+
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
79+
df["e_form_per_atom"] = df_wbm.e_form_per_atom
80+
81+
82+
for batch_idx, ax in zip(range(1, 6), axs.flat):
83+
batch_df = df[df.index.str.startswith(f"wbm-step-{batch_idx}-")]
84+
assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")
5885

59-
for (batch_idx, batch_df), ax in zip(df.groupby("wbm_batch"), axs.flat):
6086
hist_classified_stable_as_func_of_hull_dist(
61-
e_above_hull_pred=batch_df[pred_cols].mean(axis=1) - batch_df.e_form_target,
87+
e_above_hull_pred=batch_df.e_form_per_atom_pred - batch_df.e_form_per_atom,
6288
e_above_hull_true=batch_df.e_above_mp_hull,
6389
which_energy=which_energy,
6490
stability_crit=stability_crit,
6591
ax=ax,
6692
)
6793

68-
title = f"Batch {batch_idx} ({len(df):,})"
94+
title = f"Batch {batch_idx} ({len(batch_df):,})"
6995
ax.set(title=title)
7096

7197

7298
hist_classified_stable_as_func_of_hull_dist(
73-
e_above_hull_pred=df[pred_cols].mean(axis=1),
99+
e_above_hull_pred=df.e_form_per_atom_pred - df.e_form_per_atom,
74100
e_above_hull_true=df.e_above_mp_hull,
75101
which_energy=which_energy,
76102
stability_crit=stability_crit,
@@ -80,5 +106,17 @@
80106
axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
81107
axs.flat[0].legend(frameon=False, loc="upper left")
82108

83-
img_name = f"{today}-wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf"
109+
img_name = (
110+
f"{today}-{model_name}-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf"
111+
)
84112
# plt.savefig(f"{ROOT}/figures/{img_name}")
113+
114+
115+
# %%
116+
pymatviz.density_scatter(
117+
dfs["wren"].dropna().e_form_per_atom_pred, dfs["wren"].dropna().e_form_per_atom
118+
)
119+
120+
pymatviz.density_scatter(
121+
dfs["m3gnet"].dropna().e_form_per_atom_pred, dfs["m3gnet"].dropna().e_form_per_atom
122+
)

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
for idx, marker in enumerate(markers, 1):
4545
title = f"Batch {idx}"
4646
df = df_wbm[df_wbm.index.str.startswith(f"wbm-step-{idx}")]
47+
assert 1e4 < len(df) < 1e5, print(f"{len(df) = :,}")
4748

4849
rolling_mae_vs_hull_dist(
4950
e_above_hull_pred=df.e_above_hull_pred,

mb_discovery/plots.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,15 @@ def hist_classified_stable_as_func_of_hull_dist(
129129
stacked=True,
130130
)
131131

132-
n_true_pos, n_false_pos, n_true_neg, n_false_neg = (
133-
len(true_pos),
134-
len(false_pos),
135-
len(true_neg),
136-
len(false_neg),
132+
n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
133+
len, (true_pos, false_pos, true_neg, false_neg)
137134
)
138135
# null = (tp + fn) / (tp + tn + fp + fn)
139136
precision = n_true_pos / (n_true_pos + n_false_pos)
140137

141-
assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len(e_above_hull_true)
138+
# assert (n_all := n_true_pos + n_false_pos + n_true_neg + n_false_neg) == len(
139+
# e_above_hull_true
140+
# ), f"{n_all} != {len(e_above_hull_true)}"
142141

143142
# recall = n_true_pos / n_total_pos
144143
# f"Prevalence = {null:.2f}\n{precision = :.2f}\n{recall = :.2f}",

models/bowsr/join_bowsr_results.py

+14-27
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from glob import glob
77

88
import pandas as pd
9+
import pymatviz
910
from pymatgen.core import Structure
1011
from tqdm import tqdm
1112

1213
from mb_discovery import ROOT, as_dict_handler
13-
from mb_discovery.plots import hist_classified_stable_as_func_of_hull_dist
1414

1515
__author__ = "Janosh Riebesell"
1616
__date__ = "2022-09-22"
@@ -22,7 +22,7 @@
2222
module_dir = os.path.dirname(__file__)
2323
task_type = "IS2RE"
2424
date = "2022-09-22"
25-
glob_pattern = f"{date}-bowsr-wbm-{task_type}/*.json.gz"
25+
glob_pattern = f"{date}-bowsr-megnet-wbm-{task_type}/*.json.gz"
2626
file_paths = sorted(glob(f"{module_dir}/{glob_pattern}"))
2727
print(f"Found {len(file_paths):,} files for {glob_pattern = }")
2828

@@ -37,12 +37,7 @@
3737
continue
3838
# keep whole dataframe in memory
3939
df = pd.read_json(file_path).set_index("material_id")
40-
col_map = dict(
41-
structure_pred="structure_bowsr",
42-
energy_pred="energy_bowsr",
43-
e_form_per_atom_pred="e_form_per_atom_bowsr",
44-
)
45-
df = df.rename(columns=col_map)
40+
4641
df["structure_bowsr"] = df.structure_bowsr.map(Structure.from_dict)
4742
df["formula"] = df.structure_bowsr.map(lambda x: x.formula)
4843
df["volume"] = df.structure_bowsr.map(lambda x: x.volume)
@@ -54,39 +49,31 @@
5449
df_bowsr = pd.concat(dfs.values())
5550

5651

57-
# %%
52+
# %% compare against WBM formation energy targets to make sure we got sensible results
5853
df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
5954
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
6055
).set_index("material_id")
6156

6257
df_bowsr["e_form_wbm"] = df_wbm.e_form_per_atom
6358

59+
print(f"{len(df_bowsr) - len(df_wbm) = :,} = {len(df_bowsr):,} - {len(df_wbm):,}")
60+
6461

6562
# %%
6663
df_bowsr.hist(bins=200, figsize=(18, 12))
6764
df_bowsr.isna().sum()
6865

6966

7067
# %%
71-
out_path = f"{ROOT}/models/bowsr/{today}-bowsr-wbm-{task_type}.json.gz"
72-
df_bowsr.reset_index().to_json(out_path, default_handler=as_dict_handler)
73-
74-
out_path = f"{ROOT}/models/bowsr/2022-08-16-bowsr-wbm-IS2RE.json.gz"
75-
df_bowsr = pd.read_json(out_path).set_index("material_id")
68+
pymatviz.density_scatter(
69+
df_bowsr.dropna().e_form_per_atom_bowsr,
70+
df_bowsr.dropna().e_form_wbm,
71+
)
7672

7773

7874
# %%
79-
df_hull = pd.read_csv(
80-
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
81-
).set_index("material_id")
82-
df_bowsr["e_above_mp_hull"] = df_hull.e_above_mp_hull
83-
df_bowsr["e_above_hull_pred"] = ( # TODO fix this incorrect e_above_hull_pred
84-
df_bowsr["e_form_per_atom_bowsr"] - df_bowsr["e_above_mp_hull"]
85-
)
86-
87-
ax_hull_dist_hist = hist_classified_stable_as_func_of_hull_dist(
88-
e_above_hull_pred=df_bowsr.e_above_hull_pred,
89-
e_above_hull_true=df_bowsr.e_above_mp_hull,
90-
)
75+
out_path = f"{ROOT}/models/bowsr/{today}-bowsr-megnet-wbm-{task_type}.json.gz"
76+
df_bowsr.reset_index().to_json(out_path, default_handler=as_dict_handler)
9177

92-
# ax_hull_dist_hist.figure.savefig(f"{ROOT}/plots/{today}-bowsr-wbm-hull-dist-hist.pdf")
78+
# out_path = f"{ROOT}/models/bowsr/2022-08-16-bowsr-megnet-wbm-IS2RE.json.gz"
79+
# df_bowsr = pd.read_json(out_path).set_index("material_id")

models/m3gnet/join_m3gnet_relax_results.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from tqdm import tqdm
1616

1717
from mb_discovery import ROOT, as_dict_handler
18-
from mb_discovery.plots import hist_classified_stable_as_func_of_hull_dist
1918

2019
__author__ = "Janosh Riebesell"
2120
__date__ = "2022-08-16"
@@ -86,7 +85,7 @@
8685
]
8786

8887

89-
# %%
88+
# %% compare against WBM formation energy targets to make sure we got sensible results
9089
df_hull = pd.read_csv(
9190
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
9291
).set_index("material_id")
@@ -119,18 +118,5 @@
119118
out_path = f"{ROOT}/models/m3gnet/{today}-m3gnet-wbm-relax-{task_type}.json.gz"
120119
df_m3gnet.reset_index().to_json(out_path, default_handler=as_dict_handler)
121120

122-
out_path = f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
123-
df_m3gnet = pd.read_json(out_path).set_index("material_id")
124-
125-
126-
# %%
127-
df_m3gnet["e_above_hull_pred"] = ( # TODO fix this incorrect e_above_hull_pred
128-
df_m3gnet["e_form_m3gnet_from_ppd"] - df_m3gnet["e_above_mp_hull"]
129-
)
130-
131-
ax_hull_dist_hist = hist_classified_stable_as_func_of_hull_dist(
132-
e_above_hull_pred=df_m3gnet.e_above_hull_pred,
133-
e_above_hull_true=df_m3gnet.e_above_mp_hull,
134-
)
135-
136-
# ax_hull_dist_hist.figure.savefig(f"{ROOT}/plots/{today}-m3gnet-wbm-hull-dist-hist.pdf")
121+
# out_path = f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-relax-results-IS2RE.json.gz"
122+
# df_m3gnet = pd.read_json(out_path).set_index("material_id")

0 commit comments

Comments
 (0)