|
6 | 6 | import pandas as pd
|
7 | 7 | from scipy.interpolate import interp1d
|
8 | 8 |
|
9 |
| -from ml_stability import ROOT |
| 9 | +from ml_stability import PKG_DIR, ROOT |
10 | 10 |
|
11 | 11 |
|
12 | 12 | __author__ = "Rhys Goodall, Janosh Riebesell"
|
|
34 | 34 | ["wren", "voronoi", "cgcnn"],
|
35 | 35 | ["tab:blue", "tab:orange", "tab:red"],
|
36 | 36 | ):
|
37 |
| - df = pd.read_csv( |
| 37 | + data_path = ( |
38 | 38 | f"{ROOT}/data/2022-06-11-from-rhys/{model_name}-mp-initial-structures.csv"
|
39 | 39 | )
|
40 |
| - df = df.set_index("material_id") |
| 40 | + df = pd.read_csv(data_path).set_index("material_id") |
41 | 41 |
|
42 | 42 | df["e_above_hull"] = df_hull.e_above_hull
|
43 | 43 |
|
44 | 44 | df = df.dropna(subset=["e_above_hull"])
|
45 | 45 |
|
46 | 46 | rare = "all"
|
47 | 47 |
|
48 |
| - # rare = "nla" |
49 |
| - # df = df[ |
50 |
| - # ~df["composition"].apply( |
51 |
| - # lambda x: any(el.is_rare_earth_metal for el in Composition(x).elements) |
52 |
| - # ) |
53 |
| - # ] |
| 48 | + # from pymatgen.core import Composition |
| 49 | + # rare = "no-lanthanides" |
| 50 | + # df["contains_rare_earths"] = df.composition.map( |
| 51 | + # lambda x: any(el.is_rare_earth_metal for el in Composition(x)) |
| 52 | + # ) |
| 53 | + # df = df.query("~contains_rare_earths") |
54 | 54 |
|
55 | 55 | e_above_hull = df.e_above_hull.to_numpy().ravel()
|
56 | 56 |
|
|
62 | 62 |
|
63 | 63 | # epistemic_std = np.var(pred, axis=0, ddof=0)
|
64 | 64 |
|
65 |
| - # aleatoric_std = np.mean(np.square(df.filter(like="ale")), axis=0) |
| 65 | + aleatoric_std = (df.filter(like="ale") ** 2).mean(axis=0) ** 0.5 |
66 | 66 |
|
67 | 67 | # full_std = np.sqrt(epistemic_std + aleatoric_std)
|
68 | 68 |
|
|
73 | 73 | # test = mean - full_std
|
74 | 74 |
|
75 | 75 | crit = "ene"
|
76 |
| - test = mean |
77 | 76 |
|
78 | 77 | bins = 200
|
79 | 78 | # xlim = (-0.2, 0.2)
|
|
86 | 85 | xticks = (-0.4, -0.2, 0, 0.2, 0.4)
|
87 | 86 | # yticks = (0, 300, 600, 900, 1200)
|
88 | 87 |
|
89 |
| - tp = len(e_above_hull[(e_above_hull <= thresh) & (test <= thresh)]) |
90 |
| - fn = len(e_above_hull[(e_above_hull <= thresh) & (test > thresh)]) |
| 88 | + tp = len(e_above_hull[(e_above_hull <= thresh) & (mean <= thresh)]) |
| 89 | + fn = len(e_above_hull[(e_above_hull <= thresh) & (mean > thresh)]) |
91 | 90 |
|
92 | 91 | pos = tp + fn
|
93 | 92 |
|
94 |
| - sort = np.argsort(test) |
| 93 | + sort = np.argsort(mean) |
95 | 94 | e_above_hull = e_above_hull[sort]
|
96 |
| - test = test[sort] |
| 95 | + mean = mean[sort] |
97 | 96 |
|
98 | 97 | e_type = "pred"
|
99 |
| - tp = np.asarray((e_above_hull <= thresh) & (test <= thresh)) |
100 |
| - fn = np.asarray((e_above_hull <= thresh) & (test > thresh)) |
101 |
| - fp = np.asarray((e_above_hull > thresh) & (test <= thresh)) |
102 |
| - tn = np.asarray((e_above_hull > thresh) & (test > thresh)) |
| 98 | + tp = np.asarray((e_above_hull <= thresh) & (mean <= thresh)) |
| 99 | + fn = np.asarray((e_above_hull <= thresh) & (mean > thresh)) |
| 100 | + fp = np.asarray((e_above_hull > thresh) & (mean <= thresh)) |
| 101 | + tn = np.asarray((e_above_hull > thresh) & (mean > thresh)) |
103 | 102 | xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
|
104 | 103 |
|
105 | 104 | c_tp = np.cumsum(tp)
|
|
114 | 113 |
|
115 | 114 | x = np.arange(len(ppv))[:end]
|
116 | 115 |
|
117 |
| - f_ppv = interp1d(x, ppv[:end], kind="cubic") |
118 |
| - f_tpr = interp1d(x, tpr[:end], kind="cubic") |
| 116 | + precision_curve = interp1d(x, ppv[:end], kind="cubic") |
| 117 | + rolling_recall_curve = interp1d(x, tpr[:end], kind="cubic") |
119 | 118 |
|
120 | 119 | line_kwargs = dict(
|
121 | 120 | linewidth=2, color=color, markevery=[-1], marker="x", markersize=14, mew=2.5
|
122 | 121 | )
|
123 |
| - ax.plot(x[::100], f_tpr(x[::100]), linestyle=":", **line_kwargs) |
124 |
| - |
125 |
| - ax.plot(x[::100], f_ppv(x[::100]), linestyle="-", **line_kwargs) |
| 122 | + ax.plot(x[::100], precision_curve(x[::100]), linestyle="-", **line_kwargs) |
| 123 | + ax.plot(x[::100], rolling_recall_curve(x[::100]), linestyle=":", **line_kwargs) |
126 | 124 |
|
127 | 125 |
|
128 | 126 | ax.set(xlabel="Number of Calculations", ylabel="Percentage")
|
129 | 127 |
|
130 |
| -ax.set(xlim=(0, 8e4), ylim=(0, 100), xticks=(0, 2e4, 4e4, 6e4, 8e4)) |
131 |
| - |
132 |
| -ax.plot((-1, -1), (-1, -1), color="tab:blue") |
133 |
| -ax.plot((-1, -1), (-1, -1), color="tab:red") |
134 |
| -ax.plot((-1, -1), (-1, -1), color="tab:orange") |
135 |
| - |
136 |
| -# ax.plot((-1, -1), (-1, -1), color="tab:purple") |
137 |
| - |
138 |
| -ax.plot((-1, -1), (-1, -1), "k", linestyle="-") |
139 |
| -ax.plot((-1, -1), (-1, -1), "k", linestyle=":") |
140 |
| - |
141 |
| -lines = ax.get_lines() |
142 |
| - |
143 |
| -legend1 = ax.legend( |
144 |
| - lines[-2:], ["Precision", "Recall"], frameon=False, loc="upper right" |
145 |
| -) |
146 |
| -legend2 = ax.legend( |
147 |
| - lines[-5:-2], |
148 |
| - # ["Wren (This Work)", "CGCNN Pre-relax", "CGCNN-D Pre-relax"], |
149 |
| - ["Wren (This Work)", "CGCNN Pre-relax", "Voronoi Pre-relax"], |
150 |
| - frameon=False, |
151 |
| - loc="lower right", |
152 |
| -) |
153 |
| - |
154 |
| -ax.add_artist(legend1) |
155 |
| - |
156 |
| -ax.set_aspect(1.0 / ax.get_data_ratio()) |
| 128 | +xlim = (0, 8e4) |
| 129 | +ax.set(xlim=xlim, ylim=(0, 100), xticks=np.linspace(*xlim, 5)) |
157 | 130 |
|
| 131 | +ax.plot((0, 0), (0, 0), color="tab:blue", label="Wren (This Work)") |
| 132 | +ax.plot((0, 0), (0, 0), color="tab:red", label="CGCNN Pre-relax") |
| 133 | +ax.plot((0, 0), (0, 0), color="tab:orange", label="Voronoi Pre-relax") |
| 134 | +legend_1 = ax.legend(frameon=False, loc="lower right") |
| 135 | +ax.add_artist(legend_1) |
158 | 136 |
|
159 |
| -# plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.pdf") |
160 |
| -# # plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.png") |
| 137 | +[prec] = ax.plot((0, 0), (0, 0), "black", linestyle="-") |
| 138 | +[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":") |
| 139 | +ax.legend([prec, recall], ["Precision", "Recall"], frameon=False, loc="upper right") |
161 | 140 |
|
162 |
| -# plt.show() |
| 141 | +img_path = f"{PKG_DIR}/plots/{today}-vary-{e_type=}-{crit=}-{rare=}.pdf" |
| 142 | +# plt.savefig(img_path) |
0 commit comments