Skip to content

Commit 6030ef3

Browse files
committed
add scripts/hist_classified_stable_vs_hull_dist_models.py
1 parent d593ae2 commit 6030ef3

3 files changed

+63
-8
lines changed

scripts/hist_classified_stable_vs_hull_dist.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
"""
1010
Histogram of the energy difference (either according to DFT ground truth [default] or
1111
model predicted energy) to the convex hull for materials in the WBM data set. The
12-
histogram is broken down into true positives, false negatives, false positives, and true
13-
negatives based on whether the model predicts candidates to be below the known convex
14-
hull. Ideally, in discovery setting a model should exhibit high recall, i.e. the
15-
majority of materials below the convex hull being correctly identified by the model.
12+
histogram stacks true/false positives/negatives with different colors.
1613
1714
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
1815
"""

scripts/hist_classified_stable_vs_hull_dist_batches.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
"""
1414
Histogram of the energy difference (either according to DFT ground truth [default] or
1515
model predicted energy) to the convex hull for materials in the WBM data set. The
16-
histogram is broken down into true positives, false negatives, false positives, and true
17-
negatives based on whether the model predicts candidates to be below the known convex
18-
hull. Ideally, in discovery setting a model should exhibit high recall, i.e. the
19-
majority of materials below the convex hull being correctly identified by the model.
16+
histogram stacks true/false positives/negatives with different colors.
2017
2118
See fig. S1 in https://science.org/doi/10.1126/sciadv.abn4117.
2219
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# %%
2+
from matbench_discovery import ROOT, today
3+
from matbench_discovery.load_preds import load_df_wbm_with_preds
4+
from matbench_discovery.plots import (
5+
WhichEnergy,
6+
hist_classified_stable_vs_hull_dist,
7+
plt,
8+
)
9+
10+
__author__ = "Janosh Riebesell"
11+
__date__ = "2022-12-01"
12+
13+
"""
14+
Histogram of the energy difference (either according to DFT ground truth [default] or
15+
model predicted energy) to the convex hull for materials in the WBM data set. The
16+
histogram stacks true/false positives/negatives with different colors.
17+
"""
18+
19+
20+
# %%
21+
models = (
22+
"Wren, CGCNN, CGCNN IS2RE, CGCNN RS2RE, Voronoi RF, "
23+
"Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
24+
).split(", ")
25+
df_wbm = load_df_wbm_with_preds(models=models).round(3)
26+
27+
target_col = "e_form_per_atom_mp2020_corrected"
28+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
29+
30+
31+
# %%
32+
which_energy: WhichEnergy = "true"
33+
fig, axs = plt.subplots(3, 3, figsize=(18, 12))
34+
35+
model_name = "Wrenformer"
36+
37+
for model_name, ax in zip(models, axs.flat, strict=True):
38+
39+
ax, metrics = hist_classified_stable_vs_hull_dist(
40+
e_above_hull_true=df_wbm[e_above_hull_col],
41+
e_above_hull_pred=df_wbm[e_above_hull_col]
42+
+ (df_wbm[model_name] - df_wbm[target_col]),
43+
which_energy=which_energy,
44+
ax=ax,
45+
)
46+
47+
text = f"Enrichment\nFactor = {metrics['enrichment']:.3}"
48+
ax.text(0.02, 0.25, text, fontsize=16, transform=ax.transAxes)
49+
50+
title = f"{model_name} ({len(df_wbm[model_name].dropna()):,})"
51+
ax.set(title=title)
52+
53+
54+
# axs.flat[0].legend(frameon=False, loc="upper left")
55+
56+
fig.suptitle(f"{today} {which_energy=}", y=1.07, fontsize=16)
57+
58+
59+
# %%
60+
img_path = f"{ROOT}/figures/{today}-wbm-hull-dist-hist-models.pdf"
61+
ax.figure.savefig(img_path)

0 commit comments

Comments
 (0)