Skip to content

Commit d593ae2

Browse files
committed
add classify_stable() in matbench_discovery/energy.py
used by plots cumulative_clf_metric() and hist_classified_stable_vs_hull_dist() add test_classify_stable()
1 parent c5d3496 commit d593ae2

11 files changed

+152
-102
lines changed

matbench_discovery/energy.py

+36
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,39 @@ def get_e_form_per_atom(
116116
form_energy = energy - sum(comp[el] * refs[str(el)].energy_per_atom for el in comp)
117117

118118
return form_energy / comp.num_atoms
119+
120+
121+
def classify_stable(
122+
e_above_hull_true: pd.Series,
123+
e_above_hull_pred: pd.Series,
124+
stability_threshold: float = 0,
125+
) -> tuple[pd.Series, pd.Series, pd.Series, pd.Series]:
126+
"""Classify model stability predictions as true/false positive/negatives depending
127+
on if material is actually stable or unstable. All energies are assumed to be in
128+
eV/atom (but shouldn't really matter as long as they're consistent).
129+
130+
Args:
131+
e_above_hull_true (pd.Series): Ground truth energy above convex hull values.
132+
e_above_hull_pred (pd.Series): Model predicted energy above convex hull values.
133+
stability_threshold (float, optional): Maximum energy above convex hull for a
134+
material to still be considered stable. Usually 0, 0.05 or 0.1. Defaults to
135+
0. 0 means a material has to be directly on the hull to be called stable.
136+
Negative values mean a material has to pull the known hull down by that
137+
amount to count as stable. Few materials lie below the known hull, so only
138+
negative values close to 0 make sense.
139+
140+
Returns:
141+
tuple[pd.Series, pd.Series, pd.Series, pd.Series]: Indices for true positives,
142+
false negatives, false positives and true negatives (in this order).
143+
"""
144+
actual_pos = e_above_hull_true <= stability_threshold
145+
actual_neg = e_above_hull_true > stability_threshold
146+
model_pos = e_above_hull_pred <= stability_threshold
147+
model_neg = e_above_hull_pred > stability_threshold
148+
149+
true_pos = actual_pos & model_pos
150+
false_neg = actual_pos & model_neg
151+
false_pos = actual_neg & model_pos
152+
true_neg = actual_neg & model_neg
153+
154+
return true_pos, false_neg, false_pos, true_neg

matbench_discovery/plots.py

+48-62
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import wandb
1313
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
1414

15+
from matbench_discovery.energy import classify_stable
16+
1517
__author__ = "Janosh Riebesell"
1618
__date__ = "2022-08-05"
1719

@@ -69,8 +71,8 @@
6971

7072

7173
def hist_classified_stable_vs_hull_dist(
72-
e_above_hull_pred: pd.Series,
7374
e_above_hull_true: pd.Series,
75+
e_above_hull_pred: pd.Series,
7476
ax: plt.Axes = None,
7577
which_energy: WhichEnergy = "true",
7678
stability_threshold: float = 0,
@@ -90,14 +92,14 @@ def hist_classified_stable_vs_hull_dist(
9092
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
9193
9294
Args:
93-
e_above_hull_pred (pd.Series): energy difference to convex hull predicted by
94-
model, i.e. difference between the model's predicted and true formation
95-
energy.
96-
e_above_hull_true (pd.Series): energy diff to convex hull according to DFT
97-
ground truth.
95+
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
96+
ground truth (in eV / atom).
97+
e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
98+
(in eV / atom). Same as true energy to convex hull plus predicted minus true
99+
formation energy.
98100
ax (plt.Axes, optional): matplotlib axes to plot on.
99-
which_energy (WhichEnergy, optional): Whether to use the true formation energy
100-
or the model's predicted formation energy for the histogram.
101+
which_energy (WhichEnergy, optional): Whether to use the true (DFT) hull
102+
distance or the model's predicted hull distance for the histogram.
101103
stability_threshold (float, optional): set stability threshold as distance to
102104
convex hull in eV/atom, usually 0 or 0.1 eV.
103105
show_threshold (bool, optional): Whether to plot stability threshold as dashed
@@ -114,36 +116,28 @@ def hist_classified_stable_vs_hull_dist(
114116
"""
115117
ax = ax or plt.gca()
116118

117-
test = e_above_hull_pred + e_above_hull_true
118-
# --- histogram of DFT-computed distance to convex hull
119-
if which_energy == "true":
120-
actual_pos = e_above_hull_true <= stability_threshold
121-
actual_neg = e_above_hull_true > stability_threshold
122-
model_pos = test <= stability_threshold
123-
model_neg = test > stability_threshold
124-
125-
n_true_pos = len(e_above_hull_true[actual_pos & model_pos])
126-
n_false_neg = len(e_above_hull_true[actual_pos & model_neg])
127-
128-
n_total_pos = n_true_pos + n_false_neg
129-
null = n_total_pos / len(e_above_hull_true)
130-
131-
true_pos = e_above_hull_true[actual_pos & model_pos]
132-
false_neg = e_above_hull_true[actual_pos & model_neg]
133-
false_pos = e_above_hull_true[actual_neg & model_pos]
134-
true_neg = e_above_hull_true[actual_neg & model_neg]
135-
xlabel = r"$E_\mathrm{above\ hull}$ (eV / atom)"
136-
137-
# --- histogram of model-predicted distance to convex hull
138-
if which_energy == "pred":
139-
true_pos = e_above_hull_pred[actual_pos & model_pos]
140-
false_neg = e_above_hull_pred[actual_pos & model_neg]
141-
false_pos = e_above_hull_pred[actual_neg & model_pos]
142-
true_neg = e_above_hull_pred[actual_neg & model_neg]
143-
xlabel = r"$\Delta E_{Hull-Pred}$ (eV / atom)"
119+
true_pos, false_neg, false_pos, true_neg = classify_stable(
120+
e_above_hull_true, e_above_hull_pred, stability_threshold
121+
)
122+
n_true_pos = sum(true_pos)
123+
n_false_neg = sum(false_neg)
124+
125+
n_total_pos = n_true_pos + n_false_neg
126+
null = n_total_pos / len(e_above_hull_true)
127+
128+
# toggle between histogram of DFT-computed/model-predicted distance to convex hull
129+
e_above_hull = e_above_hull_true if which_energy == "true" else e_above_hull_pred
130+
eah_true_pos = e_above_hull[true_pos]
131+
eah_false_neg = e_above_hull[false_neg]
132+
eah_false_pos = e_above_hull[false_pos]
133+
eah_true_neg = e_above_hull[true_neg]
134+
xlabel = dict(
135+
true="$E_\\mathrm{above\\ hull}$ (eV / atom)",
136+
pred="$E_\\mathrm{above\\ hull\\ pred}$ (eV / atom)",
137+
)[which_energy]
144138

145139
ax.hist(
146-
[true_pos, false_neg, false_pos, true_neg],
140+
[eah_true_pos, eah_false_neg, eah_false_pos, eah_true_neg],
147141
bins=200,
148142
range=x_lim,
149143
alpha=0.5,
@@ -158,7 +152,7 @@ def hist_classified_stable_vs_hull_dist(
158152
)
159153

160154
n_true_pos, n_false_pos, n_true_neg, n_false_neg = map(
161-
len, (true_pos, false_pos, true_neg, false_neg)
155+
len, (eah_true_pos, eah_false_pos, eah_true_neg, eah_false_neg)
162156
)
163157
# null = (tp + fn) / (tp + tn + fp + fn)
164158
precision = n_true_pos / (n_true_pos + n_false_pos)
@@ -181,8 +175,8 @@ def hist_classified_stable_vs_hull_dist(
181175
# compute accuracy within 20 meV/atom intervals
182176
bins = np.arange(x_lim[0], x_lim[1], rolling_accuracy)
183177
bin_counts = np.histogram(e_above_hull_true, bins)[0]
184-
bin_true_pos = np.histogram(true_pos, bins)[0]
185-
bin_true_neg = np.histogram(true_neg, bins)[0]
178+
bin_true_pos = np.histogram(eah_true_pos, bins)[0]
179+
bin_true_neg = np.histogram(eah_true_neg, bins)[0]
186180

187181
# compute accuracy
188182
bin_accuracies = (bin_true_pos + bin_true_neg) / bin_counts
@@ -327,8 +321,8 @@ def rolling_mae_vs_hull_dist(
327321

328322

329323
def cumulative_clf_metric(
330-
e_above_hull_error: pd.Series,
331324
e_above_hull_true: pd.Series,
325+
e_above_hull_pred: pd.Series,
332326
metric: Literal["precision", "recall"],
333327
stability_threshold: float = 0, # set stability threshold as distance to convex
334328
# hull in eV / atom, usually 0 or 0.1 eV
@@ -344,11 +338,11 @@ def cumulative_clf_metric(
344338
predicted stable are included.
345339
346340
Args:
347-
df (pd.DataFrame): Model predictions and target energy values.
348-
e_above_hull_error (str, optional): Column name with residuals of model
349-
predictions, i.e. residual = pred - target. Defaults to "residual".
350-
e_above_hull_true (str, optional): Column name with convex hull distance values.
351-
Defaults to "e_above_hull".
341+
e_above_hull_true (pd.Series): Distance to convex hull according to DFT
342+
ground truth (in eV / atom).
343+
e_above_hull_pred (pd.Series): Distance to convex hull predicted by model
344+
(in eV / atom). Same as true energy to convex hull plus predicted minus true
345+
formation energy.
352346
metric ('precision' | 'recall', optional): Metric to plot.
353347
stability_threshold (float, optional): Max distance from convex hull before
354348
material is considered unstable. Defaults to 0.
@@ -365,25 +359,19 @@ def cumulative_clf_metric(
365359
"""
366360
ax = ax or plt.gca()
367361

368-
e_above_hull_error = e_above_hull_error.sort_values()
369-
e_above_hull_true = e_above_hull_true.loc[e_above_hull_error.index]
362+
e_above_hull_pred = e_above_hull_pred.sort_values()
363+
e_above_hull_true = e_above_hull_true.loc[e_above_hull_pred.index]
370364

371-
true_pos_mask = (e_above_hull_true <= stability_threshold) & (
372-
e_above_hull_error <= stability_threshold
373-
)
374-
false_neg_mask = (e_above_hull_true <= stability_threshold) & (
375-
e_above_hull_error > stability_threshold
376-
)
377-
false_pos_mask = (e_above_hull_true > stability_threshold) & (
378-
e_above_hull_error <= stability_threshold
365+
true_pos, false_neg, false_pos, _true_neg = classify_stable(
366+
e_above_hull_true, e_above_hull_pred, stability_threshold
379367
)
380368

381-
true_pos_cumsum = true_pos_mask.cumsum()
369+
true_pos_cumsum = true_pos.cumsum()
382370

383371
# precision aka positive predictive value (PPV)
384-
precision = true_pos_cumsum / (true_pos_cumsum + false_pos_mask.cumsum()) * 100
385-
n_true_pos = sum(true_pos_mask)
386-
n_false_neg = sum(false_neg_mask)
372+
precision = true_pos_cumsum / (true_pos_cumsum + false_pos.cumsum()) * 100
373+
n_true_pos = sum(true_pos)
374+
n_false_neg = sum(false_neg)
387375
n_total_pos = n_true_pos + n_false_neg
388376
true_pos_rate = true_pos_cumsum / n_total_pos * 100
389377

@@ -443,9 +431,7 @@ def cumulative_clf_metric(
443431
return ax
444432

445433

446-
def wandb_log_scatter(
447-
table: wandb.Table, fields: dict[str, str], **kwargs: Any
448-
) -> None:
434+
def wandb_scatter(table: wandb.Table, fields: dict[str, str], **kwargs: Any) -> None:
449435
"""Log a parity scatter plot using custom vega spec to WandB.
450436
451437
Args:

models/cgcnn/test_cgcnn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from matbench_discovery import DEBUG, ROOT, today
1818
from matbench_discovery.load_preds import df_wbm
19-
from matbench_discovery.plots import wandb_log_scatter
19+
from matbench_discovery.plots import wandb_scatter
2020
from matbench_discovery.slurm import slurm_submit
2121

2222
__author__ = "Janosh Riebesell"
@@ -124,4 +124,4 @@
124124

125125
title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
126126

127-
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
127+
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

models/megnet/test_megnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from matbench_discovery import DEBUG, ROOT, timestamp, today
1414
from matbench_discovery.load_preds import df_wbm
15-
from matbench_discovery.plots import wandb_log_scatter
15+
from matbench_discovery.plots import wandb_scatter
1616
from matbench_discovery.slurm import slurm_submit
1717

1818
"""
@@ -115,4 +115,4 @@
115115
title = f"{model_name} {task_type} {MAE=:.4} {R2=:.4}"
116116
print(title)
117117

118-
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
118+
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

models/voronoi/train_test_voronoi_rf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from matbench_discovery import DEBUG, ROOT, today
1313
from matbench_discovery.load_preds import df_wbm, glob_to_df
14-
from matbench_discovery.plots import wandb_log_scatter
14+
from matbench_discovery.plots import wandb_scatter
1515
from matbench_discovery.slurm import slurm_submit
1616
from models.voronoi import featurizer
1717

@@ -127,4 +127,4 @@
127127
title = f"{model_name} {task_type} {MAE=:.3} {R2=:.3}"
128128
print(title)
129129

130-
wandb_log_scatter(table, fields=dict(x=test_target_col, y=pred_col), title=title)
130+
wandb_scatter(table, fields=dict(x=test_target_col, y=pred_col), title=title)

models/wrenformer/test_wrenformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from aviary.wrenformer.model import Wrenformer
1313

1414
from matbench_discovery import DEBUG, ROOT, today
15-
from matbench_discovery.plots import wandb_log_scatter
15+
from matbench_discovery.plots import wandb_scatter
1616
from matbench_discovery.slurm import slurm_submit
1717

1818
__author__ = "Janosh Riebesell"
@@ -110,4 +110,4 @@
110110

111111
title = rf"Wrenformer {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
112112

113-
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)
113+
wandb_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

scripts/hist_classified_stable_vs_hull_dist.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
# %%
2727
target_col = "e_form_per_atom_mp2020_corrected"
28+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
2829
which_energy: WhichEnergy = "true"
2930
# std_factor=0,+/-1,+/-2,... changes the criterion for material stability to
3031
# energy+std_factor*std. energy+std means predicted energy plus the model's uncertainty
@@ -40,10 +41,15 @@
4041
var_epistemic = df_wbm.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
4142
std_total = (var_epistemic + var_aleatoric) ** 0.5
4243
std_total = df_wbm[f"{model_name}_std"]
44+
e_above_hull_pred = (
45+
df_wbm[e_above_hull_col]
46+
+ (df_wbm[model_name] + std_factor * std_total)
47+
- df_wbm[target_col]
48+
)
4349

4450
ax, metrics = hist_classified_stable_vs_hull_dist(
45-
e_above_hull_pred=df_wbm[model_name] - std_factor * std_total - df_wbm[target_col],
46-
e_above_hull_true=df_wbm.e_above_hull_mp2020_corrected_ppd_mp,
51+
e_above_hull_true=df_wbm[e_above_hull_col],
52+
e_above_hull_pred=e_above_hull_pred,
4753
which_energy=which_energy,
4854
# stability_threshold=-0.05,
4955
rolling_accuracy=0,

scripts/hist_classified_stable_vs_hull_dist_batches.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")
4040

4141
ax, metrics = hist_classified_stable_vs_hull_dist(
42-
e_above_hull_pred=batch_df[model_name] - batch_df[target_col],
4342
e_above_hull_true=batch_df[e_above_hull_col],
43+
e_above_hull_pred=batch_df[e_above_hull_col]
44+
+ (batch_df[model_name] - batch_df[target_col]),
4445
which_energy=which_energy,
4546
ax=ax,
4647
)
@@ -53,8 +54,9 @@
5354

5455

5556
ax, metrics = hist_classified_stable_vs_hull_dist(
56-
e_above_hull_pred=df_wbm[model_name] - df_wbm[target_col],
5757
e_above_hull_true=df_wbm[e_above_hull_col],
58+
e_above_hull_pred=df_wbm[e_above_hull_col]
59+
+ (df_wbm[model_name] - df_wbm[target_col]),
5860
which_energy=which_energy,
5961
ax=axs.flat[-1],
6062
)
@@ -69,5 +71,5 @@
6971

7072

7173
# %%
72-
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-batches"
73-
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")
74+
img_path = f"{ROOT}/figures/{today}-{model_name}-wbm-hull-dist-hist-batches.pdf"
75+
# ax.figure.savefig(img_path)

scripts/precision_recall.py

+8-21
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,26 @@
2626

2727
for model_name, color in zip(models, colors):
2828

29-
e_above_hull_pred = df_wbm[model_name] - df_wbm[target_col]
30-
31-
F1 = f1_score(df_wbm[e_above_hull_col] < 0, e_above_hull_pred < 0)
32-
33-
e_above_hull_error = e_above_hull_pred + df_wbm[e_above_hull_col]
34-
cumulative_clf_metric(
35-
e_above_hull_error,
36-
df_wbm[e_above_hull_col],
37-
color=color,
38-
label=f"{model_name}\n{F1=:.3}",
39-
project_end_point="xy",
40-
ax=ax_prec,
41-
metric="precision",
29+
e_above_hull_pred = df_wbm[e_above_hull_col] + (
30+
df_wbm[model_name] - df_wbm[target_col]
4231
)
43-
44-
cumulative_clf_metric(
45-
e_above_hull_error,
46-
df_wbm[e_above_hull_col],
32+
F1 = f1_score(df_wbm[e_above_hull_col] < 0, e_above_hull_pred < 0)
33+
in_common = dict(
34+
e_above_hull_true=df_wbm[e_above_hull_col],
35+
e_above_hull_pred=e_above_hull_pred,
4736
color=color,
4837
label=f"{model_name}\n{F1=:.3}",
4938
project_end_point="xy",
50-
ax=ax_recall,
51-
metric="recall",
5239
)
40+
cumulative_clf_metric(**in_common, ax=ax_prec, metric="precision")
5341

42+
cumulative_clf_metric(**in_common, ax=ax_recall, metric="recall")
5443

5544
for ax in (ax_prec, ax_recall):
5645
ax.set(xlim=(0, None))
5746

58-
5947
# x-ticks every 10k materials
6048
# ax.set(xticks=range(0, int(ax.get_xlim()[1]), 10_000))
61-
6249
fig.suptitle(f"{today} {model_name}")
6350
xlabel_cumulative = "Materials predicted stable sorted by hull distance"
6451
fig.text(0.5, -0.08, xlabel_cumulative, ha="center")

0 commit comments

Comments
 (0)