Skip to content

Commit d5422bd

Browse files
committed
refactor hist_clf_vary.py
1 parent a9645b7 commit d5422bd

File tree

1 file changed

+34
-54
lines changed

1 file changed

+34
-54
lines changed

ml_stability/stability_plot_scripts/hist_clf_vary.py

+34-54
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
from scipy.interpolate import interp1d
88

9-
from ml_stability import ROOT
9+
from ml_stability import PKG_DIR, ROOT
1010

1111

1212
__author__ = "Rhys Goodall, Janosh Riebesell"
@@ -34,23 +34,23 @@
3434
["wren", "voronoi", "cgcnn"],
3535
["tab:blue", "tab:orange", "tab:red"],
3636
):
37-
df = pd.read_csv(
37+
data_path = (
3838
f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
3939
)
40-
df = df.set_index("material_id")
40+
df = pd.read_csv(data_path).set_index("material_id")
4141

4242
df["e_above_hull"] = df_hull.e_above_hull
4343

4444
df = df.dropna(subset=["e_above_hull"])
4545

4646
rare = "all"
4747

48-
# rare = "nla"
49-
# df = df[
50-
# ~df["composition"].apply(
51-
# lambda x: any(el.is_rare_earth_metal for el in Composition(x).elements)
52-
# )
53-
# ]
48+
# from pymatgen.core import Composition
49+
# rare = "no-lanthanides"
50+
# df["contains_rare_earths"] = df.composition.map(
51+
# lambda x: any(el.is_rare_earth_metal for el in Composition(x))
52+
# )
53+
# df = df.query("~contains_rare_earths")
5454

5555
e_above_hull = df.e_above_hull.to_numpy().ravel()
5656

@@ -62,7 +62,7 @@
6262

6363
# epistemic_std = np.var(pred, axis=0, ddof=0)
6464

65-
# aleatoric_std = np.mean(np.square(df.filter(like="ale")), axis=0)
65+
aleatoric_std = (df.filter(like="ale") ** 2).mean(axis=0) ** 0.5
6666

6767
# full_std = np.sqrt(epistemic_std + aleatoric_std)
6868

@@ -73,7 +73,6 @@
7373
# test = mean - full_std
7474

7575
crit = "ene"
76-
test = mean
7776

7877
bins = 200
7978
# xlim = (-0.2, 0.2)
@@ -86,20 +85,20 @@
8685
xticks = (-0.4, -0.2, 0, 0.2, 0.4)
8786
# yticks = (0, 300, 600, 900, 1200)
8887

89-
tp = len(e_above_hull[(e_above_hull <= thresh) & (test <= thresh)])
90-
fn = len(e_above_hull[(e_above_hull <= thresh) & (test > thresh)])
88+
tp = len(e_above_hull[(e_above_hull <= thresh) & (mean <= thresh)])
89+
fn = len(e_above_hull[(e_above_hull <= thresh) & (mean > thresh)])
9190

9291
pos = tp + fn
9392

94-
sort = np.argsort(test)
93+
sort = np.argsort(mean)
9594
e_above_hull = e_above_hull[sort]
96-
test = test[sort]
95+
mean = mean[sort]
9796

9897
e_type = "pred"
99-
tp = np.asarray((e_above_hull <= thresh) & (test <= thresh))
100-
fn = np.asarray((e_above_hull <= thresh) & (test > thresh))
101-
fp = np.asarray((e_above_hull > thresh) & (test <= thresh))
102-
tn = np.asarray((e_above_hull > thresh) & (test > thresh))
98+
tp = np.asarray((e_above_hull <= thresh) & (mean <= thresh))
99+
fn = np.asarray((e_above_hull <= thresh) & (mean > thresh))
100+
fp = np.asarray((e_above_hull > thresh) & (mean <= thresh))
101+
tn = np.asarray((e_above_hull > thresh) & (mean > thresh))
103102
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
104103

105104
c_tp = np.cumsum(tp)
@@ -114,49 +113,30 @@
114113

115114
x = np.arange(len(ppv))[:end]
116115

117-
f_ppv = interp1d(x, ppv[:end], kind="cubic")
118-
f_tpr = interp1d(x, tpr[:end], kind="cubic")
116+
precision_curve = interp1d(x, ppv[:end], kind="cubic")
117+
rolling_recall_curve = interp1d(x, tpr[:end], kind="cubic")
119118

120119
line_kwargs = dict(
121120
linewidth=2, color=color, markevery=[-1], marker="x", markersize=14, mew=2.5
122121
)
123-
ax.plot(x[::100], f_tpr(x[::100]), linestyle=":", **line_kwargs)
124-
125-
ax.plot(x[::100], f_ppv(x[::100]), linestyle="-", **line_kwargs)
122+
ax.plot(x[::100], precision_curve(x[::100]), linestyle="-", **line_kwargs)
123+
ax.plot(x[::100], rolling_recall_curve(x[::100]), linestyle=":", **line_kwargs)
126124

127125

128126
ax.set(xlabel="Number of Calculations", ylabel="Percentage")
129127

130-
ax.set(xlim=(0, 8e4), ylim=(0, 100), xticks=(0, 2e4, 4e4, 6e4, 8e4))
131-
132-
ax.plot((-1, -1), (-1, -1), color="tab:blue")
133-
ax.plot((-1, -1), (-1, -1), color="tab:red")
134-
ax.plot((-1, -1), (-1, -1), color="tab:orange")
135-
136-
# ax.plot((-1, -1), (-1, -1), color="tab:purple")
137-
138-
ax.plot((-1, -1), (-1, -1), "k", linestyle="-")
139-
ax.plot((-1, -1), (-1, -1), "k", linestyle=":")
140-
141-
lines = ax.get_lines()
142-
143-
legend1 = ax.legend(
144-
lines[-2:], ["Precision", "Recall"], frameon=False, loc="upper right"
145-
)
146-
legend2 = ax.legend(
147-
lines[-5:-2],
148-
# ["Wren (This Work)", "CGCNN Pre-relax", "CGCNN-D Pre-relax"],
149-
["Wren (This Work)", "CGCNN Pre-relax", "Voronoi Pre-relax"],
150-
frameon=False,
151-
loc="lower right",
152-
)
153-
154-
ax.add_artist(legend1)
155-
156-
ax.set_aspect(1.0 / ax.get_data_ratio())
128+
xlim = (0, 8e4)
129+
ax.set(xlim=xlim, ylim=(0, 100), xticks=np.linspace(*xlim, 5))
157130

131+
ax.plot((0, 0), (0, 0), color="tab:blue", label="Wren (This Work)")
132+
ax.plot((0, 0), (0, 0), color="tab:red", label="CGCNN Pre-relax")
133+
ax.plot((0, 0), (0, 0), color="tab:orange", label="Voronoi Pre-relax")
134+
legend_1 = ax.legend(frameon=False, loc="lower right")
135+
ax.add_artist(legend_1)
158136

159-
# plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.pdf")
160-
# # plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.png")
137+
[prec] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
138+
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")
139+
ax.legend([prec, recall], ["Precision", "Recall"], frameon=False, loc="upper right")
161140

162-
# plt.show()
141+
img_path = f"{PKG_DIR}/plots/{today}-vary-{e_type=}-{crit=}-{rare=}.pdf"
142+
# plt.savefig(img_path)

0 commit comments

Comments
 (0)