|
| 1 | +# %% |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +from scipy.interpolate import interp1d |
| 6 | + |
| 7 | + |
| 8 | +plt.rcParams.update({"font.size": 20}) |
| 9 | + |
| 10 | +plt.rcParams["axes.linewidth"] = 2.5 |
| 11 | +plt.rcParams["lines.linewidth"] = 3.5 |
| 12 | +plt.rcParams["xtick.major.size"] = 7 |
| 13 | +plt.rcParams["xtick.major.width"] = 2.5 |
| 14 | +plt.rcParams["xtick.minor.size"] = 5 |
| 15 | +plt.rcParams["xtick.minor.width"] = 2.5 |
| 16 | +plt.rcParams["ytick.major.size"] = 7 |
| 17 | +plt.rcParams["ytick.major.width"] = 2.5 |
| 18 | +plt.rcParams["ytick.minor.size"] = 5 |
| 19 | +plt.rcParams["ytick.minor.width"] = 2.5 |
| 20 | +plt.rcParams["legend.fontsize"] = 20 |
| 21 | + |
| 22 | +fig, ax = plt.subplots(1, 1, figsize=(10, 9)) |
| 23 | + |
| 24 | +df_hull = pd.read_csv( |
| 25 | + f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/wbm_e_above_mp.csv", |
| 26 | + comment="#", |
| 27 | + na_filter=False, |
| 28 | +) |
| 29 | + |
| 30 | +e_hull_dict = dict(zip(df_hull.material_id, df_hull.E_above_hull)) |
| 31 | + |
| 32 | +for name, c, a in zip( |
| 33 | + # ["wren", "cgcnn", "cgcnn-d"], |
| 34 | + # ["tab:blue", "tab:red", "tab:purple"], |
| 35 | + ["wren", "voro", "cgcnn"], |
| 36 | + ["tab:blue", "tab:orange", "tab:red"], |
| 37 | + [1, 0.8, 0.8], |
| 38 | + # ["wren", "cgcnn"], |
| 39 | + # ["tab:blue", "tab:red"], |
| 40 | + # [1, 0.8], |
| 41 | +): |
| 42 | + df = pd.read_csv( |
| 43 | + f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/{name}-mp-init.csv", |
| 44 | + comment="#", |
| 45 | + na_filter=False, |
| 46 | + ) |
| 47 | + |
| 48 | + df["E_hull"] = pd.to_numeric(df["material_id"].map(e_hull_dict)) |
| 49 | + |
| 50 | + df = df.dropna(axis=0, subset=["E_hull"]) |
| 51 | + |
| 52 | + init = len(df) |
| 53 | + |
| 54 | + rare = "all" |
| 55 | + |
| 56 | + # rare = "nla" |
| 57 | + # df = df[ |
| 58 | + # ~df["composition"].apply( |
| 59 | + # lambda x: any(el.is_rare_earth_metal for el in Composition(x).elements) |
| 60 | + # ) |
| 61 | + # ] |
| 62 | + |
| 63 | + # print(1-len(df)/init) |
| 64 | + |
| 65 | + tar = df["E_hull"].to_numpy().ravel() |
| 66 | + |
| 67 | + print(len(tar)) |
| 68 | + |
| 69 | + tar_cols = [col for col in df.columns if "target" in col] |
| 70 | + # tar = df[tar_cols].to_numpy().ravel() - e_hull |
| 71 | + tar_f = df[tar_cols].to_numpy().ravel() |
| 72 | + |
| 73 | + pred_cols = [col for col in df.columns if "pred" in col] |
| 74 | + pred = df[pred_cols].to_numpy().T |
| 75 | + # mean = np.average(pred, axis=0) - e_hull |
| 76 | + mean = np.average(pred, axis=0) - tar_f + tar |
| 77 | + |
| 78 | + epi = np.var(pred, axis=0, ddof=0) |
| 79 | + |
| 80 | + ale_cols = [col for col in df.columns if "ale" in col] |
| 81 | + if len(ale_cols) > 0: |
| 82 | + ales = df[ale_cols].to_numpy().T |
| 83 | + ale = np.mean(np.square(ales), axis=0) |
| 84 | + else: |
| 85 | + ale = 0 |
| 86 | + |
| 87 | + both = np.sqrt(epi + ale) |
| 88 | + |
| 89 | + # crit = "std" |
| 90 | + # test = mean + both |
| 91 | + |
| 92 | + # crit = "neg" |
| 93 | + # test = mean - both |
| 94 | + |
| 95 | + crit = "ene" |
| 96 | + test = mean |
| 97 | + |
| 98 | + bins = 200 |
| 99 | + # xlim = (-0.2, 0.2) |
| 100 | + xlim = (-0.4, 0.4) |
| 101 | + # xlim = (-1, 1) |
| 102 | + |
| 103 | + alpha = 0.5 |
| 104 | + # thresh = 0.02 |
| 105 | + thresh = 0.00 |
| 106 | + # thresh = 0.10 |
| 107 | + xticks = (-0.4, -0.2, 0, 0.2, 0.4) |
| 108 | + # yticks = (0, 300, 600, 900, 1200) |
| 109 | + |
| 110 | + tp = len(tar[(tar <= thresh) & (test <= thresh)]) |
| 111 | + fn = len(tar[(tar <= thresh) & (test > thresh)]) |
| 112 | + |
| 113 | + pos = tp + fn |
| 114 | + null = pos / len(tar) |
| 115 | + |
| 116 | + sort = np.argsort(test) |
| 117 | + tar = tar[sort] |
| 118 | + test = test[sort] |
| 119 | + |
| 120 | + e_type = "pred" |
| 121 | + tp = np.asarray((tar <= thresh) & (test <= thresh)) |
| 122 | + fn = np.asarray((tar <= thresh) & (test > thresh)) |
| 123 | + fp = np.asarray((tar > thresh) & (test <= thresh)) |
| 124 | + tn = np.asarray((tar > thresh) & (test > thresh)) |
| 125 | + xlabel = ( |
| 126 | + r"$\Delta$" + r"$\it{E}$" + r"$_{Hull-Pred}$" + " / eV per atom" |
| 127 | + ) # r"$\/(\frac{eV}{atom})$" |
| 128 | + |
| 129 | + # %% |
| 130 | + |
| 131 | + c_tp = np.cumsum(tp) |
| 132 | + c_fn = np.cumsum(fn) |
| 133 | + c_fp = np.cumsum(fp) |
| 134 | + c_tn = np.cumsum(tn) |
| 135 | + |
| 136 | + ppv = c_tp / (c_tp + c_fp) * 100 |
| 137 | + tpr = c_tp / pos * 100 |
| 138 | + |
| 139 | + end = np.argmax(tpr) |
| 140 | + |
| 141 | + x = np.arange(len(ppv))[:end] |
| 142 | + |
| 143 | + f_ppv = interp1d(x, ppv[:end], kind="cubic") |
| 144 | + f_tpr = interp1d(x, tpr[:end], kind="cubic") |
| 145 | + |
| 146 | + ax.plot( |
| 147 | + x[::100], |
| 148 | + f_tpr(x[::100]), |
| 149 | + linestyle=":", |
| 150 | + color=c, |
| 151 | + alpha=a, |
| 152 | + markevery=[-1], |
| 153 | + marker="x", |
| 154 | + markersize=14, |
| 155 | + mew=2.5, |
| 156 | + ) |
| 157 | + |
| 158 | + ax.plot( |
| 159 | + x[::100], |
| 160 | + f_ppv(x[::100]), |
| 161 | + linestyle="-", |
| 162 | + color=c, |
| 163 | + alpha=a, |
| 164 | + markevery=[-1], |
| 165 | + marker="x", |
| 166 | + markersize=14, |
| 167 | + mew=2.5, |
| 168 | + ) |
| 169 | + |
| 170 | + |
| 171 | +# ax.set_xticks((0, 2.5e4, 5e4, 7.5e4)) |
| 172 | +ax.set_xticks((0, 2e4, 4e4, 6e4, 8e4)) |
| 173 | + |
| 174 | +ax.set_ylabel("Percentage") |
| 175 | +ax.set_xlabel("Number of Calculations") |
| 176 | + |
| 177 | +ax.set_xlim((0, 8e4)) |
| 178 | +# ax.set_xlim((0, 75000)) |
| 179 | +ax.set_ylim((0, 100)) |
| 180 | + |
| 181 | +ax.plot((-1, -1), (-1, -1), color="tab:blue") |
| 182 | +ax.plot((-1, -1), (-1, -1), color="tab:red") |
| 183 | +ax.plot((-1, -1), (-1, -1), color="tab:orange") |
| 184 | + |
| 185 | +# ax.plot((-1, -1), (-1, -1), color="tab:purple") |
| 186 | + |
| 187 | +ax.plot((-1, -1), (-1, -1), "k", linestyle="-") |
| 188 | +ax.plot((-1, -1), (-1, -1), "k", linestyle=":") |
| 189 | + |
| 190 | +lines = ax.get_lines() |
| 191 | + |
| 192 | +legend1 = ax.legend( |
| 193 | + lines[-2:], ["Precision", "Recall"], frameon=False, loc="upper right" |
| 194 | +) |
| 195 | +legend2 = ax.legend( |
| 196 | + lines[-5:-2], |
| 197 | + # ["Wren (This Work)", "CGCNN Pre-relax", "CGCNN-D Pre-relax"], |
| 198 | + ["Wren (This Work)", "CGCNN Pre-relax", "Voronoi Pre-relax"], |
| 199 | + frameon=False, |
| 200 | + loc="lower right", |
| 201 | +) |
| 202 | + |
| 203 | +ax.add_artist(legend1) |
| 204 | +# plt.gca().add_artist(legend1) |
| 205 | + |
| 206 | +ax.set_aspect(1.0 / ax.get_data_ratio()) |
| 207 | + |
| 208 | + |
| 209 | +fig.tight_layout() |
| 210 | +plt.savefig(f"examples/manuscript/new_figs/vary-{e_type}-{crit}-{rare}.pdf") |
| 211 | +# plt.savefig(f"examples/manuscript/pdf/vary-{e_type}-{crit}-{rare}.png") |
| 212 | + |
| 213 | +plt.show() |
0 commit comments