Skip to content

Commit a111222

Browse files
committed
initial commit
0 parents  commit a111222

File tree

5 files changed

+313
-0
lines changed

5 files changed

+313
-0
lines changed

.gitignore

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# package install
2+
*.egg-info
3+
4+
# cache
5+
__pycache__
6+
.mypy
7+
.db_cache
8+
9+
# Fireworks config
10+
/FW_config.yaml
11+
12+
# Atomate workflow output
13+
fw_logs
14+
block*
15+
launcher*
16+
17+
# datasets
18+
*.json.gz
19+
*.json.bz2
20+
*.csv.bz2
21+
22+
# data files
23+
*.pkl*
24+
25+
# deploy config
26+
.netlify

.pre-commit-config.yaml

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
ci:
2+
autoupdate_schedule: quarterly
3+
4+
default_stages: [commit]
5+
6+
default_install_hook_types: [pre-commit, commit-msg]
7+
8+
repos:
9+
- repo: https://github.com/PyCQA/isort
10+
rev: 5.10.1
11+
hooks:
12+
- id: isort
13+
14+
- repo: https://github.com/psf/black
15+
rev: 22.3.0
16+
hooks:
17+
- id: black
18+
19+
- repo: https://github.com/pycqa/flake8
20+
rev: 4.0.1
21+
hooks:
22+
- id: flake8
23+
24+
- repo: https://github.com/asottile/pyupgrade
25+
rev: v2.31.1
26+
hooks:
27+
- id: pyupgrade
28+
args: [--py39-plus]
29+
30+
- repo: https://github.com/janosh/format-ipy-cells
31+
rev: v0.1.10
32+
hooks:
33+
- id: format-ipy-cells
34+
35+
- repo: https://github.com/pre-commit/pre-commit-hooks
36+
rev: v4.1.0
37+
hooks:
38+
- id: check-case-conflict
39+
- id: check-symlinks
40+
- id: check-yaml
41+
- id: destroyed-symlinks
42+
- id: end-of-file-fixer
43+
- id: mixed-line-ending
44+
- id: trailing-whitespace
45+
46+
- repo: https://github.com/pre-commit/mirrors-mypy
47+
rev: v0.942
48+
hooks:
49+
- id: mypy
50+
additional_dependencies: [types-pyyaml]
51+
52+
- repo: https://github.com/codespell-project/codespell
53+
rev: v2.1.0
54+
hooks:
55+
- id: codespell
56+
stages: [commit, commit-msg]
57+
exclude_types: [csv, html, json]
58+
59+
- repo: https://github.com/myint/autoflake
60+
rev: v1.4
61+
hooks:
62+
- id: autoflake
63+
args:
64+
- --in-place
65+
- --remove-unused-variables
66+
- --remove-all-unused-imports
67+
- --expand-star-imports
68+
- --ignore-init-module-imports

ml_stability/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from os.path import dirname
2+
3+
4+
PKG_DIR = dirname(__file__)
5+
ROOT = dirname(PKG_DIR)

ml_stability/hist_clf_vary.py

