Skip to content

Commit 75fc095

Browse files
committed
use stable_metrics() from matbench_discovery.energy in scripts/compile_metrics.py and update site/src/routes/models/2023-01-23-model-stats.json
add FPR, FNR and DAF to /models page add tooltips to /models page sort-by-metric buttons update deps
1 parent 32a02d8 commit 75fc095

18 files changed

+134
-106
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/PyCQA/isort
10-
rev: 5.11.4
10+
rev: 5.12.0
1111
hooks:
1212
- id: isort
1313

@@ -59,7 +59,7 @@ repos:
5959
exclude: ^(.+references.yaml)$
6060

6161
- repo: https://github.com/PyCQA/autoflake
62-
rev: v2.0.0
62+
rev: v2.0.1
6363
hooks:
6464
- id: autoflake
6565

@@ -75,7 +75,7 @@ repos:
7575
exclude: ^(site/src/figs/.+\.svelte|data/wbm/20.+\..+|site/src/routes/.+\.(yml|yaml|json))$
7676

7777
- repo: https://github.com/pre-commit/mirrors-eslint
78-
rev: v8.31.0
78+
rev: v8.33.0
7979
hooks:
8080
- id: eslint
8181
types: [file]
@@ -90,7 +90,7 @@ repos:
9090
- "@typescript-eslint/parser"
9191

9292
- repo: https://github.com/PyCQA/pydocstyle
93-
rev: 6.1.1
93+
rev: 6.3.0
9494
hooks:
9595
- id: pydocstyle
9696
exclude: tests

