Skip to content

Commit e61c8c7

Browse files
committed
convert old data and plot file paths from rhys to new repo
partly clean up commented out and verbose code
1 parent a111222 commit e61c8c7

File tree

2 files changed

+47
-92
lines changed

2 files changed

+47
-92
lines changed

ml_stability/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import os
12
from os.path import dirname
23

34

45
PKG_DIR = dirname(__file__)
56
ROOT = dirname(PKG_DIR)
7+
8+
os.makedirs(f"{PKG_DIR}/plots", exist_ok=True)

ml_stability/hist_clf_vary.py

+44-92
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,43 @@
11
# %%
2+
from datetime import datetime
3+
24
import matplotlib.pyplot as plt
35
import numpy as np
46
import pandas as pd
57
from scipy.interpolate import interp1d
68

9+
from ml_stability import ROOT
10+
11+
12+
__author__ = "Rhys Goodall, Janosh Riebesell"
13+
__date__ = "2022-06-18"
714

8-
plt.rcParams.update({"font.size": 20})
15+
today = f"{datetime.now():%Y-%m-%d}"
916

10-
plt.rcParams["axes.linewidth"] = 2.5
11-
plt.rcParams["lines.linewidth"] = 3.5
12-
plt.rcParams["xtick.major.size"] = 7
13-
plt.rcParams["xtick.major.width"] = 2.5
14-
plt.rcParams["xtick.minor.size"] = 5
15-
plt.rcParams["xtick.minor.width"] = 2.5
16-
plt.rcParams["ytick.major.size"] = 7
17-
plt.rcParams["ytick.major.width"] = 2.5
18-
plt.rcParams["ytick.minor.size"] = 5
19-
plt.rcParams["ytick.minor.width"] = 2.5
20-
plt.rcParams["legend.fontsize"] = 20
17+
plt.rc("font", size=18)
18+
plt.rc("savefig", bbox="tight", dpi=200)
19+
plt.rcParams["figure.constrained_layout.use"] = True
20+
plt.rc("figure", dpi=150, titlesize=20)
2121

22+
23+
# %%
2224
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
2325

24-
df_hull = pd.read_csv(
25-
f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/wbm_e_above_mp.csv",
26-
comment="#",
27-
na_filter=False,
28-
)
26+
df_hull = pd.read_csv(f"{ROOT}/data/wbm_e_above_mp.csv")
2927

30-
e_hull_dict = dict(zip(df_hull.material_id, df_hull.E_above_hull))
28+
e_hull_dict = dict(zip(df_hull.material_id, df_hull.e_above_hull))
3129

