Skip to content

Commit 967b482

Browse files
committed
refactor plot func hist_classified_stable_as_func_of_hull_dist()
add script mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py
1 parent 8daf748 commit 967b482

File tree

3 files changed

+229
-85
lines changed

3 files changed

+229
-85
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# %%
2+
from datetime import datetime
3+
from typing import Literal
4+
5+
import matplotlib.pyplot as plt
6+
import pandas as pd
7+
8+
from mb_discovery import ROOT
9+
from mb_discovery.plot_scripts.plot_funcs import (
10+
hist_classified_stable_as_func_of_hull_dist,
11+
)
12+
13+
14+
__author__ = "Rhys Goodall, Janosh Riebesell"
15+
__date__ = "2022-06-18"
16+
17+
"""
18+
Histogram of the energy difference (either according to DFT ground truth [default] or
19+
model predicted energy) to the convex hull for materials in the WBM data set. The
20+
histogram is broken down into true positives, false negatives, false positives, and true
21+
negatives based on whether the model predicts candidates to be below the known convex
22+
hull. Ideally, in discovery setting a model should exhibit high recall, i.e. the
23+
majority of materials below the convex hull being correctly identified by the model.
24+
25+
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
26+
"""
27+
28+
today = f"{datetime.now():%Y-%m-%d}"
29+
30+
plt.rc("savefig", bbox="tight", dpi=200)
31+
plt.rcParams["figure.constrained_layout.use"] = True
32+
plt.rc("figure", dpi=150)
33+
plt.rc("font", size=16)
34+
35+
36+
# %%
37+
df = pd.read_csv(
38+
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
39+
).set_index("material_id")
40+
41+
df_hull = pd.read_csv(
42+
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
43+
).set_index("material_id")
44+
45+
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
46+
47+
# download wbm-steps-summary.csv (23.31 MB)
48+
df_summary = pd.read_csv(
49+
"https://figshare.com/ndownloader/files/36714216?private_link=ff0ad14505f9624f0c05"
50+
).set_index("material_id")
51+
52+
53+
# %%
54+
nan_counts = df.isna().sum()
55+
assert all(nan_counts == 0), f"df should not have missing values: {nan_counts}"
56+
57+
target_col = "e_form_target"
58+
criterion: Literal["energy", "std", "neg_std"] = "energy"
59+
energy_type: Literal["true", "pred"] = "true"
60+
61+
62+
# make sure we average the expected number of ensemble member predictions
63+
pred_cols = df.filter(regex=r"_pred_\d").columns
64+
assert len(pred_cols) == 10
65+
66+
ax = hist_classified_stable_as_func_of_hull_dist(
67+
df,
68+
target_col,
69+
pred_cols,
70+
e_above_hull_col="e_above_mp_hull",
71+
energy_type=energy_type,
72+
criterion=criterion,
73+
)
74+
75+
ax.figure.set_size_inches(10, 9)
76+
77+
ax.legend(loc="upper left", frameon=False)
78+
79+
img_path = (
80+
f"{ROOT}/figures/{today}-wren-wbm-hull-dist-hist-{energy_type=}-{criterion=}.pdf"
81+
)
82+
# plt.savefig(img_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# %%
2+
from datetime import datetime
3+
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
7+
from mb_discovery import ROOT
8+
from mb_discovery.plot_scripts.plot_funcs import (
9+
hist_classified_stable_as_func_of_hull_dist,
10+
)
11+
12+
13+
__author__ = "Rhys Goodall, Janosh Riebesell"
14+
__date__ = "2022-08-25"
15+
16+
"""
17+
Histogram of the energy difference (either according to DFT ground truth [default] or
18+
model predicted energy) to the convex hull for materials in the WBM data set. The
19+
histogram is broken down into true positives, false negatives, false positives, and true
20+
negatives based on whether the model predicts candidates to be below the known convex
21+
hull. Ideally, in discovery setting a model should exhibit high recall, i.e. the
22+
majority of materials below the convex hull being correctly identified by the model.
23+
24+
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
25+
"""
26+
27+
today = f"{datetime.now():%Y-%m-%d}"
28+
29+
plt.rc("savefig", bbox="tight", dpi=200)
30+
plt.rcParams["figure.constrained_layout.use"] = True
31+
plt.rc("figure", dpi=150)
32+
plt.rc("font", size=16)
33+
34+
35+
# %%
36+
df = pd.read_csv(
37+
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
38+
).set_index("material_id")
39+
40+
df_hull = pd.read_csv(
41+
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
42+
).set_index("material_id")
43+
44+
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
45+
46+
# download wbm-steps-summary.csv (23.31 MB)
47+
df_summary = pd.read_csv(
48+
"https://figshare.com/ndownloader/files/36714216?private_link=ff0ad14505f9624f0c05"
49+
).set_index("material_id")
50+
51+
52+
# %%
53+
assert df.e_above_mp_hull.isna().sum() == 0
54+
55+
energy_type = "true"
56+
criterion = "energy"
57+
df["wbm_batch"] = df.index.str.split("-").str[2]
58+
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
59+
60+
# make sure we average the expected number of ensemble member predictions
61+
pred_cols = df.filter(regex=r"_pred_\d").columns
62+
assert len(pred_cols) == 10
63+
64+
common_kwargs = dict(
65+
target_col="e_form_target",
66+
pred_cols=pred_cols,
67+
energy_type=energy_type,
68+
criterion=criterion,
69+
e_above_hull_col="e_above_mp_hull",
70+
)
71+
72+
for (batch_idx, batch_df), ax in zip(df.groupby("wbm_batch"), axs.flat):
73+
hist_classified_stable_as_func_of_hull_dist(batch_df, ax=ax, **common_kwargs)
74+
75+
title = f"Batch {batch_idx} ({len(df):,})"
76+
ax.set(title=title)
77+
78+
79+
hist_classified_stable_as_func_of_hull_dist(df, ax=axs.flat[-1], **common_kwargs)
80+
81+
axs.flat[-1].set(title=f"Combined {batch_idx} ({len(df):,})")
82+
axs.flat[0].legend(frameon=False, loc="upper left")
83+
84+
img_name = f"{today}-wren-wbm-hull-dist-hist-{energy_type=}-{criterion=}.pdf"
85+
# plt.savefig(f"{ROOT}/figures/{img_name}")
+62-85
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,33 @@
1-
# %%
2-
from typing import Literal
1+
from __future__ import annotations
2+
3+
from typing import Literal, Sequence
34

