Skip to content

Commit a16b46a

Browse files
committed
more refactoring of rhys plotting scripts for clearer variable names
1 parent 105e468 commit a16b46a

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

ml_stability/stability_plot_scripts/hist_clf_vary.py ml_stability/stability_plot_scripts/precision_recall_as_func_of_calc_count.py

+20-26
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121

2222

2323
# %%
24-
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
25-
2624
df_hull = pd.read_csv(
2725
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
2826
).set_index("material_id")
2927

3028

29+
# %%
30+
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
31+
3132
for model_name, color in zip(
3233
# ["wren", "cgcnn", "cgcnn-d"],
3334
# ["tab:blue", "tab:red", "tab:purple"],
@@ -41,8 +42,9 @@
4142

4243
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
4344

44-
df = df.dropna(subset=["e_above_mp_hull"])
45+
assert df.e_above_mp_hull.isna().sum() == 0
4546

47+
target_col = "e_form_target"
4648
rare = "all"
4749

4850
# from pymatgen.core import Composition
@@ -52,19 +54,16 @@
5254
# )
5355
# df = df.query("~contains_rare_earths")
5456

55-
e_above_mp_hull = df.e_above_mp_hull.to_numpy().ravel()
56-
57-
# tar = df[tar_cols].to_numpy().ravel() - e_hull
58-
tar_f = df.filter(like="target").to_numpy().ravel()
57+
e_above_mp_hull = df.e_above_mp_hull
5958

60-
# mean = np.average(pred, axis=0) - e_hull
61-
mean = df.filter(like="pred").mean(axis=1) - tar_f + e_above_mp_hull
59+
# mean = df.filter(like="pred").mean(axis=1) - e_hull
60+
mean = df.filter(like="pred").mean(axis=1) - df[target_col] + e_above_mp_hull
6261

63-
# epistemic_std = np.var(pred, axis=0, ddof=0)
62+
# epistemic_var = df.filter(like="pred").var(axis=1, ddof=0)
6463

65-
aleatoric_std = (df.filter(like="ale") ** 2).mean(axis=0) ** 0.5
64+
# aleatoric_var = (df.filter(like="ale") ** 2).mean(axis=1)
6665

67-
# full_std = np.sqrt(epistemic_std + aleatoric_std)
66+
# full_std = (epistemic_var + aleatoric_var) ** 0.5
6867

6968
# crit = "std"
7069
# test = mean + full_std
@@ -85,29 +84,24 @@
8584
xticks = (-0.4, -0.2, 0, 0.2, 0.4)
8685
# yticks = (0, 300, 600, 900, 1200)
8786

88-
tp = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean <= thresh)])
89-
fn = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean > thresh)])
87+
n_true_pos = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean <= thresh)])
88+
n_false_neg = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean > thresh)])
9089

91-
pos = tp + fn
90+
n_total_pos = n_true_pos + n_false_neg
9291

9392
sort = np.argsort(mean)
9493
e_above_mp_hull = e_above_mp_hull[sort]
9594
mean = mean[sort]
9695

9796
e_type = "pred"
98-
tp = np.asarray((e_above_mp_hull <= thresh) & (mean <= thresh))
99-
fn = np.asarray((e_above_mp_hull <= thresh) & (mean > thresh))
100-
fp = np.asarray((e_above_mp_hull > thresh) & (mean <= thresh))
101-
tn = np.asarray((e_above_mp_hull > thresh) & (mean > thresh))
97+
true_pos_cumsum = ((e_above_mp_hull <= thresh) & (mean <= thresh)).cumsum()
98+
false_neg_cumsum = ((e_above_mp_hull <= thresh) & (mean > thresh)).cumsum()
99+
false_pos_cumsum = ((e_above_mp_hull > thresh) & (mean <= thresh)).cumsum()
100+
true_neg_cumsum = ((e_above_mp_hull > thresh) & (mean > thresh)).cumsum()
102101
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
103102

104-
c_tp = np.cumsum(tp)
105-
c_fn = np.cumsum(fn)
106-
c_fp = np.cumsum(fp)
107-
c_tn = np.cumsum(tn)
108-
109-
ppv = c_tp / (c_tp + c_fp) * 100
110-
tpr = c_tp / pos * 100
103+
ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_cumsum) * 100
104+
tpr = true_pos_cumsum / n_total_pos * 100
111105

112106
end = np.argmax(tpr)
113107

0 commit comments

Comments
 (0)