Skip to content

Commit a42472c

Browse files
committed
fix scripts/prc_roc_curves_models.py and roc-models.svelte fig
sth wrong in stable_metrics() maybe?
1 parent 5d98946 commit a42472c

10 files changed

+119
-100
lines changed

data/mp/get_mp_energies.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
df["wyckoff_spglib"] = [get_aflow_label_from_spglib(x) for x in tqdm(df.structure)]
5959

6060
df.reset_index().to_json(
61-
f"{module_dir}/{today}-mp-energies.json.gz", default_handler=as_dict_handler
61+
f"{module_dir}/mp-energies.json.gz", default_handler=as_dict_handler
6262
)
6363

6464
# df = pd.read_json(f"{module_dir}/2022-08-13-mp-energies.json.gz")
@@ -78,9 +78,9 @@
7878
)
7979

8080
annotate_mae_r2(df.formation_energy_per_atom, df.decomposition_enthalpy)
81-
# result on 2023-01-10: plots match. no correlation between formation energy and decomposition
82-
# enthalpy. R^2 = -1.571, MAE = 1.604
83-
# ax.figure.savefig(f"{module_dir}/{today}-mp-decomp-enth-vs-e-form.webp", dpi=300)
81+
# result on 2023-01-10: plots match. no correlation between formation energy and
82+
# decomposition enthalpy. R^2 = -1.571, MAE = 1.604
83+
# ax.figure.savefig(f"{module_dir}/mp-decomp-enth-vs-e-form.webp", dpi=300)
8484

8585

8686
# %% scatter plot energy above convex hull vs decomposition enthalpy
@@ -99,4 +99,4 @@
9999
title=f"{n_above_line:,} / {len(df):,} = {n_above_line/len(df):.1%} "
100100
"MP materials with\nenergy_above_hull - decomposition_enthalpy.clip(0) > 0.1"
101101
)
102-
# ax.figure.savefig(f"{module_dir}/{today}-mp-e-above-hull-vs-decomp-enth.webp", dpi=300)
102+
# ax.figure.savefig(f"{module_dir}/mp-e-above-hull-vs-decomp-enth.webp", dpi=300)

matbench_discovery/plots.py

-7
Original file line numberDiff line numberDiff line change
@@ -699,13 +699,6 @@ def cumulative_precision_recall(
699699
facet_col="metric",
700700
facet_col_wrap=2,
701701
facet_col_spacing=0.03,
702-
# pivot df in case we want to show all 3 metrics in each plot's hover tooltip
703-
# requires fixing index mismatch due to df sub-sampling above
704-
# customdata=dict(
705-
# df_cum.reset_index()
706-
# .pivot(index="index", columns="metric")
707-
# .items()
708-
# ),
709702
**kwargs,
710703
)
711704

matbench_discovery/preds.py

+5
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@
3737
df_each_pred = pd.DataFrame()
3838
for model in df_metrics.T.MAE.sort_values().index:
3939
df_each_pred[model] = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]
40+
41+
42+
df_each_err = pd.DataFrame()
43+
for model in df_metrics.T.MAE.sort_values().index:
44+
df_each_err[model] = df_wbm[model] - df_wbm[e_form_col]

scripts/hist_classified_stable_vs_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
for anno in fig.layout.annotations:
8787
if not anno.text.startswith("batch_idx="):
8888
continue
89-
batch_idx = int(anno.text.split("=")[-1])
89+
batch_idx = int(anno.text.split("=", 1)[-1])
9090
len_df = sum(df_wbm[batch_col] == int(batch_idx))
9191
anno.text = f"Batch {batch_idx} ({len_df:,})"
9292

