Skip to content

Commit 387184c

Browse files
committed
move load_df_wbm_preds+PRED_FILES from matbench_discovery/{data->preds}.py
change model label 'Voronoi Random Forest' to 'Voronoi RF'
1 parent 8166801 commit 387184c

15 files changed

+159
-155
lines changed

matbench_discovery/data.py

+2-100
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from tqdm import tqdm
1414

1515
from matbench_discovery import ROOT
16-
from matbench_discovery.plots import model_labels
1716

1817
df_wbm = pd.read_csv(f"{ROOT}/data/wbm/2022-10-19-wbm-summary.csv")
1918
df_wbm.index = df_wbm.material_id
@@ -39,7 +38,7 @@ def __init__(self) -> None:
3938
"""Create a Files instance."""
4039
key_map = getattr(self, "_key_map", {})
4140
dct = {
42-
key_map.get(key, key): f"{self._root}/{file}" # type: ignore
41+
key_map.get(key, key): f"{self._root}{file}" # type: ignore
4342
for key, file in type(self).__dict__.items()
4443
if not key.startswith("_")
4544
}
@@ -52,7 +51,7 @@ class DataFiles(Files):
5251
See https://janosh.github.io/matbench-discovery/contribute for data descriptions.
5352
"""
5453

55-
_root = f"{ROOT}/data"
54+
_root = f"{ROOT}/data/"
5655

5756
mp_computed_structure_entries = (
5857
"mp/2023-02-07-mp-computed-structure-entries.json.gz"
@@ -176,41 +175,6 @@ def load_train_test(
176175
return dfs
177176

178177

179-
class PredFiles(Files):
180-
"""Data files provided by Matbench Discovery.
181-
See https://janosh.github.io/matbench-discovery/contribute for data descriptions.
182-
"""
183-
184-
_root = f"{ROOT}/models"
185-
_key_map = model_labels
186-
187-
# CGCnn 10-member ensemble
188-
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
189-
190-
# cgcnn 10-member ensemble with 5-fold training set perturbations
191-
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
192-
193-
# magpie composition+voronoi tessellation structure features + sklearn random forest
194-
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
195-
196-
# wrenformer 10-member ensemble
197-
wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
198-
199-
# original megnet straight from publication, not re-trained
200-
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
201-
202-
# original m3gnet straight from publication, not re-trained
203-
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
204-
205-
# m3gnet-relaxed structures fed into megnet for formation energy prediction
206-
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
207-
# bowsr optimizer coupled with original megnet
208-
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
209-
210-
211-
PRED_FILES = PredFiles()
212-
213-
214178
def glob_to_df(
215179
pattern: str,
216180
reader: Callable[[Any], pd.DataFrame] = None,
@@ -242,65 +206,3 @@ def glob_to_df(
242206
sub_dfs[file] = df
243207

244208
return pd.concat(sub_dfs.values())
245-
246-
247-
def load_df_wbm_preds(
248-
models: Sequence[str] = (*PRED_FILES,),
249-
pbar: bool = True,
250-
id_col: str = "material_id",
251-
**kwargs: Any,
252-
) -> pd.DataFrame:
253-
"""Load WBM summary dataframe with model predictions from disk.
254-
255-
Args:
256-
models (Sequence[str], optional): Model names must be keys of
257-
matbench_discovery.data.PRED_FILES. Defaults to all models.
258-
pbar (bool, optional): Whether to show progress bar. Defaults to True.
259-
id_col (str, optional): Column to set as df.index. Defaults to "material_id".
260-
**kwargs: Keyword arguments passed to glob_to_df().
261-
262-
Raises:
263-
ValueError: On unknown model names.
264-
265-
Returns:
266-
pd.DataFrame: WBM summary dataframe with model predictions.
267-
"""
268-
if mismatch := ", ".join(set(models) - set(PRED_FILES)):
269-
raise ValueError(f"Unknown models: {mismatch}")
270-
271-
dfs: dict[str, pd.DataFrame] = {}
272-
273-
for model_name in (bar := tqdm(models, disable=not pbar, desc="Loading preds")):
274-
bar.set_postfix_str(model_name)
275-
df = glob_to_df(PRED_FILES[model_name], pbar=False, **kwargs).set_index(id_col)
276-
dfs[model_name] = df
277-
278-
df_out = df_wbm.copy()
279-
for model_name, df in dfs.items():
280-
model_key = model_name.lower().replace(" + ", "_").replace(" ", "_")
281-
if f"e_form_per_atom_{model_key}" in df:
282-
df_out[model_name] = df[f"e_form_per_atom_{model_key}"]
283-
284-
elif len(pred_cols := df.filter(like="_pred_ens").columns) > 0:
285-
assert len(pred_cols) == 1
286-
df_out[model_name] = df[pred_cols[0]]
287-
if len(std_cols := df.filter(like="_std_ens").columns) > 0:
288-
df_out[f"{model_name}_std"] = df[std_cols[0]]
289-
290-
elif len(pred_cols := df.filter(like=r"_pred_").columns) > 1:
291-
# make sure we average the expected number of ensemble member predictions
292-
assert len(pred_cols) == 10, f"{len(pred_cols) = }, expected 10"
293-
df_out[model_name] = df[pred_cols].mean(axis=1)
294-
295-
elif "e_form_per_atom_voronoi_rf" in df: # new voronoi
296-
df_out[model_name] = df.e_form_per_atom_voronoi_rf
297-
298-
elif "e_form_pred" in df: # old voronoi
299-
df_out[model_name] = df.e_form_pred
300-
301-
else:
302-
raise ValueError(
303-
f"No pred col for {model_name=}, available cols={list(df)}"
304-
)
305-
306-
return df_out

matbench_discovery/plots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
m3gnet="M3GNet",
5757
megnet="MEGNet",
5858
megnet_old="MEGNet Old",
59-
voronoi_rf="Voronoi Random Forest",
59+
voronoi_rf="Voronoi RF",
6060
wrenformer="Wrenformer",
6161
dft="DFT",
6262
wbm="WBM",

matbench_discovery/preds.py

+102-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
4+
from typing import Any
5+
36
import pandas as pd
7+
from tqdm import tqdm
48

5-
from matbench_discovery.data import PRED_FILES, load_df_wbm_preds
9+
from matbench_discovery import ROOT
10+
from matbench_discovery.data import Files, glob_to_df
611
from matbench_discovery.metrics import stable_metrics
12+
from matbench_discovery.plots import model_labels
713

814
"""Centralize data-loading and computing metrics for plotting scripts"""
915

@@ -14,6 +20,101 @@
1420
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
1521
each_pred_col = "e_above_hull_pred"
1622

23+
24+
class PredFiles(Files):
25+
"""Data files provided by Matbench Discovery.
26+
See https://janosh.github.io/matbench-discovery/contribute for data descriptions.
27+
"""
28+
29+
_root = f"{ROOT}/models/"
30+
_key_map = model_labels # remap model keys below to pretty plot labels (see Files)
31+
32+
# CGCnn 10-member ensemble
33+
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
34+
35+
# cgcnn 10-member ensemble with 5-fold training set perturbations
36+
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
37+
38+
# magpie composition+voronoi tessellation structure features + sklearn random forest
39+
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
40+
41+
# wrenformer 10-member ensemble
42+
wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
43+
44+
# original megnet straight from publication, not re-trained
45+
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
46+
megnet_old = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
47+
48+
# original m3gnet straight from publication, not re-trained
49+
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
50+
51+
# m3gnet-relaxed structures fed into megnet for formation energy prediction
52+
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
53+
# bowsr optimizer coupled with original megnet
54+
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
55+
56+
57+
PRED_FILES = PredFiles()
58+
59+
60+
def load_df_wbm_preds(
61+
models: Sequence[str] = (*PRED_FILES,),
62+
pbar: bool = True,
63+
id_col: str = "material_id",
64+
**kwargs: Any,
65+
) -> pd.DataFrame:
66+
"""Load WBM summary dataframe with model predictions from disk.
67+
68+
Args:
69+
models (Sequence[str], optional): Model names must be keys of
70+
matbench_discovery.data.PRED_FILES. Defaults to all models.
71+
pbar (bool, optional): Whether to show progress bar. Defaults to True.
72+
id_col (str, optional): Column to set as df.index. Defaults to "material_id".
73+
**kwargs: Keyword arguments passed to glob_to_df().
74+
75+
Raises:
76+
ValueError: On unknown model names.
77+
78+
Returns:
79+
pd.DataFrame: WBM summary dataframe with model predictions.
80+
"""
81+
if mismatch := ", ".join(set(models) - set(PRED_FILES)):
82+
raise ValueError(f"Unknown models: {mismatch}")
83+
84+
dfs: dict[str, pd.DataFrame] = {}
85+
86+
for model_name in (bar := tqdm(models, disable=not pbar, desc="Loading preds")):
87+
bar.set_postfix_str(model_name)
88+
df = glob_to_df(PRED_FILES[model_name], pbar=False, **kwargs).set_index(id_col)
89+
dfs[model_name] = df
90+
91+
from matbench_discovery.data import df_wbm
92+
93+
df_out = df_wbm.copy()
94+
for model_name, df in dfs.items():
95+
model_key = model_name.lower().replace(" + ", "_").replace(" ", "_")
96+
if (col := f"e_form_per_atom_{model_key}") in df:
97+
df_out[model_name] = df[col]
98+
99+
elif pred_cols := list(df.filter(like="_pred_ens")):
100+
assert len(pred_cols) == 1
101+
df_out[model_name] = df[pred_cols[0]]
102+
if std_cols := list(df.filter(like="_std_ens")):
103+
df_out[f"{model_name}_std"] = df[std_cols[0]]
104+
105+
elif pred_cols := list(df.filter(like=r"_pred_")):
106+
# make sure we average the expected number of ensemble member predictions
107+
assert len(pred_cols) == 10, f"{len(pred_cols) = }, expected 10"
108+
df_out[model_name] = df[pred_cols].mean(axis=1)
109+
110+
else:
111+
raise ValueError(
112+
f"No pred col for {model_name=}, available cols={list(df)}"
113+
)
114+
115+
return df_out
116+
117+
17118
df_wbm = load_df_wbm_preds().round(3)
18119

19120

models/chgnet/metadata.yml

+5-4
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ authors:
1919
- name: Christopher J. Bartel
2020
affiliation: University of Minnesota
2121
orcid: https://orcid.org/0000-0002-5198-5036
22-
- name: Christopher J. Bartel
23-
affiliation: Gerbrand Ceder
22+
- name: Gerbrand Ceder
23+
affiliation: UC Berkeley
2424
orcid: https://orcid.org/0000-0001-9275-3605
2525
2626
repo: https://github.com/CederGroupHub/chgnet
@@ -31,9 +31,10 @@ requirements:
3131
ase: 3.22.0
3232
pymatgen: 2022.10.22
3333
numpy: 1.24.0
34-
pandas: 1.5.1
3534
trained_on_benchmark: false
3635

3736
notes:
38-
description: The Crystal Hamiltonian Graph Neural Network (CHGNet) is a universal GNN-based interatomic potential trained on energies, forces, stresses and magnetic moments from the MP trajectory dataset containing ∼1.5 million inorganic structures.
37+
description: |
38+
The Crystal Hamiltonian Graph Neural Network (CHGNet) is a universal GNN-based interatomic potential trained on energies, forces, stresses and magnetic moments from the MP trajectory dataset containing ∼1.5 million inorganic structures.
39+
![CHGNet Pipeline](https://user-images.githubusercontent.com/30958850/222842305-b6ed2468-8773-4e03-9de5-20c8e8de030e.svg)
3940
training: Using pre-trained model released with preprint. Training set unreleased until after review.

models/voronoi/metadata.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model_name: Voronoi Random Forest
1+
model_name: Voronoi RF
22
model_version: 1.1.2 # scikit learn version which implements the random forest
33
matbench_discovery_version: 1.0
44
date_added: "2022-11-26"

scripts/compile_metrics.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
from tqdm import tqdm
1818

1919
from matbench_discovery import FIGS, MODELS, ROOT, WANDB_PATH
20-
from matbench_discovery.data import PRED_FILES
2120
from matbench_discovery.plots import px
22-
from matbench_discovery.preds import df_metrics, df_wbm
21+
from matbench_discovery.preds import PRED_FILES, df_metrics, df_wbm
2322

2423
__author__ = "Janosh Riebesell"
2524
__date__ = "2022-11-28"
@@ -35,7 +34,7 @@
3534
display_name={"$regex": "cgcnn-robust-formation_energy_per_atom"},
3635
),
3736
),
38-
"Voronoi Random Forest": dict(
37+
"Voronoi RF": dict(
3938
n_runs=68,
4039
filters=dict(
4140
created_at={"$gt": "2022-11-17", "$lt": "2022-11-28"},

scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454

5555

5656
# %% plotly
57-
model = "Wrenformer" # ["M3GNet", "Wrenformer", "MEGNet", "Voronoi Random Forest"]
57+
model = "Wrenformer" # ["M3GNet", "Wrenformer", "MEGNet", "Voronoi RF"]
5858
df_pivot = df_each_pred.pivot(columns=batch_col, values=model)
5959

6060
# unstack two-level column index into new model column

site/src/lib/ModelCard.svelte

+2-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
const target = { target: `_blank`, rel: `noopener` }
2424
</script>
2525

26-
<h2>
26+
<h2 id={model_name.toLowerCase().replaceAll(` `, `-`)}>
2727
{model_name}
2828
<button
2929
on:click={() => (show_details = !show_details)}
@@ -161,10 +161,7 @@
161161
font: inherit;
162162
}
163163
h3 {
164-
margin: 1em 0 0;
165-
}
166-
div h3 {
167-
margin: 0;
164+
margin: 1ex 0 3pt;
168165
}
169166
ul {
170167
list-style: disc;
@@ -204,9 +201,6 @@
204201
flex-direction: column;
205202
max-height: 10em;
206203
}
207-
section.metrics > h3 {
208-
margin: 0;
209-
}
210204
section.metrics > ul > li {
211205
font-weight: lighter;
212206
display: flex;
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"CGCNN":{"Run Time (h)":11.35,"GPU":1.0,"CPU":128.0,"Slurm Jobs":10.0,"DAF":1.97,"Precision":0.74,"Recall":0.79,"Accuracy":0.82,"F1":0.77,"TPR":0.79,"FPR":0.17,"TNR":0.83,"FNR":0.21,"MAE":0.13,"RMSE":0.23,"R2":0.16,"missing_preds":2,"missing_percent":"0.00%"},"Voronoi Random Forest":{"Run Time (h)":204.7,"GPU":0.0,"CPU":76.0,"Slurm Jobs":68.0,"DAF":1.55,"Precision":0.59,"Recall":0.79,"Accuracy":0.71,"F1":0.67,"TPR":0.79,"FPR":0.34,"TNR":0.66,"FNR":0.21,"MAE":0.14,"RMSE":0.21,"R2":0.32,"missing_preds":17,"missing_percent":"0.01%"},"Wrenformer":{"Run Time (h)":57.89,"GPU":1.0,"CPU":128.0,"Slurm Jobs":10.0,"DAF":1.78,"Precision":0.67,"Recall":0.91,"Accuracy":0.8,"F1":0.77,"TPR":0.91,"FPR":0.27,"TNR":0.73,"FNR":0.09,"MAE":0.11,"RMSE":0.18,"R2":0.46,"missing_preds":0,"missing_percent":"0.00%"},"MEGNet":{"Run Time (h)":3.44,"GPU":1.0,"CPU":128.0,"Slurm Jobs":1.0,"DAF":2.1,"Precision":0.8,"Recall":0.66,"Accuracy":0.81,"F1":0.72,"TPR":0.66,"FPR":0.1,"TNR":0.9,"FNR":0.34,"MAE":0.13,"RMSE":0.21,"R2":0.3,"missing_preds":0,"missing_percent":"0.00%"},"M3GNet":{"Run Time (h)":83.65,"GPU":0.0,"CPU":76.0,"Slurm Jobs":99.0,"DAF":1.88,"Precision":0.71,"Recall":0.93,"Accuracy":0.83,"F1":0.8,"TPR":0.93,"FPR":0.24,"TNR":0.76,"FNR":0.07,"MAE":0.07,"RMSE":0.11,"R2":0.79,"missing_preds":2569,"missing_percent":"1.00%"},"BOWSR + MEGNet":{"Run Time (h)":2776.57,"GPU":0.0,"CPU":32.0,"Slurm Jobs":500.0,"DAF":1.63,"Precision":0.59,"Recall":0.87,"Accuracy":0.73,"F1":0.7,"TPR":0.87,"FPR":0.36,"TNR":0.64,"FNR":0.13,"MAE":0.11,"RMSE":0.16,"R2":0.55,"missing_preds":6184,"missing_percent":"2.41%"},"M3GNet + MEGNet":{"Run Time (h)":87.09,"GPU":0.0,"CPU":76.0,"Slurm Jobs":99.0,"DAF":1.98,"Precision":0.74,"Recall":0.79,"Accuracy":0.82,"F1":0.76,"TPR":0.79,"FPR":0.17,"TNR":0.83,"FNR":0.21,"MAE":0.09,"RMSE":0.13,"R2":0.72,"missing_preds":2576,"missing_percent":"1.00%"}}
1+
{"CGCNN":{"Run Time (h)":11.35,"GPU":1.0,"CPU":128.0,"Slurm Jobs":10.0,"DAF":1.97,"Precision":0.74,"Recall":0.79,"Accuracy":0.82,"F1":0.77,"TPR":0.79,"FPR":0.17,"TNR":0.83,"FNR":0.21,"MAE":0.13,"RMSE":0.23,"R2":0.16,"missing_preds":2,"missing_percent":"0.00%"},"Voronoi RF":{"Run Time (h)":204.7,"GPU":0.0,"CPU":76.0,"Slurm Jobs":68.0,"DAF":1.55,"Precision":0.59,"Recall":0.79,"Accuracy":0.71,"F1":0.67,"TPR":0.79,"FPR":0.34,"TNR":0.66,"FNR":0.21,"MAE":0.14,"RMSE":0.21,"R2":0.32,"missing_preds":17,"missing_percent":"0.01%"},"Wrenformer":{"Run Time (h)":57.89,"GPU":1.0,"CPU":128.0,"Slurm Jobs":10.0,"DAF":1.78,"Precision":0.67,"Recall":0.91,"Accuracy":0.8,"F1":0.77,"TPR":0.91,"FPR":0.27,"TNR":0.73,"FNR":0.09,"MAE":0.11,"RMSE":0.18,"R2":0.46,"missing_preds":0,"missing_percent":"0.00%"},"MEGNet":{"Run Time (h)":3.44,"GPU":1.0,"CPU":128.0,"Slurm Jobs":1.0,"DAF":2.1,"Precision":0.8,"Recall":0.66,"Accuracy":0.81,"F1":0.72,"TPR":0.66,"FPR":0.1,"TNR":0.9,"FNR":0.34,"MAE":0.13,"RMSE":0.21,"R2":0.3,"missing_preds":0,"missing_percent":"0.00%"},"M3GNet":{"Run Time (h)":83.65,"GPU":0.0,"CPU":76.0,"Slurm Jobs":99.0,"DAF":1.88,"Precision":0.71,"Recall":0.93,"Accuracy":0.83,"F1":0.8,"TPR":0.93,"FPR":0.24,"TNR":0.76,"FNR":0.07,"MAE":0.07,"RMSE":0.11,"R2":0.79,"missing_preds":2569,"missing_percent":"1.00%"},"BOWSR + MEGNet":{"Run Time (h)":2776.57,"GPU":0.0,"CPU":32.0,"Slurm Jobs":500.0,"DAF":1.63,"Precision":0.59,"Recall":0.87,"Accuracy":0.73,"F1":0.7,"TPR":0.87,"FPR":0.36,"TNR":0.64,"FNR":0.13,"MAE":0.11,"RMSE":0.16,"R2":0.55,"missing_preds":6184,"missing_percent":"2.41%"},"M3GNet + MEGNet":{"Run Time (h)":87.09,"GPU":0.0,"CPU":76.0,"Slurm Jobs":99.0,"DAF":1.98,"Precision":0.74,"Recall":0.79,"Accuracy":0.82,"F1":0.76,"TPR":0.79,"FPR":0.17,"TNR":0.83,"FNR":0.21,"MAE":0.09,"RMSE":0.13,"R2":0.72,"missing_preds":2576,"missing_percent":"1.00%"}}

0 commit comments

Comments
 (0)