+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# %%
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
import pandas as pd
5+
from scipy.interpolate import interp1d
6+
7+
8+
plt.rcParams.update({"font.size": 20})
9+
10+
plt.rcParams["axes.linewidth"] = 2.5
11+
plt.rcParams["lines.linewidth"] = 3.5
12+
plt.rcParams["xtick.major.size"] = 7
13+
plt.rcParams["xtick.major.width"] = 2.5
14+
plt.rcParams["xtick.minor.size"] = 5
15+
plt.rcParams["xtick.minor.width"] = 2.5
16+
plt.rcParams["ytick.major.size"] = 7
17+
plt.rcParams["ytick.major.width"] = 2.5
18+
plt.rcParams["ytick.minor.size"] = 5
19+
plt.rcParams["ytick.minor.width"] = 2.5
20+
plt.rcParams["legend.fontsize"] = 20
21+
22+
fig, ax = plt.subplots(1, 1, figsize=(10, 9))
23+
24+
df_hull = pd.read_csv(
25+
f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/wbm_e_above_mp.csv",
26+
comment="#",
27+
na_filter=False,
28+
)
29+
30+
e_hull_dict = dict(zip(df_hull.material_id, df_hull.E_above_hull))
31+
32+
for name, c, a in zip(
33+
# ["wren", "cgcnn", "cgcnn-d"],
34+
# ["tab:blue", "tab:red", "tab:purple"],
35+
["wren", "voro", "cgcnn"],
36+
["tab:blue", "tab:orange", "tab:red"],
37+
[1, 0.8, 0.8],
38+
# ["wren", "cgcnn"],
39+
# ["tab:blue", "tab:red"],
40+
# [1, 0.8],
41+
):
42+
df = pd.read_csv(
43+
f"/home/reag2/PhD/aviary/examples/manuscript/new_figs/{name}-mp-init.csv",
44+
comment="#",
45+
na_filter=False,
46+
)
47+
48+
df["E_hull"] = pd.to_numeric(df["material_id"].map(e_hull_dict))
49+
50+
df = df.dropna(axis=0, subset=["E_hull"])
51+
52+
init = len(df)
53+
54+
rare = "all"
55+
56+
# rare = "nla"
57+
# df = df[
58+
# ~df["composition"].apply(
59+
# lambda x: any(el.is_rare_earth_metal for el in Composition(x).elements)
60+
# )
61+
# ]
62+
63+
# print(1-len(df)/init)
64+
65+
tar = df["E_hull"].to_numpy().ravel()
66+
67+
print(len(tar))
68+
69+
tar_cols = [col for col in df.columns if "target" in col]
70+
# tar = df[tar_cols].to_numpy().ravel() - e_hull
71+
tar_f = df[tar_cols].to_numpy().ravel()
72+
73+
pred_cols = [col for col in df.columns if "pred" in col]
74+
pred = df[pred_cols].to_numpy().T
75+
# mean = np.average(pred, axis=0) - e_hull
76+
mean = np.average(pred, axis=0) - tar_f + tar
77+
78+
epi = np.var(pred, axis=0, ddof=0)
79+
80+
ale_cols = [col for col in df.columns if "ale" in col]
81+
if len(ale_cols) > 0:
82+
ales = df[ale_cols].to_numpy().T
83+
ale = np.mean(np.square(ales), axis=0)
84+
else:
85+
ale = 0
86+
87+
both = np.sqrt(epi + ale)
88+
89+
# crit = "std"
90+
# test = mean + both
91+
92+
# crit = "neg"
93+
# test = mean - both
94+
95+
crit = "ene"
96+
test = mean
97+
98+
bins = 200
99+
# xlim = (-0.2, 0.2)
100+
xlim = (-0.4, 0.4)
101+
# xlim = (-1, 1)
102+
103+
alpha = 0.5
104+
# thresh = 0.02
105+
thresh = 0.00
106+
# thresh = 0.10
107+
xticks = (-0.4, -0.2, 0, 0.2, 0.4)
108+
# yticks = (0, 300, 600, 900, 1200)
109+
110+
tp = len(tar[(tar <= thresh) & (test <= thresh)])
111+
fn = len(tar[(tar <= thresh) & (test > thresh)])
112+
113+
pos = tp + fn
114+
null = pos / len(tar)
115+
116+
sort = np.argsort(test)
117+
tar = tar[sort]
118+
test = test[sort]
119+
120+
e_type = "pred"
121+
tp = np.asarray((tar <= thresh) & (test <= thresh))
122+
fn = np.asarray((tar <= thresh) & (test > thresh))
123+
fp = np.asarray((tar > thresh) & (test <= thresh))
124+
tn = np.asarray((tar > thresh) & (test > thresh))
125+
xlabel = (
126+
r"$\Delta$" + r"$\it{E}$" + r"$_{Hull-Pred}$" + " / eV per atom"
127+
) # r"$\/(\frac{eV}{atom})$"
128+
129+
# %%
130+
131+
c_tp = np.cumsum(tp)
132+
c_fn = np.cumsum(fn)
133+
c_fp = np.cumsum(fp)
134+
c_tn = np.cumsum(tn)
135+
136+
ppv = c_tp / (c_tp + c_fp) * 100
137+
tpr = c_tp / pos * 100
138+
139+
end = np.argmax(tpr)
140+
141+
x = np.arange(len(ppv))[:end]
142+
143+
f_ppv = interp1d(x, ppv[:end], kind="cubic")
144+
f_tpr = interp1d(x, tpr[:end], kind="cubic")
145+
146+
ax.plot(
147+
x[::100],
148+
f_tpr(x[::100]),
149+
linestyle=":",
150+
color=c,
151+
alpha=a,
152+
markevery=[-1],
153+
marker="x",
154+
markersize=14,
155+
mew=2.5,
156+
)
157+
158+
ax.plot(
159+
x[::100],
160+
f_ppv(x[::100]),
161+
linestyle="-",
162+
color=c,
163+
alpha=a,
164+
markevery=[-1],
165+
marker="x",
166+
markersize=14,
167+
mew=2.5,
168+
)
169+
170+
171+
# ax.set_xticks((0, 2.5e4, 5e4, 7.5e4))
172+
ax.set_xticks((0, 2e4, 4e4, 6e4, 8e4))
173+
174+
ax.set_ylabel("Percentage")
175+
ax.set_xlabel("Number of Calculations")
176+
177+
ax.set_xlim((0, 8e4))
178+
# ax.set_xlim((0, 75000))
179+
ax.set_ylim((0, 100))
180+
181+
ax.plot((-1, -1), (-1, -1), color="tab:blue")
182+
ax.plot((-1, -1), (-1, -1), color="tab:red")
183+
ax.plot((-1, -1), (-1, -1), color="tab:orange")
184+
185+
# ax.plot((-1, -1), (-1, -1), color="tab:purple")
186+
187+
ax.plot((-1, -1), (-1, -1), "k", linestyle="-")
188+
ax.plot((-1, -1), (-1, -1), "k", linestyle=":")
189+
190+
lines = ax.get_lines()
191+
192+
legend1 = ax.legend(
193+
lines[-2:], ["Precision", "Recall"], frameon=False, loc="upper right"
194+
)
195+
legend2 = ax.legend(
196+
lines[-5:-2],
197+
# ["Wren (This Work)", "CGCNN Pre-relax", "CGCNN-D Pre-relax"],
198+
["Wren (This Work)", "CGCNN Pre-relax", "Voronoi Pre-relax"],
199+
frameon=False,
200+
loc="lower right",
201+
)
202+
203+
ax.add_artist(legend1)
204+
# plt.gca().add_artist(legend1)
205+
206+
ax.set_aspect(1.0 / ax.get_data_ratio())
207+
208+
209+
fig.tight_layout()
210+
plt.savefig(f"examples/manuscript/new_figs/vary-{e_type}-{crit}-{rare}.pdf")
211+
# plt.savefig(f"examples/manuscript/pdf/vary-{e_type}-{crit}-{rare}.png")
212+
213+
plt.show()

readme.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# ML Stability

0 commit comments

Comments
 (0)