scripts/hist_classified_stable_vs_hull_dist_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888
ax.set(title=f"{model_name} · {F1=:.2f} · {FPR=:.2f} · {FNR=:.2f} · {DAF=:.2f}")
8989
else:
9090
for anno in fig.layout.annotations:
91-
model_name = anno.text = anno.text.split("=").pop()
91+
model_name = anno.text = anno.text.split("=", 1).pop()
9292
if model_name not in models or not show_metrics:
9393
continue
9494
F1, FPR, FNR, DAF = (

scripts/prc_roc_curves_models.py

+62-43
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
# %%
2-
import numpy as np
32
import pandas as pd
43
from pymatviz.utils import save_fig
4+
from sklearn.metrics import auc, precision_recall_curve, roc_curve
55
from tqdm import tqdm
66

77
from matbench_discovery import FIGS
8-
from matbench_discovery.metrics import stable_metrics
98
from matbench_discovery.plots import pio
10-
from matbench_discovery.preds import (
11-
df_wbm,
12-
e_form_col,
13-
each_pred_col,
14-
each_true_col,
15-
models,
16-
)
9+
from matbench_discovery.preds import df_each_pred, df_wbm, each_true_col
1710

1811
__author__ = "Janosh Riebesell"
1912
__date__ = "2023-01-30"
@@ -34,47 +27,49 @@
3427
# %%
3528
df_roc = pd.DataFrame()
3629

37-
for model in (pbar := tqdm(models)):
38-
pbar.set_description(model)
39-
df_wbm[f"{model}_{each_pred_col}"] = df_wbm[each_true_col] + (
40-
df_wbm[model] - df_wbm[e_form_col]
41-
)
42-
for stab_treshold in np.arange(-0.4, 0.4, 0.01):
43-
metrics = stable_metrics(
44-
df_wbm[each_true_col], df_wbm[f"{model}_{each_pred_col}"], stab_treshold
45-
)
46-
df_tmp = pd.DataFrame(
47-
{facet_col: model, color_col: stab_treshold, **metrics}, index=[0]
48-
)
49-
df_roc = pd.concat([df_roc, df_tmp])
50-
30+
for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")):
31+
pbar.set_postfix_str(model)
32+
na_mask = df_wbm[each_true_col].isna() | df_each_pred[model].isna()
33+
y_true = (df_wbm[~na_mask][each_true_col] <= 0).astype(int)
34+
y_pred = df_each_pred[model][~na_mask]
35+
fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=0)
36+
AUC = auc(fpr, tpr)
37+
title = f"{model} · {AUC=:.2f}"
38+
df_tmp = pd.DataFrame(
39+
{"FPR": fpr, "TPR": tpr, color_col: thresholds, "AUC": AUC, facet_col: title}
40+
).round(3)
5141

52-
df_roc = df_roc.round(3)
42+
df_roc = pd.concat([df_roc, df_tmp])
5343

5444

