Skip to content

Commit 0236c36

Browse files
committed
add plot_scripts/all_models_scatter.py
centralize model preds loading into matbench_discovery/plot_scripts/__init__.py, now used in most plot scripts add tests/test_plots_scripts.py
1 parent bac8551 commit 0236c36

File tree

6 files changed

+37
-143
lines changed

6 files changed

+37
-143
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
# %%
2-
import pandas as pd
3-
import pymatviz
4-
52
from matbench_discovery import ROOT, today
6-
from matbench_discovery.plot_scripts import df_wbm
3+
from matbench_discovery.plot_scripts import load_df_wbm_with_preds
74
from matbench_discovery.plots import (
85
StabilityCriterion,
96
WhichEnergy,
@@ -27,59 +24,25 @@
2724

2825

2926
# %%
30-
dfs = {}
31-
dfs["wren"] = pd.read_csv(
32-
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
33-
).set_index("material_id")
34-
dfs["m3gnet"] = pd.read_json(
35-
f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
36-
).set_index("material_id")
37-
dfs["wrenformer"] = pd.read_csv(
38-
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
39-
).set_index("material_id")
40-
dfs["bowsr_megnet"] = pd.read_json(
41-
f"{ROOT}/models/bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.json.gz"
42-
).set_index("material_id")
43-
44-
45-
# %%
46-
pred_col = "e_form_per_atom_pred"
47-
target_col = "e_form_per_atom"
48-
if "wren" in dfs:
49-
df = dfs["wren"]
50-
pred_cols = df.filter(regex=r"_pred_\d").columns
51-
# make sure we average the expected number of ensemble member predictions
52-
assert len(pred_cols) == 10
53-
df[pred_col] = df[pred_cols].mean(axis=1)
54-
if "m3gnet" in dfs:
55-
df = dfs["m3gnet"]
56-
df[pred_col] = df.e_form_per_atom_m3gnet
57-
if "bowsr_megnet" in dfs:
58-
df = dfs["bowsr_megnet"]
59-
df[pred_col] = df.e_form_per_atom_bowsr_megnet
60-
if "wrenformer" in dfs:
61-
pred_col = "e_form_per_atom_mp2020_corrected_pred_ens"
27+
df_wbm = load_df_wbm_with_preds(models="Wren Wrenformer".split()).round(3)
28+
target_col = "e_form_per_atom_mp2020_corrected"
29+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
6230

6331

6432
# %%
6533
which_energy: WhichEnergy = "true"
6634
stability_crit: StabilityCriterion = "energy"
6735
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
6836

69-
model_name = "wrenformer"
70-
df = dfs[model_name]
71-
72-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
73-
df[target_col] = df_wbm.e_form_per_atom_mp2020_corrected # e_form targets
74-
37+
model_name = "Wrenformer"
7538

7639
for batch_idx, ax in zip(range(1, 6), axs.flat):
77-
batch_df = df[df.index.str.startswith(f"wbm-step-{batch_idx}-")]
40+
batch_df = df_wbm[df_wbm.index.str.startswith(f"wbm-step-{batch_idx}-")]
7841
assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")
7942

8043
ax, metrics = hist_classified_stable_vs_hull_dist(
81-
e_above_hull_pred=batch_df[pred_col] - batch_df.e_form_per_atom,
82-
e_above_hull_true=batch_df.e_above_hull_mp,
44+
e_above_hull_pred=batch_df[model_name] - batch_df[target_col],
45+
e_above_hull_true=batch_df[e_above_hull_col],
8346
which_energy=which_energy,
8447
stability_crit=stability_crit,
8548
ax=ax,
@@ -93,8 +56,8 @@
9356

9457

9558
ax, metrics = hist_classified_stable_vs_hull_dist(
96-
e_above_hull_pred=df[pred_col] - df.e_form_per_atom,
97-
e_above_hull_true=df.e_above_hull_mp,
59+
e_above_hull_pred=df_wbm[model_name] - df_wbm[target_col],
60+
e_above_hull_true=df_wbm[e_above_hull_col],
9861
which_energy=which_energy,
9962
stability_crit=stability_crit,
10063
ax=axs.flat[-1],
@@ -103,7 +66,7 @@
10366
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
10467
ax.text(0.02, 0.3, text, fontsize=16, transform=ax.transAxes)
10568

106-
axs.flat[-1].set(title=f"All batches ({len(df.filter(like='e_').dropna()):,})")
69+
axs.flat[-1].set(title=f"All batches ({len(df_wbm[model_name].dropna()):,})")
10770
axs.flat[0].legend(frameon=False, loc="upper left")
10871

10972
fig.suptitle(f"{today} {model_name}", y=1.07, fontsize=16)
@@ -112,9 +75,3 @@
11275
# %%
11376
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches"
11477
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")
115-
116-
117-
# %%
118-
pymatviz.density_scatter(
119-
df=dfs[model_name].query(f"{target_col} < 5"), x=target_col, y=pred_col
120-
)

matbench_discovery/plot_scripts/precision_recall.py

+19-67
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,45 @@
11
# %%
2-
import pandas as pd
32
from sklearn.metrics import f1_score
43

54
from matbench_discovery import ROOT, today
6-
from matbench_discovery.plot_scripts import df_wbm
5+
from matbench_discovery.plot_scripts import load_df_wbm_with_preds
76
from matbench_discovery.plots import StabilityCriterion, cumulative_clf_metric, plt
87

98
__author__ = "Rhys Goodall, Janosh Riebesell"
109

1110

1211
# %%
13-
dfs: dict[str, pd.DataFrame] = {}
14-
for model_name in ("wren", "cgcnn", "voronoi"):
15-
csv_path = (
16-
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
17-
)
18-
df = pd.read_csv(csv_path).set_index("material_id")
19-
dfs[model_name] = df
20-
21-
dfs["m3gnet"] = pd.read_json(
22-
f"{ROOT}/models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
23-
).set_index("material_id")
24-
25-
dfs["wrenformer"] = pd.read_csv(
26-
f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
27-
).set_index("material_id")
12+
models = (
13+
"Wren, CGCNN IS2RE, CGCNN RS2RE, Voronoi IS2RE, Voronoi RS2RE, "
14+
"Wrenformer, MEGNet"
15+
).split(", ")
2816

29-
dfs["bowsr_megnet"] = pd.read_json(
30-
f"{ROOT}/models/bowsr/2022-11-22-bowsr-megnet-wbm-IS2RE.json.gz"
31-
).set_index("material_id")
17+
df_wbm = load_df_wbm_with_preds(models=models).round(3)
3218

33-
print(f"loaded models: {list(dfs)}")
19+
target_col = "e_form_per_atom_mp2020_corrected"
20+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
3421

3522

3623
# %%
3724
stability_crit: StabilityCriterion = "energy"
3825
colors = "tab:blue tab:orange teal tab:pink black red turquoise tab:purple".split()
39-
F1s: dict[str, float] = {}
40-
41-
for model_name, df in sorted(dfs.items()):
42-
if "std" in stability_crit:
43-
# TODO column names to compute standard deviation from are currently hardcoded
44-
# needs to be updated when adding non-aviary models with uncertainty estimation
45-
var_aleatoric = (df.filter(regex=r"_ale_\d") ** 2).mean(axis=1)
46-
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
47-
std_total = (var_epistemic + var_aleatoric) ** 0.5
48-
else:
49-
std_total = None
50-
51-
try:
52-
if model_name == "m3gnet":
53-
model_preds = df.e_form_m3gnet
54-
elif "wrenformer" in model_name:
55-
model_preds = df.e_form_per_atom_pred_ens
56-
elif len(pred_cols := df.filter(like="e_form_pred").columns) >= 1:
57-
# Voronoi+RF has single prediction column, Wren and CGCNN each have 10
58-
# other cases are unexpected
59-
assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
60-
model_preds = df[pred_cols].mean(axis=1)
61-
elif model_name == "bowsr_megnet":
62-
model_preds = df.e_form_per_atom_bowsr_megnet
63-
else:
64-
raise ValueError(f"Unhandled {model_name = }")
65-
except AttributeError as exc:
66-
raise KeyError(f"{model_name = }") from exc
67-
68-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
69-
df["e_form_per_atom"] = df_wbm.e_form_per_atom_mp2020_corrected
70-
df["e_above_hull_pred"] = model_preds - df.e_form_per_atom
71-
if n_nans := df.isna().values.sum() > 0:
72-
assert n_nans < 10, f"{model_name=} has {n_nans=}"
73-
df = df.dropna()
74-
75-
F1 = f1_score(df.e_above_hull_mp < 0, df.e_above_hull_pred < 0)
76-
F1s[model_name] = F1
7726

7827

7928
# %%
8029
fig, (ax_prec, ax_recall) = plt.subplots(1, 2, figsize=(15, 7), sharey=True)
8130

82-
for (model_name, F1), color in zip(sorted(F1s.items(), key=lambda x: x[1]), colors):
83-
df = dfs[model_name]
84-
e_above_hull_error = df.e_above_hull_pred + df.e_above_hull_mp
85-
e_above_hull_true = df.e_above_hull_mp
31+
for model_name, color in zip(models, colors):
32+
33+
e_above_hull_pred = df_wbm[model_name] - df_wbm[target_col]
34+
35+
F1 = f1_score(df_wbm[e_above_hull_col] < 0, e_above_hull_pred < 0)
36+
37+
e_above_hull_error = e_above_hull_pred + df_wbm[e_above_hull_col]
8638
cumulative_clf_metric(
8739
e_above_hull_error,
88-
e_above_hull_true,
40+
df_wbm[e_above_hull_col],
8941
color=color,
90-
label=f"{model_name}\n{F1=:.2}",
42+
label=f"{model_name}\n{F1=:.3}",
9143
project_end_point="xy",
9244
stability_crit=stability_crit,
9345
ax=ax_prec,
@@ -96,9 +48,9 @@
9648

9749
cumulative_clf_metric(
9850
e_above_hull_error,
99-
e_above_hull_true,
51+
df_wbm[e_above_hull_col],
10052
color=color,
101-
label=f"{model_name}\n{F1=:.2}",
53+
label=f"{model_name}\n{F1=:.3}",
10254
project_end_point="xy",
10355
stability_crit=stability_crit,
10456
ax=ax_recall,

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

-11
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111

1212
# %%
13-
markers = ["o", "v", "^", "H", "D", ""]
14-
1513
data_path = (
1614
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
1715
# f"{ROOT}/models/wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
@@ -21,15 +19,6 @@
2119

2220

2321
# %%
24-
# rare = "all"
25-
# from pymatgen.core import Composition
26-
# rare = "no-lanthanides"
27-
# df["contains_rare_earths"] = df.composition.map(
28-
# lambda x: any(el.is_rare_earth_metal for el in Composition(x))
29-
# )
30-
# df = df.query("~contains_rare_earths")
31-
32-
3322
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
3423

3524
assert all(n_nans := df.isna().sum() == 0), f"Found {n_nans} NaNs"

matbench_discovery/plots.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
e_above_hull_mp="Energy above MP convex hull (eV/atom)",
3636
e_above_hull_error="Error in energy above convex hull (eV/atom)",
3737
vol_diff="Volume difference (A^3)",
38+
e_form_per_atom_mp2020_corrected="Formation energy (eV/atom)",
39+
e_form_per_atom_pred="Predicted formation energy (eV/atom)",
40+
material_id="Material ID",
41+
band_gap="Band gap (eV)",
42+
formula="Formula",
3843
)
3944
model_labels = dict(
4045
wren="Wren",
@@ -254,10 +259,6 @@ def rolling_mae_vs_hull_dist(
254259
"""
255260
ax = ax or plt.gca()
256261

257-
for series in (e_above_hull_pred, e_above_hull_true):
258-
n_nans = series.isna().sum()
259-
assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"
260-
261262
is_fresh_ax = len(ax.lines) == 0
262263

263264
bins = np.arange(*x_lim, bin_width)
@@ -387,10 +388,6 @@ def cumulative_clf_metric(
387388
"""
388389
ax = ax or plt.gca()
389390

390-
for series in (e_above_hull_error, e_above_hull_true):
391-
n_nans = series.isna().sum()
392-
assert n_nans == 0, f"{n_nans:,} NaNs in {series.name}"
393-
394391
e_above_hull_error = e_above_hull_error.sort_values()
395392
e_above_hull_true = e_above_hull_true.loc[e_above_hull_error.index]
396393

models/m3gnet/eda_wbm_pre_vs_post_m3gnet_relaxation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@
9999
df["m3gnet_vol_diff"] = df.m3gnet_volume - df.final_wbm_volume
100100
df["dft_vol_diff"] = df.initial_wbm_volume - df.final_wbm_volume
101101
fig = px.histogram(
102-
pd.melt(
103-
df,
102+
df.melt(
104103
value_vars=["m3gnet", "dft"],
105104
value_name="vol_diff",
106105
var_name="method",

models/megnet/test_megnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
"""
1919
To slurm submit this file: python path/to/file.py slurm-submit
20-
Requires Megnet installation: pip install megnet
20+
Requires MEGNet installation: pip install megnet
2121
https://github.com/materialsvirtuallab/megnet
2222
"""
2323

0 commit comments

Comments
 (0)