Skip to content

Commit 20ef518

Browse files
committed
add scripts/metrics_table.py
1 parent 2779538 commit 20ef518

File tree

4 files changed

+155
-6
lines changed

4 files changed

+155
-6
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ repos:
4848
rev: v0.991
4949
hooks:
5050
- id: mypy
51-
additional_dependencies: [types-pyyaml]
51+
additional_dependencies: [types-pyyaml, types-requests]
5252

5353
- repo: https://github.com/codespell-project/codespell
5454
rev: v2.2.2

models/cgcnn/test_cgcnn.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
if val == runs[0].config[key] or key.startswith(("slurm_", "timestamp")):
7878
continue
7979
raise ValueError(
80-
f"Configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
80+
f"Run configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
8181
)
8282

8383
run_params = dict(
@@ -96,7 +96,10 @@
9696
wandb.init(project="matbench-discovery", name=job_name, config=run_params)
9797

9898
cg_data = CrystalGraphData(
99-
df, task_dict={target_col: "regression"}, structure_col=input_col
99+
df,
100+
task_dict={target_col: "regression"},
101+
structure_col=input_col,
102+
identifiers=("material_id", "formula_from_cse"),
100103
)
101104
data_loader = DataLoader(
102105
cg_data, batch_size=1024, shuffle=False, collate_fn=collate_batch
@@ -120,6 +123,5 @@
120123
R2 = ensemble_metrics.R2.mean()
121124

122125
title = rf"CGCNN {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
123-
print(title)
124126

125127
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

models/wrenformer/test_wrenformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
if val == runs[0].config[key] or key.startswith(("slurm_", "timestamp")):
6464
continue
6565
raise ValueError(
66-
f"Configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
66+
f"Run configs not identical: runs[{idx}][{key}]={val}, {runs[0][key]=}"
6767
)
6868

6969
run_params = dict(
@@ -109,6 +109,5 @@
109109
R2 = ensemble_metrics.R2.mean()
110110

111111
title = rf"Wrenformer {task_type} ensemble={len(runs)} {MAE=:.4} {R2=:.4}"
112-
print(title)
113112

114113
wandb_log_scatter(table, fields=dict(x=target_col, y=pred_col), title=title)

scripts/metrics_table.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# %%
2+
from __future__ import annotations
3+
4+
from typing import Any
5+
6+
import pandas as pd
7+
import requests
8+
import wandb
9+
from sklearn.metrics import f1_score, r2_score
10+
from tqdm import tqdm
11+
12+
from matbench_discovery import ROOT, today
13+
from matbench_discovery.load_preds import load_df_wbm_with_preds
14+
15+
__author__ = "Janosh Riebesell"
16+
__date__ = "2022-11-28"
17+
18+
19+
# %%
20+
models: dict[str, dict[str, Any]] = {
21+
"Wren": dict(n_runs=0),
22+
"CGCNN": dict(
23+
n_runs=10,
24+
filters=dict(
25+
created_at={"$gt": "2022-11-21", "$lt": "2022-11-23"},
26+
display_name={"$regex": "cgcnn-robust-formation_energy_per_atom"},
27+
),
28+
),
29+
"Voronoi RF": dict(
30+
n_runs=70,
31+
filters=dict(
32+
created_at={"$gt": "2022-11-17", "$lt": "2022-11-28"},
33+
display_name={"$regex": "voronoi-features"},
34+
),
35+
),
36+
"Wrenformer": dict(
37+
n_runs=10,
38+
filters=dict(
39+
created_at={"$gt": "2022-11-14", "$lt": "2022-11-16"},
40+
display_name={"$regex": "wrenformer-robust-mp-formation_energy"},
41+
),
42+
),
43+
"MEGNet": dict(
44+
n_runs=1,
45+
filters=dict(
46+
created_at={"$gt": "2022-11-17", "$lt": "2022-11-19"},
47+
display_name={"$regex": "megnet-wbm-IS2RE"},
48+
),
49+
),
50+
"M3GNet": dict(
51+
n_runs=99,
52+
filters=dict(
53+
created_at={"$gt": "2022-10-31", "$lt": "2022-11-01"},
54+
display_name={"$regex": "m3gnet-wbm-IS2RE"},
55+
),
56+
),
57+
"BOWSR MEGNet": dict(
58+
n_runs=1000,
59+
filters=dict(
60+
created_at={"$gt": "2022-11-22", "$lt": "2022-11-25"},
61+
display_name={"$regex": "bowsr-megnet"},
62+
),
63+
),
64+
}
65+
66+
run_times: dict[str, dict[str, str | int | float]] = {}
67+
68+
69+
# %% calculate total model run times from wandb logs
70+
# NOTE these model run times are pretty meaningless since some models were run on GPU
71+
# (Wrenformer and CGCNN), others on CPU. Also BOWSR MEGNet, M3GNet and MEGNet weren't
72+
# trained from scratch. Their run times only indicate the time needed to predict the
73+
# test set.
74+
75+
for model in (pbar := tqdm(models)):
76+
model_dict = models[model]
77+
n_runs, filters = (model_dict.get(x) for x in ("n_runs", "filters"))
78+
if n_runs == 0 or model in run_times:
79+
continue
80+
pbar.set_description(model)
81+
82+
runs = wandb.Api().runs("janosh/matbench-discovery", filters=filters)
83+
84+
assert len(runs) == n_runs, f"found {len(runs)=} for {model}, expected {n_runs}"
85+
86+
run_time = sum(run.summary.get("_wandb", {}).get("runtime", 0) for run in runs)
87+
# NOTE we assume all jobs have the same metadata here
88+
metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
89+
90+
n_gpu, n_cpu = metadata.get("gpu_count", 0), metadata.get("cpu_count", 0)
91+
run_times[model] = {"Run time": run_time, "Hardware": f"GPU: {n_gpu}, CPU: {n_cpu}"}
92+
93+
94+
# on 2022-11-28:
95+
# run_times = {'Voronoi RF': 739608,
96+
# 'Wrenformer': 208399,
97+
# 'MEGNet': 12396,
98+
# 'M3GNet': 301138,
99+
# 'BOWSR MEGNet': 9105237}
100+
101+
102+
# %%
103+
df_wbm = load_df_wbm_with_preds(models=models).round(3)
104+
105+
106+
target_col = "e_form_per_atom_mp2020_corrected"
107+
108+
df_wbm = df_wbm.query(f"{target_col} < 5")
109+
e_above_hull_col = "e_above_hull_mp2020_corrected_ppd_mp"
110+
e_above_hull = df_wbm[e_above_hull_col]
111+
112+
113+
# %%
114+
df_metrics = pd.DataFrame(run_times).T
115+
116+
for model in models:
117+
dct = {}
118+
e_above_hull_pred = df_wbm[model] - df_wbm[target_col]
119+
120+
dct["F1"] = f1_score(e_above_hull < 0, e_above_hull_pred < 0)
121+
dct["Precision"] = f1_score(e_above_hull < 0, e_above_hull_pred < 0, pos_label=True)
122+
dct["Recall"] = f1_score(e_above_hull < 0, e_above_hull_pred < 0, pos_label=False)
123+
124+
dct["MAE"] = (e_above_hull_pred - e_above_hull).abs().mean()
125+
126+
dct["RMSE"] = ((e_above_hull_pred - e_above_hull) ** 2).mean() ** 0.5
127+
dct["R2"] = r2_score(
128+
e_above_hull.loc[e_above_hull_pred.dropna().index], e_above_hull_pred.dropna()
129+
)
130+
131+
df_metrics.loc[model, list(dct)] = dct.values()
132+
133+
134+
df_styled = df_metrics.style.format(precision=3).background_gradient(
135+
cmap="viridis",
136+
# gmap=np.log10(df_table) # for log scaled color map
137+
)
138+
df_styled
139+
140+
141+
# %%
142+
styles = {
143+
"": "font-family: sans-serif; border-collapse: collapse;",
144+
"td, th": "border: 1px solid #ddd; text-align: left; padding: 8px;",
145+
}
146+
df_styled.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles])
147+
148+
df_styled.to_html(f"{ROOT}/figures/{today}-metrics-table.html")

0 commit comments

Comments
 (0)