Skip to content

Commit 500e670

Browse files
committed
add matbench_discovery/metrics.py to centralize computing metrics for plotting scripts
remove np.random.seed(0) from test_stable_metrics for increased randomness test regression metrics against sklearn in test_stable_metrics
1 parent 248a79b commit 500e670

17 files changed

+309
-318
lines changed

matbench_discovery/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
today = timestamp.split("@")[0]
2424

2525
# load docs, repo, package URLs from package.json
26-
print(f"{ROOT=}")
27-
2826
with open(f"{ROOT}/site/package.json") as file:
2927
pkg = json.load(file)
3028
pypi_keys_to_npm = dict(Docs="homepage", Repo="repository", Package="package")

matbench_discovery/data.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535

3636
def as_dict_handler(obj: Any) -> dict[str, Any] | None:
3737
"""Pass this to json.dump(default=) or as pandas.to_json(default_handler=) to
38-
convert Python classes with a as_dict() method to dictionaries on serialization.
39-
Objects without a as_dict() method are replaced with None in the serialized data.
38+
serialize Python classes with as_dict(). Warning: Objects without a as_dict() method
39+
are replaced with None in the serialized data.
4040
"""
4141
try:
4242
return obj.as_dict() # all MSONable objects implement as_dict()
@@ -144,6 +144,7 @@ def load_train_test(
144144
"Wrenformer": "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv",
145145
"MEGNet": "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv",
146146
"M3GNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
147+
"M3GNet MEGNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
147148
"BOWSR MEGNet": "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv",
148149
}
149150

matbench_discovery/energy.py

-91
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import itertools
22
from collections.abc import Sequence
33

4-
import numpy as np
54
import pandas as pd
65
from pymatgen.analysis.phase_diagram import Entry, PDEntry
76
from pymatgen.core import Composition
87
from pymatgen.util.typing import EntryLike
9-
from sklearn.metrics import r2_score
108
from tqdm import tqdm
119

