|
1 | 1 | # %%
|
| 2 | +from datetime import datetime |
| 3 | + |
2 | 4 | import matplotlib.pyplot as plt
|
3 | 5 | import numpy as np
|
4 | 6 | import pandas as pd
|
5 | 7 | from scipy.interpolate import interp1d
|
6 | 8 |
|
| 9 | +from ml_stability import ROOT |
| 10 | + |
| 11 | + |
| 12 | +__author__ = "Rhys Goodall, Janosh Riebesell" |
| 13 | +__date__ = "2022-06-18" |
7 | 14 |
|
8 |
| -plt.rcParams.update({"font.size": 20}) |
| 15 | +today = f"{datetime.now():%Y-%m-%d}" |
9 | 16 |
|
10 |
| -plt.rcParams["axes.linewidth"] = 2.5 |
11 |
| -plt.rcParams["lines.linewidth"] = 3.5 |
12 |
| -plt.rcParams["xtick.major.size"] = 7 |
13 |
| -plt.rcParams["xtick.major.width"] = 2.5 |
14 |
| -plt.rcParams["xtick.minor.size"] = 5 |
15 |
| -plt.rcParams["xtick.minor.width"] = 2.5 |
16 |
| -plt.rcParams["ytick.major.size"] = 7 |
17 |
| -plt.rcParams["ytick.major.width"] = 2.5 |
18 |
| -plt.rcParams["ytick.minor.size"] = 5 |
19 |
| -plt.rcParams["ytick.minor.width"] = 2.5 |
20 |
| -plt.rcParams["legend.fontsize"] = 20 |
| 17 | +plt.rc("font", size=18) |
| 18 | +plt.rc("savefig", bbox="tight", dpi=200) |
| 19 | +plt.rcParams["figure.constrained_layout.use"] = True |
| 20 | +plt.rc("figure", dpi=150, titlesize=20) |
21 | 21 |
|
| 22 | + |
| 23 | +# %% |
22 | 24 | fig, ax = plt.subplots(1, 1, figsize=(10, 9))
|
23 | 25 |
|
24 |
| -df_hull = pd.read_csv( |
25 |
| - f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/wbm_e_above_mp.csv", |
26 |
| - comment="#", |
27 |
| - na_filter=False, |
28 |
| -) |
| 26 | +df_hull = pd.read_csv(f"{ROOT}/data/wbm_e_above_mp.csv") |
29 | 27 |
|
30 |
| -e_hull_dict = dict(zip(df_hull.material_id, df_hull.E_above_hull)) |
| 28 | +e_hull_dict = dict(zip(df_hull.material_id, df_hull.e_above_hull)) |
31 | 29 |
|
32 |
| -for name, c, a in zip( |
| 30 | +for model_name, color in zip( |
33 | 31 | # ["wren", "cgcnn", "cgcnn-d"],
|
34 | 32 | # ["tab:blue", "tab:red", "tab:purple"],
|
35 |
| - ["wren", "voro", "cgcnn"], |
| 33 | + ["wren", "voronoi", "cgcnn"], |
36 | 34 | ["tab:blue", "tab:orange", "tab:red"],
|
37 |
| - [1, 0.8, 0.8], |
38 |
| - # ["wren", "cgcnn"], |
39 |
| - # ["tab:blue", "tab:red"], |
40 |
| - # [1, 0.8], |
41 | 35 | ):
|
42 |
| - df = pd.read_csv( |
43 |
| - f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/{name}-mp-init.csv", |
44 |
| - comment="#", |
45 |
| - na_filter=False, |
46 |
| - ) |
47 |
| - |
48 |
| - df["E_hull"] = pd.to_numeric(df["material_id"].map(e_hull_dict)) |
| 36 | + df = pd.read_csv(f"{ROOT}/data/{model_name}-mp-initial-structures.csv") |
49 | 37 |
|
50 |
| - df = df.dropna(axis=0, subset=["E_hull"]) |
| 38 | + df["e_above_hull"] = pd.to_numeric(df["material_id"].map(e_hull_dict)) |
51 | 39 |
|
52 |
| - init = len(df) |
| 40 | + df = df.dropna(axis=0, subset=["e_above_hull"]) |
53 | 41 |
|
54 | 42 | rare = "all"
|
55 | 43 |
|
|
60 | 48 | # )
|
61 | 49 | # ]
|
62 | 50 |
|
63 |
| - # print(1-len(df)/init) |
64 |
| - |
65 |
| - tar = df["E_hull"].to_numpy().ravel() |
| 51 | + e_above_hull = df.e_above_hull.to_numpy().ravel() |
66 | 52 |
|
67 |
| - print(len(tar)) |
68 |
| - |
69 |
| - tar_cols = [col for col in df.columns if "target" in col] |
70 | 53 | # tar = df[tar_cols].to_numpy().ravel() - e_hull
|
71 |
| - tar_f = df[tar_cols].to_numpy().ravel() |
| 54 | + tar_f = df.filter(like="target").to_numpy().ravel() |
72 | 55 |
|
73 |
| - pred_cols = [col for col in df.columns if "pred" in col] |
74 |
| - pred = df[pred_cols].to_numpy().T |
75 | 56 | # mean = np.average(pred, axis=0) - e_hull
|
76 |
| - mean = np.average(pred, axis=0) - tar_f + tar |
| 57 | + mean = df.filter(like="pred").T.mean(axis=0) - tar_f + e_above_hull |
77 | 58 |
|
78 |
| - epi = np.var(pred, axis=0, ddof=0) |
| 59 | + # epistemic_std = np.var(pred, axis=0, ddof=0) |
79 | 60 |
|
80 |
| - ale_cols = [col for col in df.columns if "ale" in col] |
81 |
| - if len(ale_cols) > 0: |
82 |
| - ales = df[ale_cols].to_numpy().T |
83 |
| - ale = np.mean(np.square(ales), axis=0) |
84 |
| - else: |
85 |
| - ale = 0 |
| 61 | + # aleatoric_std = np.mean(np.square(df.filter(like="ale")), axis=0) |
86 | 62 |
|
87 |
| - both = np.sqrt(epi + ale) |
| 63 | + # full_std = np.sqrt(epistemic_std + aleatoric_std) |
88 | 64 |
|
89 | 65 | # crit = "std"
|
90 |
| - # test = mean + both |
| 66 | + # test = mean + full_std |
91 | 67 |
|
92 | 68 | # crit = "neg"
|
93 |
| - # test = mean - both |
| 69 | + # test = mean - full_std |
94 | 70 |
|
95 | 71 | crit = "ene"
|
96 | 72 | test = mean
|
|
100 | 76 | xlim = (-0.4, 0.4)
|
101 | 77 | # xlim = (-1, 1)
|
102 | 78 |
|
103 |
| - alpha = 0.5 |
104 | 79 | # thresh = 0.02
|
105 | 80 | thresh = 0.00
|
106 | 81 | # thresh = 0.10
|
107 | 82 | xticks = (-0.4, -0.2, 0, 0.2, 0.4)
|
108 | 83 | # yticks = (0, 300, 600, 900, 1200)
|
109 | 84 |
|
110 |
| - tp = len(tar[(tar <= thresh) & (test <= thresh)]) |
111 |
| - fn = len(tar[(tar <= thresh) & (test > thresh)]) |
| 85 | + tp = len(e_above_hull[(e_above_hull <= thresh) & (test <= thresh)]) |
| 86 | + fn = len(e_above_hull[(e_above_hull <= thresh) & (test > thresh)]) |
112 | 87 |
|
113 | 88 | pos = tp + fn
|
114 |
| - null = pos / len(tar) |
115 | 89 |
|
116 | 90 | sort = np.argsort(test)
|
117 |
| - tar = tar[sort] |
| 91 | + e_above_hull = e_above_hull[sort] |
118 | 92 | test = test[sort]
|
119 | 93 |
|
120 | 94 | e_type = "pred"
|
121 |
| - tp = np.asarray((tar <= thresh) & (test <= thresh)) |
122 |
| - fn = np.asarray((tar <= thresh) & (test > thresh)) |
123 |
| - fp = np.asarray((tar > thresh) & (test <= thresh)) |
124 |
| - tn = np.asarray((tar > thresh) & (test > thresh)) |
125 |
| - xlabel = ( |
126 |
| - r"$\Delta$" + r"$\it{E}$" + r"$_{Hull-Pred}$" + " / eV per atom" |
127 |
| - ) # r"$\/(\frac{eV}{atom})$" |
128 |
| - |
129 |
| - # %% |
| 95 | + tp = np.asarray((e_above_hull <= thresh) & (test <= thresh)) |
| 96 | + fn = np.asarray((e_above_hull <= thresh) & (test > thresh)) |
| 97 | + fp = np.asarray((e_above_hull > thresh) & (test <= thresh)) |
| 98 | + tn = np.asarray((e_above_hull > thresh) & (test > thresh)) |
| 99 | + xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom" |
130 | 100 |
|
131 | 101 | c_tp = np.cumsum(tp)
|
132 | 102 | c_fn = np.cumsum(fn)
|
|
143 | 113 | f_ppv = interp1d(x, ppv[:end], kind="cubic")
|
144 | 114 | f_tpr = interp1d(x, tpr[:end], kind="cubic")
|
145 | 115 |
|
146 |
| - ax.plot( |
147 |
| - x[::100], |
148 |
| - f_tpr(x[::100]), |
149 |
| - linestyle=":", |
150 |
| - color=c, |
151 |
| - alpha=a, |
152 |
| - markevery=[-1], |
153 |
| - marker="x", |
154 |
| - markersize=14, |
155 |
| - mew=2.5, |
| 116 | + line_kwargs = dict( |
| 117 | + linewidth=2, color=color, markevery=[-1], marker="x", markersize=14, mew=2.5 |
156 | 118 | )
|
| 119 | + ax.plot(x[::100], f_tpr(x[::100]), linestyle=":", **line_kwargs) |
157 | 120 |
|
158 |
| - ax.plot( |
159 |
| - x[::100], |
160 |
| - f_ppv(x[::100]), |
161 |
| - linestyle="-", |
162 |
| - color=c, |
163 |
| - alpha=a, |
164 |
| - markevery=[-1], |
165 |
| - marker="x", |
166 |
| - markersize=14, |
167 |
| - mew=2.5, |
168 |
| - ) |
| 121 | + ax.plot(x[::100], f_ppv(x[::100]), linestyle="-", **line_kwargs) |
169 | 122 |
|
170 | 123 |
|
171 | 124 | # ax.set_xticks((0, 2.5e4, 5e4, 7.5e4))
|
|
206 | 159 | ax.set_aspect(1.0 / ax.get_data_ratio())
|
207 | 160 |
|
208 | 161 |
|
209 |
| -fig.tight_layout() |
210 |
| -plt.savefig(f"examples/manuscript/new_figs/vary-{e_type}-{crit}-{rare}.pdf") |
211 |
| -# plt.savefig(f"examples/manuscript/pdf/vary-{e_type}-{crit}-{rare}.png") |
| 162 | +# plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.pdf") |
| 163 | +# # plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.png") |
212 | 164 |
|
213 |
| -plt.show() |
| 165 | +# plt.show() |
0 commit comments