Skip to content

Commit 9c6d61c

Browse files
committed
add rolling_accuracy kwarg and doc str to hist_classified_stable_as_func_of_hull_dist()
extend global plot settings: plt.rc("legend", title_fontsize=16) plt.rc("axes", titlesize=16, labelsize=16)
1 parent fed968f commit 9c6d61c

7 files changed

+117
-45
lines changed

matbench_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,19 @@
3030

3131
# %%
3232
df = pd.read_csv(
33-
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
33+
# f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
34+
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
3435
).set_index("material_id")
3536

36-
df["e_above_hull_mp"] = df_wbm.e_above_hull_mp
37+
df["e_above_hull"] = df_wbm.e_above_hull_mp2020_corrected_ppd_mp
3738

3839

3940
# %%
4041
nan_counts = df.isna().sum()
4142
assert all(nan_counts == 0), f"df should not have missing values: {nan_counts}"
4243

43-
target_col = "e_form_target"
44+
# target_col = "e_form_target"
45+
target_col = "e_form_per_atom"
4446
stability_crit: StabilityCriterion = "energy"
4547
which_energy: WhichEnergy = "true"
4648

@@ -57,19 +59,24 @@
5759
pred_cols = df.filter(regex=r"_pred_\d").columns
5860
assert len(pred_cols) == 10
5961