matbench_discovery/energy.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def stable_metrics(
176176
177177
Returns:
178178
dict[str, float]: dictionary of classification metrics with keys DAF, Precision,
179-
Recall, Prevalence, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
179+
Recall, Accuracy, F1, TPR, FPR, TNR, FNR, MAE, RMSE, R2.
180180
"""
181181
true_pos, false_neg, false_pos, true_neg = classify_stable(
182182
true, pred, stability_threshold
@@ -198,7 +198,6 @@ def stable_metrics(
198198
DAF=precision / prevalence,
199199
Precision=precision,
200200
Recall=recall,
201-
Prevalence=prevalence,
202201
Accuracy=(n_true_pos + n_true_neg) / len(true),
203202
F1=2 * (precision * recall) / (precision + recall),
204203
TPR=n_true_pos / (n_true_pos + n_false_neg),

scripts/compile_metrics.py

+37-29
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33

44
from typing import Any
55

6+
import numpy as np
67
import pandas as pd
78
import requests
89
import wandb
910
import wandb.apis.public
1011
from pymatviz.utils import save_fig
11-
from sklearn.metrics import f1_score, r2_score
1212
from tqdm import tqdm
1313

1414
from matbench_discovery import FIGS, MODELS, WANDB_PATH, today
1515
from matbench_discovery.data import PRED_FILENAMES, load_df_wbm_preds
16+
from matbench_discovery.energy import stable_metrics
1617
from matbench_discovery.plots import px
1718

1819
__author__ = "Janosh Riebesell"
@@ -97,10 +98,10 @@
9798

9899
n_gpu, n_cpu = metadata.get("gpu_count", 0), metadata.get("cpu_count", 0)
99100
model_stats[model] = {
100-
"run_time_h": run_time_total / 3600,
101+
(time_col := "Run Time (h)"): run_time_total / 3600,
101102
"GPU": n_gpu,
102103
"CPU": n_cpu,
103-
"slurm_jobs": n_runs,
104+
"Slurm Jobs": n_runs,
104105
}
105106

106107

@@ -110,6 +111,7 @@
110111
)
111112

112113
df_metrics = pd.DataFrame(model_stats).T
114+
df_metrics.index.name = "Model"
113115
# on 2022-11-28:
114116
# run_times = {'Voronoi Random Forest': 739608,
115117
# 'Wrenformer': 208399,
@@ -121,46 +123,50 @@
121123
# %%
122124
df_wbm = load_df_wbm_preds(list(models))
123125
e_form_col = "e_form_per_atom_mp2020_corrected"
124-
each_col = "e_above_hull_mp2020_corrected_ppd_mp"
126+
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
125127

126128

127129
# %%
128130
for model in models:
129-
dct = {}
130-
e_above_hull_pred = df_wbm[model] - df_wbm[e_form_col]
131-
isna = e_above_hull_pred.isna() | df_wbm[each_col].isna()
131+
each_pred = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]
132132

133-
dct["F1"] = f1_score(df_wbm[each_col] < 0, e_above_hull_pred < 0)
134-
dct["Precision"] = f1_score(
135-
df_wbm[each_col] < 0, e_above_hull_pred < 0, pos_label=True
136-
)
137-
dct["Recall"] = f1_score(
138-
df_wbm[each_col] < 0, e_above_hull_pred < 0, pos_label=False
139-
)
140-
141-
dct["MAE"] = (e_above_hull_pred - df_wbm[each_col]).abs().mean()
133+
metrics = stable_metrics(df_wbm[each_true_col], each_pred)
142134

143-
dct["RMSE"] = ((e_above_hull_pred - df_wbm[each_col]) ** 2).mean() ** 0.5
144-
dct["R2"] = r2_score(df_wbm[each_col][~isna], e_above_hull_pred[~isna])
135+
df_metrics.loc[model, list(metrics)] = metrics.values()
145136

146-
df_metrics.loc[model, list(dct)] = dct.values()
147137

148-
149-
df_styled = df_metrics.style.format(precision=3).background_gradient(
150-
cmap="viridis",
151-
# gmap=np.log10(df_table) # for log scaled color map
138+
# %%
139+
df_styled = (
140+
df_metrics.reset_index()
141+
.drop(columns=["GPU", "CPU", "Slurm Jobs"])
142+
.style.format(precision=2)
143+
.background_gradient(
144+
cmap="viridis_r", # lower is better so reverse color map
145+
subset=["MAE", "RMSE", "FNR", "FPR"],
146+
)
147+
.background_gradient(
148+
cmap="viridis_r",
149+
subset=[time_col],
150+
gmap=np.log10(df_metrics[time_col].to_numpy()), # for log scaled color map
151+
)
152+
.background_gradient(
153+
cmap="viridis", # higher is better
154+
subset=["DAF", "R2", "Precision", "Recall", "F1", "Accuracy", "TPR", "TNR"],
155+
)
156+
.hide(axis="index")
152157
)
158+
df_styled
153159

154160

155161
# %% export model metrics as styled HTML table
156162
styles = {
157163
"": "font-family: sans-serif; border-collapse: collapse;",
158-
"td, th": "border: 1px solid #ddd; text-align: left; padding: 8px;",
164+
"td, th": "border: 1px solid #ddd; text-align: left; padding: 8px; white-space: nowrap;",
159165
}
160166
df_styled.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles])
161167

162-
html_path = f"{FIGS}/{today}-metrics-table.html"
163-
# df_styled.to_html(html_path)
168+
html_path = f"{FIGS}/{today}-metrics-table.svelte"
169+
df_styled.to_html(html_path)
164170

165171

166172
# %% write model metrics to json for use by the website
@@ -169,14 +175,14 @@
169175
f"{x / len(df_wbm):.2%}" for x in df_metrics.missing_preds
170176
]
171177

172-
df_metrics.attrs["total_run_time"] = df_metrics.run_time.sum()
178+
df_metrics.attrs["Total Run Time"] = df_metrics[time_col].sum()
173179

174180
df_metrics.round(2).to_json(f"{MODELS}/{today}-model-stats.json", orient="index")
175181

176182

177183
# %% plot model run times as pie chart
178184
fig = px.pie(
179-
df_metrics, values="run_time", names=df_metrics.index, hole=0.5
185+
df_metrics, values=time_col, names=df_metrics.index, hole=0.5
180186
).update_traces(
181187
textinfo="percent+label",
182188
textfont_size=14,
@@ -189,12 +195,14 @@
189195
)
190196
fig.add_annotation(
191197
# add title in the middle saying "Total CPU+GPU time used"
192-
text=f"Total CPU+GPU<br>time used:<br>{df_metrics.run_time.sum():.1f} h",
198+
text=f"Total CPU+GPU<br>time used:<br>{df_metrics[time_col].sum():.1f} h",
193199
font=dict(size=18),
194200
x=0.5,
195201
y=0.5,
196202
showarrow=False,
197203
)
198204
fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))
199205

206+
207+
# %%
200208
save_fig(fig, f"{FIGS}/{today}-model-run-times-pie.svelte")

scripts/cumulative_clf_metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"CGCNN, Voronoi Random Forest, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet"
1616
).split(", ")
1717

18-
df_wbm = load_df_wbm_preds(models=models).round(3)
18+
df_wbm = load_df_wbm_preds(models).round(3)
1919

2020
# df_wbm.columns = [f"{col}_e_form" if col in models else col for col in df_wbm]
2121
e_form_col = "e_form_per_atom_mp2020_corrected"

scripts/hist_classified_stable_vs_hull_dist_batches.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@
3535
which_energy: WhichEnergy = "true"
3636
backend: Backend = "matplotlib"
3737
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
38-
df_wbm[each_pred_col] = df_wbm[each_true_col] + (
39-
df_wbm[model_name] - df_wbm[e_form_col]
40-
)
38+
df_wbm[each_pred_col] = df_wbm[each_true_col] + df_wbm[model_name] - df_wbm[e_form_col]
4139

4240

4341
for batch_idx, ax in zip(range(1, 6), axs.flat):

scripts/hist_classified_stable_vs_hull_dist_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
models = sorted(
2121
"CGCNN, Voronoi Random Forest, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet".split(", ")
2222
)
23-
df_wbm = load_df_wbm_preds(models=models).round(3)
23+
df_wbm = load_df_wbm_preds(models).round(3)
2424

2525
e_form_col = "e_form_per_atom_mp2020_corrected"
2626
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
@@ -39,8 +39,8 @@
3939
value_name=e_form_preds,
4040
)
4141

42-
df_melt[each_pred_col] = df_melt[each_true_col] + (
43-
df_melt[e_form_preds] - df_melt[e_form_col]
42+
df_melt[each_pred_col] = (
43+
df_melt[each_true_col] + df_melt[e_form_preds] - df_melt[e_form_col]
4444
)
4545

4646

scripts/rolling_mae_vs_hull_dist_all_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
e_form_col = "e_form_per_atom_mp2020_corrected"
1717
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
1818

19-
df_wbm = load_df_wbm_preds(models=models).round(3)
19+
df_wbm = load_df_wbm_preds(models).round(3)
2020

2121

2222
# %%

scripts/scatter_e_above_hull_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
models = sorted(
1717
"CGCNN, Voronoi Random Forest, Wrenformer, MEGNet, M3GNet, BOWSR MEGNet".split(", ")
1818
)
19-
df_wbm = load_df_wbm_preds(models=models).round(3)
19+
df_wbm = load_df_wbm_preds(models).round(3)
2020

2121
e_form_col = "e_form_per_atom_mp2020_corrected"
2222
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
@@ -36,8 +36,8 @@
3636
value_name=e_form_pred_col,
3737
)
3838

39-
df_melt[each_pred_col] = df_melt[each_true_col] + (
40-
df_melt[e_form_pred_col] - df_melt[e_form_col]
39+
df_melt[each_pred_col] = (
40+
df_melt[each_true_col] + df_melt[e_form_pred_col] - df_melt[e_form_col]
4141
)
4242

4343

site/package.json

+11-11
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
"make-api-docs": "cd .. && python scripts/make_api_docs.py"
1818
},
1919
"devDependencies": {
20-
"@iconify/svelte": "^3.0.1",
20+
"@iconify/svelte": "^3.1.0",
2121
"@rollup/plugin-yaml": "^4.0.1",
22-
"@sveltejs/adapter-static": "^1.0.3",
23-
"@sveltejs/kit": "^1.1.1",
22+
"@sveltejs/adapter-static": "^1.0.5",
23+
"@sveltejs/kit": "^1.3.7",
2424
"@sveltejs/vite-plugin-svelte": "^2.0.2",
25-
"@typescript-eslint/eslint-plugin": "^5.48.1",
26-
"@typescript-eslint/parser": "^5.48.1",
27-
"eslint": "^8.32.0",
25+
"@typescript-eslint/eslint-plugin": "^5.50.0",
26+
"@typescript-eslint/parser": "^5.50.0",
27+
"eslint": "^8.33.0",
2828
"eslint-plugin-svelte3": "^4.0.0",
2929
"hastscript": "^7.2.0",
3030
"katex": "^0.16.4",
@@ -36,14 +36,14 @@
3636
"rehype-slug": "^5.1.0",
3737
"remark-math": "3.0.0",
3838
"svelte": "^3.55.1",
39-
"svelte-check": "^3.0.2",
40-
"svelte-preprocess": "^5.0.0",
39+
"svelte-check": "^3.0.3",
40+
"svelte-preprocess": "^5.0.1",
4141
"svelte-toc": "^0.5.2",
42-
"svelte-zoo": "^0.2.1",
42+
"svelte-zoo": "^0.2.4",
4343
"svelte2tsx": "^0.6.0",
4444
"sveriodic-table": "^0.1.4",
45-
"tslib": "^2.4.1",
46-
"typescript": "^4.9.4",
45+
"tslib": "^2.5.0",
46+
"typescript": "^4.9.5",
4747
"vite": "^4.0.4"
4848
},
4949
"prettier": {

site/src/lib/ModelCard.svelte

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
<li>
6666
{#if ![`aviary`].includes(name)}
6767
{@const href = `https://pypi.org/project/${name}/${version}`}
68-
{name}: <a {href}>{version}</a>
68+
{name}: <a {href} {...target}>{version}</a>
6969
{:else}
7070
{name}: {version}
7171
{/if}
@@ -77,7 +77,7 @@
7777
<section class="metrics">
7878
<h3 class="toc-exclude">Metrics</h3>
7979
<ul>
80-
{#each stats as [key, label, unit]}
80+
{#each stats as { key, label, unit }}
8181
<li class:active={sort_by == key}>
8282
{@html label ?? key} = {data[key]}
8383
{unit ?? ``}

site/src/lib/index.ts

+10-2
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,23 @@ export type ModelStats = {
2929
missing_preds: number
3030
missing_percent: number
3131
Accuracy: number
32-
run_time_h: string
32+
'Run Time (h)': string
33+
FPR: number
34+
FNR: number
35+
DAF: number
3336
GPUs: number
3437
CPUs: number
3538
slurm_jobs: number
3639
date_added: string
3740
}
3841

3942
// [key, label?, unit?]
40-
export type ModelStatLabel = [keyof ModelStats, (string | null)?, string?]
43+
export type ModelStatLabel = {
44+
key: keyof ModelStats
45+
label?: string
46+
unit?: string
47+
tooltip?: string
48+
}
4149

4250
export type Author = {
4351
name: string

0 commit comments

Comments
 (0)