32-
for name, c, a in zip(
30+
for model_name, color in zip(
3331
# ["wren", "cgcnn", "cgcnn-d"],
3432
# ["tab:blue", "tab:red", "tab:purple"],
35-
["wren", "voro", "cgcnn"],
33+
["wren", "voronoi", "cgcnn"],
3634
["tab:blue", "tab:orange", "tab:red"],
37-
[1, 0.8, 0.8],
38-
# ["wren", "cgcnn"],
39-
# ["tab:blue", "tab:red"],
40-
# [1, 0.8],
4135
):
42-
df = pd.read_csv(
43-
f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/{name}-mp-init.csv",
44-
comment="#",
45-
na_filter=False,
46-
)
47-
48-
df["E_hull"] = pd.to_numeric(df["material_id"].map(e_hull_dict))
36+
df = pd.read_csv(f"{ROOT}/data/{model_name}-mp-initial-structures.csv")
4937

50-
df = df.dropna(axis=0, subset=["E_hull"])
38+
df["e_above_hull"] = pd.to_numeric(df["material_id"].map(e_hull_dict))
5139

52-
init = len(df)
40+
df = df.dropna(axis=0, subset=["e_above_hull"])
5341

5442
rare = "all"
5543

@@ -60,37 +48,25 @@
6048
# )
6149
# ]
6250

63-
# print(1-len(df)/init)
64-
65-
tar = df["E_hull"].to_numpy().ravel()
51+
e_above_hull = df.e_above_hull.to_numpy().ravel()
6652

67-
print(len(tar))
68-
69-
tar_cols = [col for col in df.columns if "target" in col]
7053
# tar = df[tar_cols].to_numpy().ravel() - e_hull
71-
tar_f = df[tar_cols].to_numpy().ravel()
54+
tar_f = df.filter(like="target").to_numpy().ravel()
7255

73-
pred_cols = [col for col in df.columns if "pred" in col]
74-
pred = df[pred_cols].to_numpy().T
7556
# mean = np.average(pred, axis=0) - e_hull
76-
mean = np.average(pred, axis=0) - tar_f + tar
57+
mean = df.filter(like="pred").T.mean(axis=0) - tar_f + e_above_hull
7758

78-
epi = np.var(pred, axis=0, ddof=0)
59+
# epistemic_std = np.var(pred, axis=0, ddof=0)
7960

80-
ale_cols = [col for col in df.columns if "ale" in col]
81-
if len(ale_cols) > 0:
82-
ales = df[ale_cols].to_numpy().T
83-
ale = np.mean(np.square(ales), axis=0)
84-
else:
85-
ale = 0
61+
# aleatoric_std = np.mean(np.square(df.filter(like="ale")), axis=0)
8662

87-
both = np.sqrt(epi + ale)
63+
# full_std = np.sqrt(epistemic_std + aleatoric_std)
8864

8965
# crit = "std"
90-
# test = mean + both
66+
# test = mean + full_std
9167

9268
# crit = "neg"
93-
# test = mean - both
69+
# test = mean - full_std
9470

9571
crit = "ene"
9672
test = mean
@@ -100,33 +76,27 @@
10076
xlim = (-0.4, 0.4)
10177
# xlim = (-1, 1)
10278

103-
alpha = 0.5
10479
# thresh = 0.02
10580
thresh = 0.00
10681
# thresh = 0.10
10782
xticks = (-0.4, -0.2, 0, 0.2, 0.4)
10883
# yticks = (0, 300, 600, 900, 1200)
10984

110-
tp = len(tar[(tar <= thresh) & (test <= thresh)])
111-
fn = len(tar[(tar <= thresh) & (test > thresh)])
85+
tp = len(e_above_hull[(e_above_hull <= thresh) & (test <= thresh)])
86+
fn = len(e_above_hull[(e_above_hull <= thresh) & (test > thresh)])
11287

11388
pos = tp + fn
114-
null = pos / len(tar)
11589

11690
sort = np.argsort(test)
117-
tar = tar[sort]
91+
e_above_hull = e_above_hull[sort]
11892
test = test[sort]
11993

12094
e_type = "pred"
121-
tp = np.asarray((tar <= thresh) & (test <= thresh))
122-
fn = np.asarray((tar <= thresh) & (test > thresh))
123-
fp = np.asarray((tar > thresh) & (test <= thresh))
124-
tn = np.asarray((tar > thresh) & (test > thresh))
125-
xlabel = (
126-
r"$\Delta$" + r"$\it{E}$" + r"$_{Hull-Pred}$" + " / eV per atom"
127-
) # r"$\/(\frac{eV}{atom})$"
128-
129-
# %%
95+
tp = np.asarray((e_above_hull <= thresh) & (test <= thresh))
96+
fn = np.asarray((e_above_hull <= thresh) & (test > thresh))
97+
fp = np.asarray((e_above_hull > thresh) & (test <= thresh))
98+
tn = np.asarray((e_above_hull > thresh) & (test > thresh))
99+
xlabel = r"$\Delta E_{Hull-Pred}$ / eV per atom"
130100

131101
c_tp = np.cumsum(tp)
132102
c_fn = np.cumsum(fn)
@@ -143,29 +113,12 @@
143113
f_ppv = interp1d(x, ppv[:end], kind="cubic")
144114
f_tpr = interp1d(x, tpr[:end], kind="cubic")
145115

146-
ax.plot(
147-
x[::100],
148-
f_tpr(x[::100]),
149-
linestyle=":",
150-
color=c,
151-
alpha=a,
152-
markevery=[-1],
153-
marker="x",
154-
markersize=14,
155-
mew=2.5,
116+
line_kwargs = dict(
117+
linewidth=2, color=color, markevery=[-1], marker="x", markersize=14, mew=2.5
156118
)
119+
ax.plot(x[::100], f_tpr(x[::100]), linestyle=":", **line_kwargs)
157120

158-
ax.plot(
159-
x[::100],
160-
f_ppv(x[::100]),
161-
linestyle="-",
162-
color=c,
163-
alpha=a,
164-
markevery=[-1],
165-
marker="x",
166-
markersize=14,
167-
mew=2.5,
168-
)
121+
ax.plot(x[::100], f_ppv(x[::100]), linestyle="-", **line_kwargs)
169122

170123

171124
# ax.set_xticks((0, 2.5e4, 5e4, 7.5e4))
@@ -206,8 +159,7 @@
206159
ax.set_aspect(1.0 / ax.get_data_ratio())
207160

208161

209-
fig.tight_layout()
210-
plt.savefig(f"examples/manuscript/new_figs/vary-{e_type}-{crit}-{rare}.pdf")
211-
# plt.savefig(f"examples/manuscript/pdf/vary-{e_type}-{crit}-{rare}.png")
162+
# plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.pdf")
163+
# # plt.savefig(f"{PKG_DIR}/plots/{today}-vary-{e_type}-{crit}-{rare}.png")
212164

213-
plt.show()
165+
# plt.show()

0 commit comments

Comments
 (0)