5545
# %%
56-
fig = df_roc.plot.scatter(
57-
x="FPR",
58-
y="TPR",
59-
facet_col=facet_col,
60-
facet_col_wrap=2,
61-
backend="plotly",
62-
height=800,
63-
color=color_col,
64-
range_x=(0, 1),
65-
range_y=(0, 1),
46+
fig = (
47+
df_roc.iloc[:: len(df_roc) // 500 or 1]
48+
.sort_values(["AUC", "FPR"], ascending=False)
49+
.plot.scatter(
50+
x="FPR",
51+
y="TPR",
52+
facet_col=facet_col,
53+
facet_col_wrap=2,
54+
backend="plotly",
55+
height=150 * len(df_roc[facet_col].unique()),
56+
color=color_col,
57+
range_x=(0, 1),
58+
range_y=(0, 1),
59+
range_color=(-0.5, 0.5),
60+
hover_name=facet_col,
61+
hover_data={facet_col: False},
62+
)
6663
)
6764

6865
for anno in fig.layout.annotations:
69-
anno.text = anno.text.split("=")[1] # remove Model= from subplot titles
66+
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles
7067

7168
fig.layout.coloraxis.colorbar.update(
72-
x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.27, title_side="right"
69+
x=1, y=1, xanchor="right", yanchor="top", thickness=14, len=0.2, title_side="right"
7370
)
7471
fig.add_shape(type="line", x0=0, y0=0, x1=1, y1=1, line=line, row="all", col="all")
75-
fig.add_annotation(
76-
text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10, textangle=-30
77-
)
72+
fig.add_annotation(text="No skill", x=0.5, y=0.5, showarrow=False, yshift=-10)
7873
# allow scrolling and zooming each subplot individually
7974
fig.update_xaxes(matches=None)
8075
fig.update_yaxes(matches=None)
@@ -86,20 +81,44 @@
8681

8782

8883
# %%
89-
fig = df_roc.plot.scatter(
84+
df_prc = pd.DataFrame()
85+
86+
for model in (pbar := tqdm(list(df_each_pred), desc="Calculating ROC curves")):
87+
pbar.set_postfix_str(model)
88+
na_mask = df_wbm[each_true_col].isna() | df_each_pred[model].isna()
89+
y_true = (df_wbm[~na_mask][each_true_col] <= 0).astype(int)
90+
y_pred = df_each_pred[model][~na_mask]
91+
prec, recall, thresholds = precision_recall_curve(y_true, y_pred, pos_label=0)
92+
df_tmp = pd.DataFrame(
93+
{
94+
"Precision": prec[:-1],
95+
"Recall": recall[:-1],
96+
color_col: thresholds,
97+
facet_col: model,
98+
}
99+
).round(3)
100+
101+
df_prc = pd.concat([df_prc, df_tmp])
102+
103+
104+
# %%
105+
fig = df_prc.iloc[:: len(df_roc) // 500 or 1].plot.scatter(
90106
x="Recall",
91107
y="Precision",
92108
facet_col=facet_col,
93109
facet_col_wrap=2,
94110
backend="plotly",
95-
height=800,
111+
height=150 * len(df_roc[facet_col].unique()),
96112
color=color_col,
97113
range_x=(0, 1),
98-
range_y=(0, 1),
114+
range_y=(0.5, 1),
115+
range_color=(-0.5, 1),
116+
hover_name=facet_col,
117+
hover_data={facet_col: False},
99118
)
100119

101120
for anno in fig.layout.annotations:
102-
anno.text = anno.text.split("=")[1] # remove Model= from subplot titles
121+
anno.text = anno.text.split("=", 1)[1] # remove Model= from subplot titles
103122

104123
fig.layout.coloraxis.colorbar.update(
105124
x=0.5, y=1.1, thickness=14, len=0.4, orientation="h"

scripts/scatter_e_above_hull_models.py

+34-34
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@
127127
traces = [t for t in fig.data if t.xaxis == f"x{idx if idx > 1 else ''}"]
128128
assert len(traces) == 4, f"Expected 4 traces, got {len(traces)=}"
129129

130-
model = anno.text.split("=")[1]
130+
model = anno.text.split("=", 1)[1]
131131
assert model in df_wbm, f"Unexpected {model=} not in {list(df_wbm)=}"
132132
# add MAE and R2 to subplot titles
133133
MAE, R2 = df_metrics[model][["MAE", "R2"]]
@@ -139,39 +139,39 @@
139139
fig.layout[f"xaxis{idx}"].title.text = ""
140140
fig.layout[f"yaxis{idx}"].title.text = ""
141141

142-
# add transparent rectangle with TN, TP, FN, FP labels in each quadrant
143-
for sign_x, sign_y, color, label in zip(
144-
[-1, -1, 1, 1], [-1, 1, -1, 1], clf_colors, ("TP", "FN", "FP", "TN")
145-
):
146-
# instead of coloring points in each quadrant, we can add a transparent
147-
# background to each quadrant (looks worse maybe than coloring points)
148-
# fig.add_shape(
149-
# type="rect",
150-
# x0=0,
151-
# y0=0,
152-
# x1=sign_x * 100,
153-
# y1=sign_y * 100,
154-
# fillcolor=color,
155-
# opacity=0.5,
156-
# layer="below",
157-
# xref=f"x{idx}",
158-
# yref=f"y{idx}",
159-
# )
160-
fig.add_annotation(
161-
xref=f"x{idx}",
162-
yref=f"y{idx}",
163-
x=sign_x * xy_max,
164-
y=sign_y * xy_max,
165-
xshift=-20 * sign_x,
166-
yshift=-20 * sign_y,
167-
text=label,
168-
showarrow=False,
169-
font=dict(size=16, color=color),
170-
)
171-
172-
# add dashed quadrant separators
173-
fig.add_vline(x=0, line=dict(width=0.5, dash="dash"))
174-
fig.add_hline(y=0, line=dict(width=0.5, dash="dash"))
142+
# add transparent rectangle with TN, TP, FN, FP labels in each quadrant
143+
for sign_x, sign_y, color, label in zip(
144+
[-1, -1, 1, 1], [-1, 1, -1, 1], clf_colors, ("TP", "FN", "FP", "TN")
145+
):
146+
# instead of coloring points in each quadrant, we can add a transparent
147+
# background to each quadrant (looks worse maybe than coloring points)
148+
# fig.add_shape(
149+
# type="rect",
150+
# x0=0,
151+
# y0=0,
152+
# x1=sign_x * 100,
153+
# y1=sign_y * 100,
154+
# fillcolor=color,
155+
# opacity=0.2,
156+
# layer="below",
157+
# row="all",
158+
# col="all",
159+
# )
160+
fig.add_annotation(
161+
x=sign_x * xy_max,
162+
y=sign_y * xy_max,
163+
xshift=-20 * sign_x,
164+
yshift=-20 * sign_y,
165+
text=label,
166+
showarrow=False,
167+
font=dict(size=16, color=color),
168+
row="all",
169+
col="all",
170+
)
171+
172+
# add dashed quadrant separators
173+
fig.add_vline(x=0, line=dict(width=0.5, dash="dash"))
174+
fig.add_hline(y=0, line=dict(width=0.5, dash="dash"))
175175

176176
fig.update_xaxes(nticks=5)
177177
fig.update_yaxes(nticks=5)

site/src/routes/about-the-test-set/tmi/+page.svelte

+1-2
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,12 @@
4040

4141
Stuff that didn't make the cut into the main page describing the WBM test set.
4242

43-
<ColorScaleSelect bind:selected />
44-
4543
<h2>WBM Element Counts for <code>{filter}</code></h2>
4644

4745
Filter WBM element counts by composition arity (how many elements in the formula) or batch
4846
index (which iteration of elemental substitution the structure was generated in).
4947

48+
<ColorScaleSelect bind:selected />
5049
<ul>
5150
<li>
5251
composition arity:

site/src/routes/si/+page.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
<RocModels />
1717
{/if}
1818

19-
> @label:fig:roc-models Receiver operating characteristic (ROC) curve for each model. TPR/FPR is the true and false positive rates. Points are colored by their stability threshold. A material is classified as stable if E<sub>above hull</sub> is less than the stability threshold. Since all models predict E<sub>form</sub> (and M3GNet predicted energies are converted to formation energy before stability classification), they are insensitive to changes in the threshold.
19+
> @label:fig:roc-models Receiver operating characteristic (ROC) curve for each model. TPR/FPR is the true/false positive rate. FPR means the $x$-axis is the fraction of unstable structures classified as stable while TPR on the $y$-axis is the fraction of stable structures classified as stable. Points are colored by stability threshold $t$ which sweeps from $-0.4 \ \frac{\text{eV}}{\text{atom}} \leq t \leq 0.4 \ \frac{\text{eV}}{\text{atom}}$ above the hull. A material is classified as stable if the predicted E<sub>above hull</sub> is less than the stability threshold. Since all models predict E<sub>form</sub> (and M3GNet predicted energies are converted to formation energy before stability classification), they are insensitive to changes in the threshold $t$. M3GNet wins in area under curve (AUC) with 0.87, coming in 34% higher than the worst model Voronoi Random Forest. The diagonal 'No skill' line shows performance of a dummy model that randomly ranks materials stability.
2020
2121
## Model Run Times
2222

tests/test_preds.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from matbench_discovery.data import PRED_FILENAMES
22
from matbench_discovery.preds import (
3+
df_each_err,
34
df_each_pred,
45
df_metrics,
56
df_wbm,
@@ -23,9 +24,11 @@ def test_df_metrics() -> None:
2324

2425
def test_df_each_pred() -> None:
2526
assert len(df_each_pred) == len(df_wbm)
26-
assert (
27-
{*df_each_pred} == {*df_metrics} < {*df_wbm}
28-
), "df_each_pred has wrong columns"
29-
assert all(
30-
df_each_pred.isna().sum() / len(df_each_pred) < 0.05
31-
), "too many NaNs in df_each_pred"
27+
assert {*df_each_pred} == {*df_metrics}, "df_each_pred has wrong columns"
28+
assert all(df_each_pred.isna().mean() < 0.05), "too many NaNs in df_each_pred"
29+
30+
31+
def test_df_each_err() -> None:
32+
assert len(df_each_err) == len(df_wbm)
33+
assert {*df_each_err} == {*df_metrics}, "df_each_err has wrong columns"
34+
assert all(df_each_err.isna().mean() < 0.05), "too many NaNs in df_each_err"

0 commit comments

Comments
 (0)