Skip to content

Commit 32a02d8

Browse files
committed
add scripts/roc_models.py
displayed on new route si/+page.md site/src/figs/2023-01-30-roc-models.svelte rename paper/+page.(svx->md)
1 parent f0363b4 commit 32a02d8

File tree

5 files changed

+149
-16
lines changed

5 files changed

+149
-16
lines changed

scripts/roc_models.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# %%
2+
import numpy as np
3+
import pandas as pd
4+
from pymatviz.utils import save_fig
5+
6+
from matbench_discovery import FIGS, today
7+
from matbench_discovery.data import load_df_wbm_preds
8+
from matbench_discovery.energy import stable_metrics
9+
from matbench_discovery.plots import pio
10+
11+
__author__ = "Janosh Riebesell"
12+
__date__ = "2023-01-30"
13+
14+
"""
15+
Histogram of the energy difference (either according to DFT ground truth [default] or
16+
model predicted energy) to the convex hull for materials in the WBM data set. The
17+
histogram stacks true/false positives/negatives with different colors.
18+
"""
19+
20+
pio.templates.default
21+
22+
23+
# %%
24+
models = sorted(
25+
"CGCNN, Voronoi Random Forest, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet".split(", ")
26+
)
27+
df_wbm = load_df_wbm_preds(models).round(3)
28+
29+
e_form_col = "e_form_per_atom_mp2020_corrected"
30+
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
31+
each_pred_col = "e_above_hull_pred"
32+
facet_col = "Model"
33+
color_col = "Stability Threshold"
34+
35+
36+
# %%
37+
df_roc = pd.DataFrame()
38+
39+
for model in models:
40+
df_wbm[f"{model}_{each_pred_col}"] = df_wbm[each_true_col] + (
41+
df_wbm[model] - df_wbm[e_form_col]
42+
)
43+
for stab_treshold in np.arange(-0.4, 0.4, 0.01):
44+
45+
metrics = stable_metrics(
46+
df_wbm[each_true_col], df_wbm[f"{model}_{each_pred_col}"], stab_treshold
47+
)
48+
df_tmp = pd.DataFrame(
49+
{facet_col: model, color_col: stab_treshold, **metrics}, index=[0]
50+
)
51+
df_roc = pd.concat([df_roc, df_tmp])
52+
53+
54+
df_roc = df_roc.round(3)
55+
56+
57+
# %%
58+
fig = df_roc.plot.scatter(
59+
x="FPR",
60+
y="TPR",
61+
facet_col=facet_col,
62+
facet_col_wrap=2,
63+
backend="plotly",
64+
height=800,
65+
color=color_col,
66+
range_x=(0, 1),
67+
range_y=(0, 1),
68+
)
69+
70+
for anno in fig.layout.annotations:
71+
anno.text = anno.text.split("=")[1] # remove Model= from subplot titles
72+
73+
fig.layout.coloraxis.colorbar.update(
74+
x=1,
75+
y=1,
76+
xanchor="right",
77+
yanchor="top",
78+
thickness=14,
79+
len=0.27,
80+
title_side="right",
81+
)
82+
fig.show()
83+
84+
85+
# %%
86+
save_fig(fig, f"{FIGS}/{today}-roc-models.svelte")

scripts/rolling_mae_vs_hull_dist_all_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
MAEs[MAE] = model
3232

3333
# sort df columns by MAE (so that the legend is sorted too)
34-
for MAE, model in sorted(MAEs.items()):
34+
for MAE, model in sorted(MAEs.items(), reverse=True):
3535
df_wbm[f"{model} {MAE=:.2f}"] = df_wbm[e_form_col] - df_wbm[model]
3636

3737
fig, df_err, df_std = rolling_mae_vs_hull_dist(
@@ -40,6 +40,7 @@
4040
backend=backend,
4141
with_sem=False,
4242
# template="plotly_white",
43+
height=800,
4344
)
4445

4546

0 commit comments

Comments
 (0)