1210
from matbench_discovery import ROOT
@@ -120,92 +118,3 @@ def get_e_form_per_atom(
120118
form_energy = energy - sum(comp[el] * e_refs[str(el)] for el in comp)
121119

122120
return form_energy / comp.num_atoms
123-
124-
125-
def classify_stable(
126-
e_above_hull_true: pd.Series,
127-
e_above_hull_pred: pd.Series,
128-
stability_threshold: float = 0,
129-
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
130-
"""Classify model stability predictions as true/false positive/negatives (usually
131-
w.r.t DFT-ground truth labels). All energies are assumed to be in eV/atom
132-
(but shouldn't really matter as long as they're consistent).
133-
134-
Args:
135-
e_above_hull_true (pd.Series): Ground truth energy above convex hull values.
136-
e_above_hull_pred (pd.Series): Model predicted energy above convex hull values.
137-
stability_threshold (float, optional): Maximum energy above convex hull for a
138-
material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to
139-
0, meaning a material has to be directly on the hull to be called stable.
140-
Negative values mean a material has to pull the known hull down by that
141-
amount to count as stable. Few materials lie below the known hull, so only
142-
negative values very close to 0 make sense.
143-
144-
Returns:
145-
tuple[TP, FN, FP, TN]: Indices as pd.Series for true positives,
146-
false negatives, false positives and true negatives (in this order).
147-
"""
148-
actual_pos = e_above_hull_true <= stability_threshold
149-
actual_neg = e_above_hull_true > stability_threshold
150-
model_pos = e_above_hull_pred <= stability_threshold
151-
model_neg = e_above_hull_pred > stability_threshold
152-
153-
true_pos = actual_pos & model_pos
154-
false_neg = actual_pos & model_neg
155-
false_pos = actual_neg & model_pos
156-
true_neg = actual_neg & model_neg
157-
158-
return true_pos, false_neg, false_pos, true_neg
159-
160-
161-
def stable_metrics(
162-
true: Sequence[float], pred: Sequence[float], stability_threshold: float = 0
163-
) -> dict[str, float]:
164-
"""
165-
Get a dictionary of stability prediction metrics. Mostly binary classification
166-
metrics, but also MAE, RMSE and R2.
167-
168-
Args:
169-
true (list[float]): true energy values
170-
pred (list[float]): predicted energy values
171-
stability_threshold (float): Where to place stability threshold relative to
172-
convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0.
173-
174-
Note: Could be replaced by sklearn.metrics.classification_report() which takes
175-
binary labels. I.e. classification_report(true > 0, pred > 0, output_dict=True)
176-
should give equivalent results.
177-
178-
Returns:
179-
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,
180-
Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
181-
"""
182-
true_pos, false_neg, false_pos, true_neg = classify_stable(
183-
true, pred, stability_threshold
184-
)
185-
186-
n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
187-
sum, (true_pos, false_pos, true_neg, false_neg)
188-
)
189-
190-
n_total_pos = n_true_pos + n_false_neg
191-
prevalence = n_total_pos / len(true) # null rate
192-
precision = n_true_pos / (n_true_pos + n_false_pos)
193-
recall = n_true_pos / n_total_pos
194-
195-
is_nan = np.isnan(true) | np.isnan(pred)
196-
true, pred = np.array(true)[~is_nan], np.array(pred)[~is_nan]
197-
198-
return dict(
199-
DAF=precision / prevalence,
200-
Precision=precision,
201-
Recall=recall,
202-
Accuracy=(n_true_pos + n_true_neg) / len(true),
203-
F1=2 * (precision * recall) / (precision + recall),
204-
TPR=n_true_pos / (n_true_pos + n_false_neg),
205-
FPR=n_false_pos / (n_true_neg + n_false_pos),
206-
TNR=n_true_neg / (n_true_neg + n_false_pos),
207-
FNR=n_false_neg / (n_true_pos + n_false_neg),
208-
MAE=np.abs(true - pred).mean(),
209-
RMSE=((true - pred) ** 2).mean() ** 0.5,
210-
R2=r2_score(true, pred),
211-
)

matbench_discovery/metrics.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Centralize data-loading and computing metrics for plotting scripts"""
2+
3+
from collections.abc import Sequence
4+
5+
import numpy as np
6+
import pandas as pd
7+
from sklearn.metrics import r2_score
8+
9+
from matbench_discovery.data import load_df_wbm_preds
10+
11+
12+
def classify_stable(
13+
e_above_hull_true: pd.Series,
14+
e_above_hull_pred: pd.Series,
15+
stability_threshold: float | None = 0,
16+
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
17+
"""Classify model stability predictions as true/false positive/negatives (usually
18+
w.r.t DFT-ground truth labels). All energies are assumed to be in eV/atom
19+
(but shouldn't really matter as long as they're consistent).
20+
21+
Args:
22+
e_above_hull_true (pd.Series): Ground truth energy above convex hull values.
23+
e_above_hull_pred (pd.Series): Model predicted energy above convex hull values.
24+
stability_threshold (float | None, optional): Maximum energy above convex hull for a
25+
material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to
26+
0, meaning a material has to be directly on the hull to be called stable.
27+
Negative values mean a material has to pull the known hull down by that
28+
amount to count as stable. Few materials lie below the known hull, so only
29+
negative values very close to 0 make sense.
30+
31+
Returns:
32+
tuple[TP, FN, FP, TN]: Indices as pd.Series for true positives,
33+
false negatives, false positives and true negatives (in this order).
34+
"""
35+
actual_pos = e_above_hull_true <= (stability_threshold or 0) # guard against None
36+
actual_neg = e_above_hull_true > (stability_threshold or 0)
37+
model_pos = e_above_hull_pred <= (stability_threshold or 0)
38+
model_neg = e_above_hull_pred > (stability_threshold or 0)
39+
40+
true_pos = actual_pos & model_pos
41+
false_neg = actual_pos & model_neg
42+
false_pos = actual_neg & model_pos
43+
true_neg = actual_neg & model_neg
44+
45+
return true_pos, false_neg, false_pos, true_neg
46+
47+
48+
def stable_metrics(
49+
true: Sequence[float], pred: Sequence[float], stability_threshold: float = 0
50+
) -> dict[str, float]:
51+
"""
52+
Get a dictionary of stability prediction metrics. Mostly binary classification
53+
metrics, but also MAE, RMSE and R2.
54+
55+
Args:
56+
true (list[float]): true energy values
57+
pred (list[float]): predicted energy values
58+
stability_threshold (float): Where to place stability threshold relative to
59+
convex hull in eV/atom, usually 0 or 0.1 eV. Defaults to 0.
60+
61+
Note: Could be replaced by sklearn.metrics.classification_report() which takes
62+
binary labels. I.e. classification_report(true > 0, pred > 0, output_dict=True)
63+
should give equivalent results.
64+
65+
Returns:
66+
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,
67+
Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
68+
"""
69+
true_pos, false_neg, false_pos, true_neg = classify_stable(
70+
true, pred, stability_threshold
71+
)
72+
73+
n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
74+
sum, (true_pos, false_pos, true_neg, false_neg)
75+
)
76+
77+
n_total_pos = n_true_pos + n_false_neg
78+
prevalence = n_total_pos / len(true) # null rate
79+
precision = n_true_pos / (n_true_pos + n_false_pos)
80+
recall = n_true_pos / n_total_pos
81+
82+
is_nan = np.isnan(true) | np.isnan(pred)
83+
true, pred = np.array(true)[~is_nan], np.array(pred)[~is_nan]
84+
85+
return dict(
86+
DAF=precision / prevalence,
87+
Precision=precision,
88+
Recall=recall,
89+
Accuracy=(n_true_pos + n_true_neg) / len(true),
90+
F1=2 * (precision * recall) / (precision + recall),
91+
TPR=n_true_pos / n_total_pos,
92+
FPR=n_false_pos / (n_true_neg + n_false_pos),
93+
TNR=n_true_neg / (n_true_neg + n_false_pos),
94+
FNR=n_false_neg / n_total_pos,
95+
MAE=np.abs(true - pred).mean(),
96+
RMSE=((true - pred) ** 2).mean() ** 0.5,
97+
R2=r2_score(true, pred),
98+
)
99+
100+
101+
models = sorted(
102+
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet "
103+
"MEGNet, BOWSR MEGNet".split(", ")
104+
)
105+
e_form_col = "e_form_per_atom_mp2020_corrected"
106+
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
107+
each_pred_col = "e_above_hull_pred"
108+
109+
df_wbm = load_df_wbm_preds(models).round(3)
110+
111+
for col in [e_form_col, each_true_col]:
112+
assert col in df_wbm, f"{col=} not in {list(df_wbm)=}"
113+
114+
115+
df_metrics = pd.DataFrame()
116+
for model in models:
117+
df_metrics[model] = stable_metrics(
118+
df_wbm[each_true_col],
119+
df_wbm[each_true_col] + df_wbm[e_form_col] - df_wbm[model],
120+
)
121+
122+
assert df_metrics.T.MAE.between(0, 0.2).all(), "MAE not in range"
123+
assert df_metrics.T.R2.between(0.1, 1).all(), "R2 not in range"
124+
assert df_metrics.T.RMSE.between(0, 0.25).all(), "RMSE not in range"
125+
assert df_metrics.isna().sum().sum() == 0, "NaNs in metrics"

matbench_discovery/plots.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import wandb
1616
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
1717

18-
from matbench_discovery.energy import classify_stable
18+
from matbench_discovery.metrics import classify_stable
1919

2020
__author__ = "Janosh Riebesell"
2121
__date__ = "2022-08-05"
@@ -102,7 +102,7 @@ def hist_classified_stable_vs_hull_dist(
102102
each_pred_col: str,
103103
ax: plt.Axes = None,
104104
which_energy: WhichEnergy = "true",
105-
stability_threshold: float = 0,
105+
stability_threshold: float | None = 0,
106106
x_lim: tuple[float | None, float | None] = (-0.7, 0.7),
107107
rolling_acc: float | None = 0.02,
108108
backend: Backend = "plotly",

scripts/compile_metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from matbench_discovery import FIGS, MODELS, WANDB_PATH, today
1515
from matbench_discovery.data import PRED_FILENAMES, load_df_wbm_preds
16-
from matbench_discovery.energy import stable_metrics
16+
from matbench_discovery.metrics import stable_metrics
1717
from matbench_discovery.plots import px
1818

1919
__author__ = "Janosh Riebesell"

scripts/cumulative_clf_metrics.py

+9-25
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,22 @@
22
import pandas as pd
33
from pymatviz.utils import save_fig
44

5-
from matbench_discovery import FIGS, today
6-
from matbench_discovery.data import load_df_wbm_preds
5+
from matbench_discovery import FIGS, STATIC, today
6+
from matbench_discovery.metrics import df_wbm, e_form_col, each_true_col, models
77
from matbench_discovery.plots import cumulative_precision_recall
88

99
__author__ = "Janosh Riebesell, Rhys Goodall"
1010
__date__ = "2022-12-04"
1111

1212

13-
# %%
14-
models = (
15-
"CGCNN, Voronoi Random Forest, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
16-
).split(", ")
17-
18-
df_wbm = load_df_wbm_preds(models).round(3)
19-
20-
# df_wbm.columns = [f"{col}_e_form" if col in models else col for col in df_wbm]
21-
e_form_col = "e_form_per_atom_mp2020_corrected"
22-
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
23-
24-
2513
# %%
2614
df_e_above_hull_pred = pd.DataFrame()
2715
for model in models:
28-
e_above_hul_pred = df_wbm[e_above_hull_col] + df_wbm[model] - df_wbm[e_form_col]
16+
e_above_hul_pred = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]
2917
df_e_above_hull_pred[model] = e_above_hul_pred
3018

3119
fig, df_metric = cumulative_precision_recall(
32-
e_above_hull_true=df_wbm[e_above_hull_col],
20+
e_above_hull_true=df_wbm[each_true_col],
3321
df_preds=df_e_above_hull_pred,
3422
project_end_point="xy",
3523
backend=(backend := "plotly"),
@@ -42,11 +30,7 @@
4230
# fig.suptitle(title)
4331
fig.text(0.5, -0.08, xlabel, ha="center", fontdict={"size": 16})
4432
if backend == "plotly":
45-
# place legend in lower right corner
46-
fig.update_layout(
47-
# title=title,
48-
legend=dict(yanchor="bottom", y=0.02, xanchor="right", x=0.9),
49-
)
33+
fig.layout.legend.update(x=0.01, y=0) # , title=title
5034
fig.layout.height = 500
5135
fig.add_annotation(
5236
x=0.5,
@@ -69,7 +53,7 @@
6953
assert isinstance(trace.y[0], float)
7054
trace.y = [round(y, 3) for y in trace.y]
7155

72-
img_path = f"{FIGS}/{today}-cumulative-clf-metrics"
73-
# save_fig(fig, f"{img_path}.pdf")
74-
save_fig(fig, f"{img_path}.svelte")
75-
# save_fig(fig, f"{img_path}.webp", scale=3)
56+
img_path = f"{today}-cumulative-clf-metrics"
57+
# save_fig(fig, f"{STATIC}/{img_path}.pdf")
58+
save_fig(fig, f"{FIGS}/{img_path}.svelte")
59+
save_fig(fig, f"{STATIC}/{img_path}.webp", scale=3)

scripts/hist_classified_stable_vs_hull_dist.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22
from pymatviz.utils import save_fig
33

44
from matbench_discovery import FIGS, today
5-
from matbench_discovery.data import load_df_wbm_preds
6-
from matbench_discovery.energy import stable_metrics
5+
from matbench_discovery.metrics import (
6+
df_wbm,
7+
e_form_col,
8+
each_pred_col,
9+
each_true_col,
10+
stable_metrics,
11+
)
712
from matbench_discovery.plots import WhichEnergy, hist_classified_stable_vs_hull_dist
813

914
__author__ = "Rhys Goodall, Janosh Riebesell"
@@ -20,13 +25,6 @@
2025

2126
# %%
2227
model_name = "Wrenformer"
23-
df_wbm = load_df_wbm_preds(models=[model_name]).round(3)
24-
25-
26-
# %%
27-
e_form_col = "e_form_per_atom_mp2020_corrected"
28-
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
29-
each_pred_col = "e_above_hull_pred"
3028
which_energy: WhichEnergy = "true"
3129
# std_factor=0,+/-1,+/-2,... changes the criterion for material stability to
3230
# energy+std_factor*std. energy+std means predicted energy plus the model's uncertainty
@@ -68,5 +66,5 @@
6866

6967
# %%
7068
img_path = f"{FIGS}/{today}-wren-wbm-hull-dist-hist-{which_energy=}"
71-
# save_fig(ax, f"{img_path}.pdf")
72-
save_fig(fig, f"{img_path}.html")
69+
# save_fig(fig, f"{img_path}.svelte")
70+
save_fig(fig, f"{img_path}.webp")

0 commit comments

Comments
 (0)