45
import matplotlib.pyplot as plt
56
import pandas as pd
6-
from matplotlib.offsetbox import AnchoredText
77

88

99
__author__ = "Janosh Riebesell"
1010
__date__ = "2022-08-05"
1111

12-
"""
13-
Histogram of the energy difference (either according to DFT ground truth [default] or
14-
model predicted energy) to the convex hull for materials in the WBM data set. The
15-
histogram is broken down into true positives, false negatives, false positives, and true
16-
negatives based on whether the model predicts candidates to be below the known convex
17-
hull. Ideally, in discovery setting a model should exhibit high recall, i.e. the
18-
majority of materials below the convex hull being correctly identified by the model.
19-
20-
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
21-
"""
22-
2312

2413
plt.rc("savefig", bbox="tight", dpi=200)
2514
plt.rcParams["figure.constrained_layout.use"] = True
2615
plt.rc("figure", dpi=150)
2716
plt.rc("font", size=16)
2817

2918

30-
def hist_classify_stable_as_func_of_hull_dist(
31-
# df: pd.DataFrame,
32-
formation_energy_targets: pd.Series,
33-
formation_energy_preds: pd.Series,
34-
e_above_hull_vals: pd.Series,
35-
rare: str = "all",
36-
std_vals: pd.Series = None,
37-
criterion: Literal["energy", "std", "neg"] = "energy",
19+
def hist_classified_stable_as_func_of_hull_dist(
20+
df: pd.DataFrame,
21+
target_col: str,
22+
pred_cols: Sequence[str],
23+
e_above_hull_col: str,
24+
ax: plt.Axes = None,
3825
energy_type: Literal["true", "pred"] = "true",
39-
annotate_all_stats: bool = False,
26+
criterion: Literal["energy", "std", "neg_std"] = "energy",
27+
show_mae: bool = False,
28+
stability_thresh: float = 0, # set stability threshold as distance to convex hull
29+
# in eV / atom, usually 0 or 0.1 eV
30+
x_lim: tuple[float, float] = (-0.4, 0.4),
4031
) -> plt.Axes:
4132
"""
4233
Histogram of the energy difference (either according to DFT ground truth [default]
@@ -49,40 +40,43 @@ def hist_classify_stable_as_func_of_hull_dist(
4940
5041
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
5142
52-
53-
# NOTE this figure plots hist bars separately which causes aliasing in pdf
54-
# to resolve this take into Inkscape and merge regions by color
43+
NOTE this figure plots hist bars separately which causes aliasing in pdf
44+
to resolve this take into Inkscape and merge regions by color
5545
"""
56-
assert e_above_hull_vals.isna().sum() == 0
46+
if ax is None:
47+
ax = plt.gca()
5748

58-
error = formation_energy_preds - formation_energy_targets
49+
error = df[pred_cols].mean(axis=1) - df[target_col]
50+
e_above_hull_vals = df[e_above_hull_col]
5951
mean = error + e_above_hull_vals
6052

61-
test = mean
53+
if criterion == "energy":
54+
test = mean
55+
elif "std" in criterion:
56+
# TODO column names to compute standard deviation from are currently hardcoded
57+
# needs to be updated when adding non-aviary models with uncertainty estimation
58+
var_aleatoric = (df.filter(like="_ale_") ** 2).mean(axis=1)
59+
var_epistemic = df.filter(regex=r"_pred_\d").var(axis=1, ddof=0)
60+
std_total = (var_epistemic + var_aleatoric) ** 0.5
6261

63-
if std_vals is not None:
6462
if criterion == "std":
65-
test += std_vals
66-
elif criterion == "neg":
67-
test -= std_vals
63+
test += std_total
64+
elif criterion == "neg_std":
65+
test -= std_total
6866

69-
xlim = (-0.4, 0.4)
70-
71-
# set stability threshold at on or 0.1 eV / atom above the hull
72-
stability_thresh = (0, 0.1)[0]
73-
74-
actual_pos = e_above_hull_vals <= stability_thresh
75-
actual_neg = e_above_hull_vals > stability_thresh
76-
model_pos = test <= stability_thresh
77-
model_neg = test > stability_thresh
67+
# --- histogram by DFT-computed distance to convex hull
68+
if energy_type == "true":
69+
actual_pos = e_above_hull_vals <= stability_thresh
70+
actual_neg = e_above_hull_vals > stability_thresh
71+
model_pos = test <= stability_thresh
72+
model_neg = test > stability_thresh
7873

79-
n_true_pos = len(e_above_hull_vals[actual_pos & model_pos])
80-
n_false_neg = len(e_above_hull_vals[actual_pos & model_neg])
74+
n_true_pos = len(e_above_hull_vals[actual_pos & model_pos])
75+
n_false_neg = len(e_above_hull_vals[actual_pos & model_neg])
8176

82-
n_total_pos = n_true_pos + n_false_neg
77+
n_total_pos = n_true_pos + n_false_neg
78+
null = n_total_pos / len(e_above_hull_vals)
8379

84-
# --- histogram by DFT-computed distance to convex hull
85-
if energy_type == "true":
8680
true_pos = e_above_hull_vals[actual_pos & model_pos]
8781
false_neg = e_above_hull_vals[actual_pos & model_neg]
8882
false_pos = e_above_hull_vals[actual_neg & model_pos]
@@ -97,12 +91,10 @@ def hist_classify_stable_as_func_of_hull_dist(
9791
true_neg = mean[actual_neg & model_neg]
9892
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
9993

100-
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
101-
10294
ax.hist(
10395
[true_pos, false_neg, false_pos, true_neg],
10496
bins=200,
105-
range=xlim,
97+
range=x_lim,
10698
alpha=0.5,
10799
color=["tab:green", "tab:orange", "tab:red", "tab:blue"],
108100
label=[
@@ -114,49 +106,34 @@ def hist_classify_stable_as_func_of_hull_dist(
114106
stacked=True,
115107
)
116108

117-
ax.legend(frameon=False, loc="upper left")
118-
119109
n_true_pos, n_false_pos, n_true_neg, n_false_neg = (
120110
len(true_pos),
121111
len(false_pos),
122112
len(true_neg),
123113
len(false_neg),
124114
)
125115
# null = (tp + fn) / (tp + tn + fp + fn)
126-
Null = n_total_pos / len(e_above_hull_vals)
127-
PPV = n_true_pos / (n_true_pos + n_false_pos)
128-
TPR = n_true_pos / n_total_pos
129-
F1 = 2 * PPV * TPR / (PPV + TPR)
130-
131-
assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len(
132-
formation_energy_targets
133-
)
134-
135-
RMSE = (error**2).mean() ** 0.5
136-
MAE = error.abs().mean()
137-
138-
# anno_text = f"Prevalence = {null:.2f}\nPrecision = {ppv:.2f}\nRecall = {tpr:.2f}",
139-
anno_text = f"Enrichment Factor = {PPV/Null:.3}"
140-
if annotate_all_stats:
141-
anno_text += f"\n{MAE = :.3}\n{RMSE = :.3}\n{Null = :.3}\n{TPR = :.3}"
142-
else:
143-
print(f"{PPV = :.3}")
144-
print(f"{TPR = :.3}")
145-
print(f"{F1 = :.3}")
146-
print(f"Enrich: {PPV/Null:.2f}")
147-
print(f"{Null = :.3}")
148-
print(f"{MAE = :.3}")
149-
print(f"{RMSE = :.3}")
150-
151-
text_box = AnchoredText(
152-
anno_text, loc="upper right", frameon=False, prop=dict(fontsize=16)
116+
precision = n_true_pos / (n_true_pos + n_false_pos)
117+
118+
assert n_true_pos + n_false_pos + n_true_neg + n_false_neg == len(df)
119+
120+
# recall = n_true_pos / n_total_pos
121+
# f"Prevalence = {null:.2f}\n{precision = :.2f}\n{recall = :.2f}",
122+
text = f"Enrichment\nFactor = {precision/null:.3}"
123+
if show_mae:
124+
MAE = error.abs().mean()
125+
text += f"\n{MAE = :.3}"
126+
127+
ax.text(
128+
0.98,
129+
0.98,
130+
text,
131+
fontsize=18,
132+
verticalalignment="top",
133+
horizontalalignment="right",
134+
transform=ax.transAxes,
153135
)
154-
ax.add_artist(text_box)
155136

156-
ax.set(
157-
xlabel=xlabel,
158-
ylabel="Number of Compounds",
159-
title=f"data size = {len(e_above_hull_vals):,}",
160-
)
137+
ax.set(xlabel=xlabel, ylabel="Number of compounds")
161138

162139
return ax

0 commit comments

Comments
 (0)