Skip to content

Commit 105e468

Browse files
committed
rename col e_above_{->mp_}hull in data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv
1 parent d5422bd commit 105e468

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

ml_stability/stability_plot_scripts/hist_clf_vary.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
)
4040
df = pd.read_csv(data_path).set_index("material_id")
4141

42-
df["e_above_hull"] = df_hull.e_above_hull
42+
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
4343

44-
df = df.dropna(subset=["e_above_hull"])
44+
df = df.dropna(subset=["e_above_mp_hull"])
4545

4646
rare = "all"
4747

@@ -52,13 +52,13 @@
5252
# )
5353
# df = df.query("~contains_rare_earths")
5454

55-
e_above_hull = df.e_above_hull.to_numpy().ravel()
55+
e_above_mp_hull = df.e_above_mp_hull.to_numpy().ravel()
5656

5757
# tar = df[tar_cols].to_numpy().ravel() - e_hull
5858
tar_f = df.filter(like="target").to_numpy().ravel()
5959

6060
# mean = np.average(pred, axis=0) - e_hull
61-
mean = df.filter(like="pred").T.mean(axis=0) - tar_f + e_above_hull
61+
mean = df.filter(like="pred").mean(axis=1) - tar_f + e_above_mp_hull
6262

6363
# epistemic_std = np.var(pred, axis=0, ddof=0)
6464

@@ -85,20 +85,20 @@
8585
xticks = (-0.4, -0.2, 0, 0.2, 0.4)
8686
# yticks = (0, 300, 600, 900, 1200)
8787

88-
tp = len(e_above_hull[(e_above_hull <= thresh) & (mean <= thresh)])
89-
fn = len(e_above_hull[(e_above_hull <= thresh) & (mean > thresh)])
88+
tp = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean <= thresh)])
89+
fn = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean > thresh)])
9090

9191
pos = tp + fn
9292

9393
sort = np.argsort(mean)
94-
e_above_hull = e_above_hull[sort]
94+
e_above_mp_hull = e_above_mp_hull[sort]
9595
mean = mean[sort]
9696

9797
e_type = "pred"
98-
tp = np.asarray((e_above_hull <= thresh) & (mean <= thresh))
99-
fn = np.asarray((e_above_hull <= thresh) & (mean > thresh))
100-
fp = np.asarray((e_above_hull > thresh) & (mean <= thresh))
101-
tn = np.asarray((e_above_hull > thresh) & (mean > thresh))
98+
tp = np.asarray((e_above_mp_hull <= thresh) & (mean <= thresh))
99+
fn = np.asarray((e_above_mp_hull <= thresh) & (mean > thresh))
100+
fp = np.asarray((e_above_mp_hull > thresh) & (mean <= thresh))
101+
tn = np.asarray((e_above_mp_hull > thresh) & (mean > thresh))
102102
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
103103

104104
c_tp = np.cumsum(tp)

0 commit comments

Comments
 (0)