|
21 | 21 |
|
22 | 22 |
|
23 | 23 | # %%
|
24 |
| -fig, ax = plt.subplots(1, 1, figsize=(10, 9)) |
25 |
| - |
26 | 24 | df_hull = pd.read_csv(
|
27 | 25 | f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
|
28 | 26 | ).set_index("material_id")
|
29 | 27 |
|
30 | 28 |
|
| 29 | +# %% |
| 30 | +fig, ax = plt.subplots(1, 1, figsize=(10, 9)) |
| 31 | + |
31 | 32 | for model_name, color in zip(
|
32 | 33 | # ["wren", "cgcnn", "cgcnn-d"],
|
33 | 34 | # ["tab:blue", "tab:red", "tab:purple"],
|
|
41 | 42 |
|
42 | 43 | df["e_above_mp_hull"] = df_hull.e_above_mp_hull
|
43 | 44 |
|
44 |
| - df = df.dropna(subset=["e_above_mp_hull"]) |
| 45 | + assert df.e_above_mp_hull.isna().sum() == 0 |
45 | 46 |
|
| 47 | + target_col = "e_form_target" |
46 | 48 | rare = "all"
|
47 | 49 |
|
48 | 50 | # from pymatgen.core import Composition
|
|
52 | 54 | # )
|
53 | 55 | # df = df.query("~contains_rare_earths")
|
54 | 56 |
|
55 |
| - e_above_mp_hull = df.e_above_mp_hull.to_numpy().ravel() |
56 |
| - |
57 |
| - # tar = df[tar_cols].to_numpy().ravel() - e_hull |
58 |
| - tar_f = df.filter(like="target").to_numpy().ravel() |
| 57 | + e_above_mp_hull = df.e_above_mp_hull |
59 | 58 |
|
60 |
| - # mean = np.average(pred, axis=0) - e_hull |
61 |
| - mean = df.filter(like="pred").mean(axis=1) - tar_f + e_above_mp_hull |
| 59 | + # mean = df.filter(like="pred").mean(axis=1) - e_hull |
| 60 | + mean = df.filter(like="pred").mean(axis=1) - df[target_col] + e_above_mp_hull |
62 | 61 |
|
63 |
| - # epistemic_std = np.var(pred, axis=0, ddof=0) |
| 62 | + # epistemic_var = df.filter(like="pred").var(axis=1, ddof=0) |
64 | 63 |
|
65 |
| - aleatoric_std = (df.filter(like="ale") ** 2).mean(axis=0) ** 0.5 |
| 64 | + # aleatoric_var = (df.filter(like="ale") ** 2).mean(axis=1) |
66 | 65 |
|
67 |
| - # full_std = np.sqrt(epistemic_std + aleatoric_std) |
| 66 | + # full_std = (epistemic_var + aleatoric_var) ** 0.5 |
68 | 67 |
|
69 | 68 | # crit = "std"
|
70 | 69 | # test = mean + full_std
|
|
85 | 84 | xticks = (-0.4, -0.2, 0, 0.2, 0.4)
|
86 | 85 | # yticks = (0, 300, 600, 900, 1200)
|
87 | 86 |
|
88 |
| - tp = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean <= thresh)]) |
89 |
| - fn = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean > thresh)]) |
| 87 | + n_true_pos = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean <= thresh)]) |
| 88 | + n_false_neg = len(e_above_mp_hull[(e_above_mp_hull <= thresh) & (mean > thresh)]) |
90 | 89 |
|
91 |
| - pos = tp + fn |
| 90 | + n_total_pos = n_true_pos + n_false_neg |
92 | 91 |
|
93 | 92 | sort = np.argsort(mean)
|
94 | 93 | e_above_mp_hull = e_above_mp_hull[sort]
|
95 | 94 | mean = mean[sort]
|
96 | 95 |
|
97 | 96 | e_type = "pred"
|
98 |
| - tp = np.asarray((e_above_mp_hull <= thresh) & (mean <= thresh)) |
99 |
| - fn = np.asarray((e_above_mp_hull <= thresh) & (mean > thresh)) |
100 |
| - fp = np.asarray((e_above_mp_hull > thresh) & (mean <= thresh)) |
101 |
| - tn = np.asarray((e_above_mp_hull > thresh) & (mean > thresh)) |
| 97 | + true_pos_cumsum = ((e_above_mp_hull <= thresh) & (mean <= thresh)).cumsum() |
| 98 | + false_neg_cumsum = ((e_above_mp_hull <= thresh) & (mean > thresh)).cumsum() |
| 99 | + false_pos_cumsum = ((e_above_mp_hull > thresh) & (mean <= thresh)).cumsum() |
| 100 | + true_neg_cumsum = ((e_above_mp_hull > thresh) & (mean > thresh)).cumsum() |
102 | 101 | xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
|
103 | 102 |
|
104 |
| - c_tp = np.cumsum(tp) |
105 |
| - c_fn = np.cumsum(fn) |
106 |
| - c_fp = np.cumsum(fp) |
107 |
| - c_tn = np.cumsum(tn) |
108 |
| - |
109 |
| - ppv = c_tp / (c_tp + c_fp) * 100 |
110 |
| - tpr = c_tp / pos * 100 |
| 103 | + ppv = true_pos_cumsum / (true_pos_cumsum + false_pos_cumsum) * 100 |
| 104 | + tpr = true_pos_cumsum / n_total_pos * 100 |
111 | 105 |
|
112 | 106 | end = np.argmax(tpr)
|
113 | 107 |
|
|
0 commit comments