Skip to content

Commit 383b8f4

Browse files
committed
add code to generate hist_classify_stable_as_func_of_hull_dist at end of ml_stability/m3gnet/m3gnet_relax_wbm.py
powered by new module ml_stability/plots/plot_funcs.py
1 parent c2a45b3 commit 383b8f4

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed

ml_stability/plots/plot_funcs.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# %%
2+
from typing import Literal
3+
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
7+
8+
__author__ = "Janosh Riebesell"
9+
__date__ = "2022-08-05"
10+
11+
"""
12+
Histogram of the energy difference (either according to DFT ground truth [default] or
13+
model predicted energy) to the convex hull for materials in the WBM data set. The
14+
histogram is broken down into true positives, false negatives, false positives, and true
15+
negatives based on whether the model predicts candidates to be below the known convex
16+
hull. Ideally, in discovery setting a model should exhibit high recall, i.e. the
17+
majority of materials below the convex hull being correctly identified by the model.
18+
19+
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
20+
"""
21+
22+
23+
plt.rc("font", size=18)
24+
plt.rc("savefig", bbox="tight", dpi=200)
25+
plt.rcParams["figure.constrained_layout.use"] = True
26+
plt.rc("figure", dpi=150, titlesize=20)
27+
28+
29+
def hist_classify_stable_as_func_of_hull_dist(
30+
# df: pd.DataFrame,
31+
formation_energy_targets: pd.Series,
32+
formation_energy_preds: pd.Series,
33+
e_above_hull_vals: pd.Series,
34+
rare: str = "all",
35+
std_vals: pd.Series = None,
36+
criterion: Literal["energy", "std", "neg"] = "energy",
37+
energy_type: Literal["true", "pred"] = "true",
38+
) -> plt.Axes:
39+
"""
40+
Histogram of the energy difference (either according to DFT ground truth [default]
41+
or model predicted energy) to the convex hull for materials in the WBM data set. The
42+
histogram is broken down into true positives, false negatives, false positives, and
43+
true negatives based on whether the model predicts candidates to be below the known
44+
convex hull. Ideally, in discovery setting a model should exhibit high recall, i.e.
45+
the majority of materials below the convex hull being correctly identified by the
46+
model.
47+
48+
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
49+
50+
51+
# NOTE this figure plots hist bars separately which causes aliasing in pdf
52+
# to resolve this take into Inkscape and merge regions by color
53+
"""
54+
assert e_above_hull_vals.isna().sum() == 0
55+
56+
error = formation_energy_preds - formation_energy_targets
57+
mean = error + e_above_hull_vals
58+
59+
test = mean
60+
61+
if std_vals is not None:
62+
if criterion == "std":
63+
test += std_vals
64+
elif criterion == "neg":
65+
test -= std_vals
66+
67+
xlim = (-0.4, 0.4)
68+
69+
# set stability threshold at on or 0.1 eV / atom above the hull
70+
stability_thresh = (0, 0.1)[0]
71+
72+
actual_pos = e_above_hull_vals <= stability_thresh
73+
actual_neg = e_above_hull_vals > stability_thresh
74+
model_pos = test <= stability_thresh
75+
model_neg = test > stability_thresh
76+
77+
n_true_pos = len(e_above_hull_vals[actual_pos & model_pos])
78+
n_false_neg = len(e_above_hull_vals[actual_pos & model_neg])
79+
80+
n_total_pos = n_true_pos + n_false_neg
81+
null = n_total_pos / len(e_above_hull_vals)
82+
83+
# --- histogram by DFT-computed distance to convex hull
84+
if energy_type == "true":
85+
true_pos = e_above_hull_vals[actual_pos & model_pos]
86+
false_neg = e_above_hull_vals[actual_pos & model_neg]
87+
false_pos = e_above_hull_vals[actual_neg & model_pos]
88+
true_neg = e_above_hull_vals[actual_neg & model_neg]
89+
xlabel = r"$\Delta E_{Hull-MP}$ / eV per atom"
90+
91+
# --- histogram by model-predicted distance to convex hull
92+
if energy_type == "pred":
93+
true_pos = mean[actual_pos & model_pos]
94+
false_neg = mean[actual_pos & model_neg]
95+
false_pos = mean[actual_neg & model_pos]
96+
true_neg = mean[actual_neg & model_neg]
97+
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
98+
99+
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
100+
101+
ax.hist(
102+
[true_pos, false_neg, false_pos, true_neg],
103+
bins=200,
104+
range=xlim,
105+
alpha=0.5,
106+
color=["tab:green", "tab:orange", "tab:red", "tab:blue"],
107+
label=[
108+
"True Positives",
109+
"False Negatives",
110+
"False Positives",
111+
"True Negatives",
112+
],
113+
stacked=True,
114+
)
115+
116+
ax.legend(frameon=False, loc="upper left")
117+
118+
n_true_pos, n_false_pos, n_true_neg, n_false_neg = (
119+
len(true_pos),
120+
len(false_pos),
121+
len(true_neg),
122+
len(false_neg),
123+
)
124+
# null = (tp + fn) / (tp + tn + fp + fn)
125+
ppv = n_true_pos / (n_true_pos + n_false_pos)
126+
tpr = n_true_pos / n_total_pos
127+
f1 = 2 * ppv * tpr / (ppv + tpr)
128+
129+
assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len(
130+
formation_energy_targets
131+
)
132+
133+
print(f"PPV: {ppv:.2f}")
134+
print(f"TPR: {tpr:.2f}")
135+
print(f"F1: {f1:.2f}")
136+
print(f"Enrich: {ppv/null:.2f}")
137+
print(f"Null: {null:.2f}")
138+
139+
RMSE = (error**2.0).mean() ** 0.5
140+
MAE = error.abs().mean()
141+
print(f"{MAE=:.3}")
142+
print(f"{RMSE=:.3}")
143+
144+
# anno_text = f"Prevalence = {null:.2f}\nPrecision = {ppv:.2f}\nRecall = {tpr:.2f}",
145+
anno_text = f"Enrichment\nFactor = {ppv/null:.1f}"
146+
147+
ax.text(0.75, 0.9, anno_text, transform=ax.transAxes, fontsize=20)
148+
149+
ax.set(
150+
xlabel=xlabel,
151+
ylabel="Number of Compounds",
152+
title=f"data size = {len(e_above_hull_vals):,}",
153+
)
154+
155+
return ax

0 commit comments

Comments
 (0)