60-
ax = hist_classified_stable_as_func_of_hull_dist(
62+
ax, metrics = hist_classified_stable_as_func_of_hull_dist(
6163
e_above_hull_pred=df[pred_cols].mean(axis=1) - df[target_col],
62-
e_above_hull_true=df.e_above_hull_mp,
64+
e_above_hull_true=df.e_above_hull,
6365
which_energy=which_energy,
6466
stability_crit=stability_crit,
6567
std_pred=std_total,
6668
# stability_threshold=-0.05,
69+
# rolling_accuracy=0,
6770
)
6871

6972
fig = ax.figure
7073
fig.set_size_inches(10, 9)
7174

72-
ax.legend(loc="center left", frameon=False)
75+
ax.legend(
76+
loc="center left",
77+
frameon=False,
78+
title=f"Enrichment Factor = {metrics['enrichment']:.3}",
79+
)
7380

7481
fig_name = f"wren-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
7582
# fig.savefig(f"{ROOT}/figures/{today}-{fig_name}.pdf")

matbench_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
4040
).set_index("material_id")
4141
dfs["wrenformer"] = pd.read_csv(
42-
f"{ROOT}/models/wrenformer/mp/"
43-
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
42+
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
4443
).set_index("material_id")
4544
dfs["bowsr_megnet"] = pd.read_json(
4645
f"{ROOT}/models/bowsr/2022-09-22-bowsr-megnet-wbm-IS2RE.json.gz"
@@ -78,26 +77,32 @@
7877
batch_df = df[df.index.str.startswith(f"wbm-step-{batch_idx}-")]
7978
assert 1e4 < len(batch_df) < 1e5, print(f"{len(batch_df) = :,}")
8079

81-
hist_classified_stable_as_func_of_hull_dist(
80+
ax, metrics = hist_classified_stable_as_func_of_hull_dist(
8281
e_above_hull_pred=batch_df.e_form_per_atom_pred - batch_df.e_form_per_atom,
8382
e_above_hull_true=batch_df.e_above_hull_mp,
8483
which_energy=which_energy,
8584
stability_crit=stability_crit,
8685
ax=ax,
8786
)
8887

88+
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
89+
ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
90+
8991
title = f"Batch {batch_idx} ({len(batch_df.filter(like='e_').dropna()):,})"
9092
ax.set(title=title)
9193

9294

93-
hist_classified_stable_as_func_of_hull_dist(
95+
ax, metrics = hist_classified_stable_as_func_of_hull_dist(
9496
e_above_hull_pred=df.e_form_per_atom_pred - df.e_form_per_atom,
9597
e_above_hull_true=df.e_above_hull_mp,
9698
which_energy=which_energy,
9799
stability_crit=stability_crit,
98100
ax=axs.flat[-1],
99101
)
100102

103+
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
104+
ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
105+
101106
axs.flat[-1].set(title=f"Combined ({len(df.filter(like='e_').dropna()):,})")
102107
axs.flat[0].legend(frameon=False, loc="upper left")
103108

matbench_discovery/plot_scripts/precision_recall.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
).set_index("material_id")
2929

3030
dfs["wrenformer"] = pd.read_csv(
31-
f"{ROOT}/models/wrenformer/mp/"
32-
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
31+
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
3332
).set_index("material_id")
3433

3534
dfs["bowsr_megnet"] = pd.read_json(

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919
data_path = (
2020
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
21-
# f"{ROOT}/models/wrenformer/mp/"
22-
# "2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
21+
# f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
2322
)
2423
df = pd.read_csv(data_path).set_index("material_id")
2524
legend_label = "Wren"

matbench_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
).set_index("material_id")
2222

2323
df_wrenformer = pd.read_csv(
24-
f"{ROOT}/models/wrenformer/mp/"
25-
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
24+
f"{ROOT}/models/wrenformer/mp/2022-09-20-wrenformer-e_form-ensemble-1-preds.csv"
2625
).set_index("material_id")
2726

2827

matbench_discovery/plots.py

+88-28
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
AxLine = Literal["x", "y", "xy", ""]
2020

2121

22-
# --- define global plot settings
22+
# --- start global plot settings
2323
quantity_labels = dict(
2424
n_atoms="Atom Count",
2525
n_elems="Element Count",
@@ -55,7 +55,8 @@
5555

5656

5757
plt.rc("font", size=14)
58-
plt.rc("legend", fontsize=16)
58+
plt.rc("legend", fontsize=16, title_fontsize=16)
59+
plt.rc("axes", titlesize=16, labelsize=16)
5960
plt.rc("savefig", bbox="tight", dpi=200)
6061
plt.rc("figure", dpi=200, titlesize=16)
6162
plt.rcParams["figure.constrained_layout.use"] = True
@@ -69,11 +70,11 @@ def hist_classified_stable_as_func_of_hull_dist(
6970
ax: plt.Axes = None,
7071
which_energy: WhichEnergy = "true",
7172
stability_crit: StabilityCriterion = "energy",
72-
show_mae: bool = False,
73-
stability_threshold: float = 0, # set stability threshold as distance to convex
74-
# hull in eV / atom, usually 0 or 0.1 eV
75-
x_lim: tuple[float, float] = (-0.4, 0.4),
76-
) -> plt.Axes:
73+
stability_threshold: float = 0,
74+
show_threshold: bool = True,
75+
x_lim: tuple[float | None, float | None] = (-0.4, 0.4),
76+
rolling_accuracy: float = 0.02,
77+
) -> tuple[plt.Axes, dict[str, float]]:
7778
"""
7879
Histogram of the energy difference (either according to DFT ground truth [default]
7980
or model predicted energy) to the convex hull for materials in the WBM data set. The
@@ -85,8 +86,33 @@ def hist_classified_stable_as_func_of_hull_dist(
8586
8687
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
8788
88-
NOTE this figure plots hist bars separately which causes aliasing in pdf
89-
to resolve this take into Inkscape and merge regions by color
89+
Args:
90+
e_above_hull_pred (pd.Series): energy difference to convex hull predicted by
91+
model, i.e. difference between the model's predicted and true formation
92+
energy.
93+
e_above_hull_true (pd.Series): energy diff to convex hull according to DFT
94+
ground truth.
95+
std_pred (pd.Series, optional): standard deviation of the model's predicted
96+
formation energy.
97+
ax (plt.Axes, optional): matplotlib axes to plot on.
98+
which_energy (WhichEnergy, optional): Whether to use the true formation energy
99+
or the model's predicted formation energy for the histogram.
100+
stability_crit (StabilityCriterion, optional): Whether to add/subtract the
101+
model's predicted uncertainty from its energy prediction when measuring
102+
predicted stability.
103+
stability_threshold (float, optional): set stability threshold as distance to
104+
convex hull in eV/atom, usually 0 or 0.1 eV.
105+
show_threshold (bool, optional): Whether to plot stability threshold as dashed
106+
vertical line.
107+
x_lim (tuple[float | None, float | None]): x-axis limits.
108+
rolling_accuracy (float): Rolling accuracy window size in eV / atom. Set to 0 to
109+
disable. Defaults to 0.01.
110+
111+
Returns:
112+
tuple[plt.Axes, dict[str, float]]: plot axes and classification metrics
113+
114+
NOTE this figure plots hist bars separately which causes aliasing in pdf. Can be
115+
fixed in Inkscape or similar by merging regions by color.
90116
"""
91117
ax = ax or plt.gca()
92118

@@ -153,26 +179,60 @@ def hist_classified_stable_as_func_of_hull_dist(
153179
# e_above_hull_true
154180
# ), f"{n_all} != {len(e_above_hull_true)}"
155181

156-
# recall = n_true_pos / n_total_pos
157-
# f"Prevalence = {null:.2f}\n{precision = :.2f}\n{recall = :.2f}",
158-
text = f"Enrichment\nFactor = {precision/null:.3}"
159-
if show_mae:
160-
MAE = e_above_hull_pred.abs().mean()
161-
text += f"\n{MAE = :.3}"
162-
163-
ax.text(
164-
0.98,
165-
0.98,
166-
text,
167-
fontsize=18,
168-
verticalalignment="top",
169-
horizontalalignment="right",
170-
transform=ax.transAxes,
171-
)
172-
173-
ax.set(xlabel=xlabel, ylabel="Number of compounds")
182+
ax.set(xlabel=xlabel, ylabel="Number of compounds", xlim=x_lim)
183+
184+
if rolling_accuracy:
185+
# add moving average of the accuracy (computed within 20 meV/atom intervals) as
186+
# a function of ΔHd,MP is shown as a blue line (right axis)
187+
ax_acc = ax.twinx()
188+
ax_acc.set_ylabel("Accuracy", color="darkblue")
189+
ax_acc.tick_params(labelcolor="darkblue")
190+
ax_acc.set(ylim=(0, 1))
191+
192+
# --- moving average of the accuracy
193+
# compute accuracy within 20 meV/atom intervals
194+
bins = np.arange(x_lim[0], x_lim[1], rolling_accuracy)
195+
bin_counts = np.histogram(e_above_hull_true, bins)[0]
196+
bin_true_pos = np.histogram(true_pos, bins)[0]
197+
bin_true_neg = np.histogram(true_neg, bins)[0]
198+
199+
# compute accuracy
200+
bin_accuracies = (bin_true_pos + bin_true_neg) / bin_counts
201+
# plot accuracy
202+
ax_acc.plot(
203+
bins[:-1],
204+
bin_accuracies,
205+
color="tab:blue",
206+
label="Accuracy",
207+
linewidth=3,
208+
)
209+
# ax2.fill_between(
210+
# bin_centers,
211+
# bin_accuracy - bin_accuracy_std,
212+
# bin_accuracy + bin_accuracy_std,
213+
# color="tab:blue",
214+
# alpha=0.2,
215+
# )
216+
217+
if show_threshold:
218+
ax.axvline(
219+
stability_threshold,
220+
color="k",
221+
linestyle="--",
222+
label="Stability Threshold",
223+
)
174224

175-
return ax
225+
recall = n_true_pos / n_total_pos
226+
227+
return ax, {
228+
"enrichment": precision / null,
229+
"precision": precision,
230+
"recall": recall,
231+
"prevalence": null,
232+
"accuracy": (n_true_pos + n_true_neg)
233+
/ (n_true_pos + n_true_neg + n_false_pos + n_false_neg),
234+
"f1": 2 * (precision * recall) / (precision + recall),
235+
}
176236

177237

178238
def rolling_mae_vs_hull_dist(

tests/test_plots.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_hist_classified_stable_as_func_of_hull_dist(
146146
else:
147147
std_total = None
148148

149-
ax = hist_classified_stable_as_func_of_hull_dist(
149+
ax, metrics = hist_classified_stable_as_func_of_hull_dist(
150150
e_above_hull_pred=df.e_above_hull_pred,
151151
e_above_hull_true=df.e_above_hull_mp,
152152
ax=ax,
@@ -160,3 +160,6 @@ def test_hist_classified_stable_as_func_of_hull_dist(
160160
# assert ax.get_ylim() == pytest.approx((0, 6.3))
161161
assert ax.get_ylabel() == "Number of compounds"
162162
assert ax.get_xlabel() == r"$\Delta E_{Hull-MP}$ (eV / atom)"
163+
164+
assert metrics["precision"] > 0.3
165+
assert metrics["recall"] > 0.3

0 commit comments

Comments
 (0)