Skip to content

Commit 24f6868

Browse files
committed
allow multiple models per metadata.yml file, add models M3GNet + MEGNet and CGCNN+P to respective model cards
rename augment kwarg to perturb in models/cgcnn/train_cgcnn.py improve model card metrics section layout
1 parent e526e62 commit 24f6868

16 files changed

+232
-148
lines changed

matbench_discovery/data.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def load_train_test(
144144
"Wrenformer": "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv",
145145
"MEGNet": "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv",
146146
"M3GNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
147-
"M3GNet MEGNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
148-
"BOWSR MEGNet": "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv",
147+
"M3GNet + MEGNet": "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv",
148+
"BOWSR + MEGNet": "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv",
149149
}
150150

151151

@@ -222,7 +222,7 @@ def load_df_wbm_preds(
222222

223223
df_out = df_wbm.copy()
224224
for model_name, df in dfs.items():
225-
model_key = model_name.lower().replace(" ", "_")
225+
model_key = model_name.lower().replace(" + ", "_").replace(" ", "_")
226226
if f"e_form_per_atom_{model_key}" in df:
227227
df_out[model_name] = df[f"e_form_per_atom_{model_key}"]
228228

matbench_discovery/metrics.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Centralize data-loading and computing metrics for plotting scripts"""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Sequence
46

57
import numpy as np
@@ -99,8 +101,8 @@ def stable_metrics(
99101

100102

101103
models = sorted(
102-
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet "
103-
"MEGNet, BOWSR MEGNet".split(", ")
104+
"Wrenformer, CGCNN, Voronoi Random Forest, MEGNet, M3GNet + MEGNet, "
105+
"BOWSR + MEGNet".split(", ")
104106
)
105107
e_form_col = "e_form_per_atom_mp2020_corrected"
106108
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"

models/bowsr/metadata.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model_name: BOWSR MEGNet
1+
model_name: BOWSR + MEGNet
22
model_version: 2022.9.20
33
matbench_discovery_version: 1.0
44
date_added: "2022-11-17"

models/cgcnn/metadata.yml

+52-24
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,53 @@
1-
model_name: CGCNN
2-
model_version: 0.1.0 # the aviary version
3-
matbench_discovery_version: 1.0
4-
date_added: "2022-12-28"
5-
authors:
6-
- name: Tian Xie
7-
8-
affiliation: Massachusetts Institute of Technology
9-
url: https://txie.me
10-
- name: Jeffrey C. Grossman
11-
affiliation: Massachusetts Institute of Technology
12-
url: https://dmse.mit.edu/people/jeffrey-c-grossman
13-
repo: https://github.com/txie-93/cgcnn
14-
doi: https://doi.org/10.1103/PhysRevLett.120.145301
15-
preprint: https://arxiv.org/abs/1710.10324
16-
requirements:
17-
aviary: 0.1.0
18-
torch: 1.11.0
19-
torch-scatter: 2.0.9
20-
numpy: 1.24.0
21-
pandas: 1.5.1
22-
trained_on_benchmark: true
1+
- model_name: CGCNN
2+
model_version: 0.1.0 # the aviary version
3+
matbench_discovery_version: 1.0
4+
date_added: "2022-12-28"
5+
authors:
6+
- name: Tian Xie
7+
8+
affiliation: Massachusetts Institute of Technology
9+
url: https://txie.me
10+
- name: Jeffrey C. Grossman
11+
affiliation: Massachusetts Institute of Technology
12+
url: https://dmse.mit.edu/people/jeffrey-c-grossman
13+
repo: https://github.com/txie-93/cgcnn
14+
doi: https://doi.org/10.1103/PhysRevLett.120.145301
15+
preprint: https://arxiv.org/abs/1710.10324
16+
requirements:
17+
aviary: 0.1.0
18+
torch: 1.11.0
19+
torch-scatter: 2.0.9
20+
numpy: 1.24.0
21+
pandas: 1.5.1
22+
trained_on_benchmark: true
2323

24-
hyperparams:
25-
Ensemble Size: 10
24+
hyperparams:
25+
Ensemble Size: 10
26+
27+
- model_name: CGCNN+P
28+
model_version: 0.1.0 # the aviary version
29+
matbench_discovery_version: 1.0
30+
date_added: "2023-02-03"
31+
authors:
32+
- name: Jason B. Gibson
33+
affiliation: University of Florida
34+
- name: Ajinkya C. Hire
35+
affiliation: University of Florida
36+
- name: Richard G. Hennig
37+
affiliation: University of Florida
38+
url: https://hennig.mse.ufl.edu
39+
40+
repo: https://github.com/JasonGibsonUfl/Augmented_CGCNN
41+
doi: https://doi.org/10.1038/s41524-022-00891-8
42+
preprint: https://arxiv.org/abs/2202.13947
43+
requirements:
44+
aviary: 0.1.0
45+
torch: 1.11.0
46+
torch-scatter: 2.0.9
47+
numpy: 1.24.0
48+
pandas: 1.5.1
49+
trained_on_benchmark: true
50+
51+
hyperparams:
52+
Ensemble Size: 10
53+
Perturbations: 5

models/cgcnn/train_cgcnn.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
target_col = "formation_energy_per_atom"
2929
input_col = "structure"
3030
id_col = "material_id"
31-
augment = 0 # 0 for no augmentation, n>1 means train on n perturbations of each crystal
31+
perturb = 0 # 0 for no perturbation, n>1 means train on n perturbations of each crystal
3232
# in the training set all assigned the same original target energy
33-
job_name = f"train-cgcnn-robust-{augment=}{'-debug' if DEBUG else ''}"
33+
job_name = f"train-cgcnn-robust-{perturb=}{'-debug' if DEBUG else ''}"
3434
print(f"{job_name=}")
3535
robust = "robust" in job_name.lower()
3636
ensemble_size = 10
@@ -67,7 +67,7 @@
6767

6868
df_aug = df.copy()
6969
structs = df_aug.pop(input_col)
70-
for idx in trange(augment, desc="Augmenting"):
70+
for idx in trange(perturb, desc="Generating perturbed structures"):
7171
df_aug[input_col] = [perturb_structure(x) for x in structs]
7272
df = pd.concat([df, df_aug.set_index(f"{x}-aug={idx+1}" for x in df_aug.index)])
7373

@@ -108,7 +108,7 @@
108108
train_df=dict(shape=str(train_data.df.shape), columns=", ".join(train_df)),
109109
test_df=dict(shape=str(test_data.df.shape), columns=", ".join(test_df)),
110110
slurm_vars=slurm_vars,
111-
augment=augment,
111+
perturb=perturb,
112112
input_col=input_col,
113113
)
114114

models/m3gnet/metadata.yml

+56-24
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,56 @@
1-
model_name: M3GNet
2-
model_version: 2022.9.20
3-
matbench_discovery_version: 1.0
4-
date_added: "2022-09-20"
5-
authors:
6-
- name: Chi Chen
7-
affiliation: UC San Diego
8-
role: Model
9-
- name: Shyue Ping Ong
10-
affiliation: UC San Diego
11-
orcid: https://orcid.org/0000-0001-5726-2587
12-
13-
repo: https://github.com/materialsvirtuallab/m3gnet
14-
url: https://materialsvirtuallab.github.io/m3gnet
15-
doi: https://doi.org/10.1038/s43588-022-00349-3
16-
preprint: https://arxiv.org/abs/2202.02450
17-
requirements:
18-
m3gnet: 0.1.0
19-
pymatgen: 2022.10.22
20-
numpy: 1.24.0
21-
pandas: 1.5.1
22-
trained_on_benchmark: false
23-
notes:
24-
training: Using pre-trained model released with paper. Was only trained on a subset of 62,783 MP relaxation trajectories in the 2018 database release (see [related issue](https://github.com/materialsvirtuallab/m3gnet/issues/20#issuecomment-1207087219)).
1+
- model_name: M3GNet
2+
model_version: 2022.9.20
3+
matbench_discovery_version: 1.0
4+
date_added: "2022-09-20"
5+
authors:
6+
- name: Chi Chen
7+
affiliation: UC San Diego
8+
role: Model
9+
- name: Shyue Ping Ong
10+
affiliation: UC San Diego
11+
orcid: https://orcid.org/0000-0001-5726-2587
12+
13+
repo: https://github.com/materialsvirtuallab/m3gnet
14+
url: https://materialsvirtuallab.github.io/m3gnet
15+
doi: https://doi.org/10.1038/s43588-022-00349-3
16+
preprint: https://arxiv.org/abs/2202.02450
17+
requirements:
18+
m3gnet: 0.1.0
19+
pymatgen: 2022.10.22
20+
numpy: 1.24.0
21+
pandas: 1.5.1
22+
trained_on_benchmark: false
23+
notes:
24+
training: Using pre-trained model released with paper. Was only trained on a subset of 62,783 MP relaxation trajectories in the 2018 database release (see [related issue](https://github.com/materialsvirtuallab/m3gnet/issues/20#issuecomment-1207087219)).
25+
26+
- model_name: M3GNet + MEGNet
27+
model_version: 2022.9.20
28+
matbench_discovery_version: 1.0
29+
date_added: "2023-02-03"
30+
authors:
31+
- name: Chi Chen
32+
affiliation: UC San Diego
33+
role: Model
34+
- name: Weike Ye
35+
affiliation: UC San Diego
36+
- name: Yunxing Zuo
37+
affiliation: UC San Diego
38+
- name: Chen Zheng
39+
affiliation: UC San Diego
40+
- name: Shyue Ping Ong
41+
affiliation: UC San Diego
42+
orcid: https://orcid.org/0000-0001-5726-2587
43+
44+
repo: https://github.com/materialsvirtuallab/m3gnet
45+
url: https://materialsvirtuallab.github.io/m3gnet
46+
doi: https://doi.org/10.1038/s43588-022-00349-3
47+
preprint: https://arxiv.org/abs/2202.02450
48+
requirements:
49+
m3gnet: 0.1.0
50+
megnet: 1.3.2
51+
pymatgen: 2022.10.22
52+
numpy: 1.24.0
53+
pandas: 1.5.1
54+
trained_on_benchmark: false
55+
notes:
56+
training: Using pre-trained model released with paper. Was only trained on a subset of 62,783 MP relaxation trajectories in the 2018 database release (see [related issue](https://github.com/materialsvirtuallab/m3gnet/issues/20#issuecomment-1207087219)).

scripts/compile_metrics.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222

2323
# %%
24+
model_stats: dict[str, dict[str, str | int | float]] = {}
2425
models: dict[str, dict[str, Any]] = {
2526
"CGCNN": dict(
2627
n_runs=10,
@@ -57,7 +58,7 @@
5758
display_name={"$regex": "m3gnet-wbm-IS2RE"},
5859
),
5960
),
60-
"BOWSR MEGNet": dict(
61+
"BOWSR + MEGNet": dict(
6162
n_runs=500,
6263
filters=dict(
6364
created_at={"$gt": "2023-01-20", "$lt": "2023-01-22"},
@@ -66,15 +67,14 @@
6667
),
6768
}
6869

69-
assert set(models) == set(PRED_FILENAMES), f"{set(models)=} != {set(PRED_FILENAMES)=}"
70-
71-
72-
model_stats: dict[str, dict[str, str | int | float]] = {}
70+
assert not (
71+
unknown_models := set(models) - set(PRED_FILENAMES)
72+
), f"{unknown_models=} missing predictions file"
7373

7474

7575
# %% calculate total model run times from wandb logs
7676
# NOTE these model run times are pretty meaningless since some models were run on GPU
77-
# (Wrenformer and CGCNN), others on CPU. Also BOWSR MEGNet, M3GNet and MEGNet weren't
77+
# (Wrenformer and CGCNN), others on CPU. Also BOWSR + MEGNet, M3GNet and MEGNet weren't
7878
# trained from scratch. Their run times only indicate the time needed to predict the
7979
# test set.
8080

@@ -110,24 +110,23 @@
110110
title=f"Run time distribution for {model}", xlabel="Run time [h]", ylabel="Count"
111111
)
112112

113+
model_stats["M3GNet + MEGNet"] = model_stats["M3GNet"].copy()
114+
model_stats["M3GNet + MEGNet"][time_col] = (
115+
model_stats["MEGNet"][time_col] + model_stats["M3GNet"][time_col] # type: ignore
116+
)
117+
113118
df_metrics = pd.DataFrame(model_stats).T
114119
df_metrics.index.name = "Model"
115-
# on 2022-11-28:
116-
# run_times = {'Voronoi Random Forest': 739608,
117-
# 'Wrenformer': 208399,
118-
# 'MEGNet': 12396,
119-
# 'M3GNet': 301138,
120-
# 'BOWSR MEGNet': 9105237}
121120

122121

123122
# %%
124-
df_wbm = load_df_wbm_preds(list(models))
123+
df_wbm = load_df_wbm_preds(list(model_stats))
125124
e_form_col = "e_form_per_atom_mp2020_corrected"
126125
each_true_col = "e_above_hull_mp2020_corrected_ppd_mp"
127126

128127

129128
# %%
130-
for model in models:
129+
for model in model_stats:
131130
each_pred = df_wbm[each_true_col] + df_wbm[model] - df_wbm[e_form_col]
132131

133132
metrics = stable_metrics(df_wbm[each_true_col], each_pred)
@@ -165,12 +164,11 @@
165164
}
166165
df_styled.set_table_styles([dict(selector=sel, props=styles[sel]) for sel in styles])
167166

168-
html_path = f"{FIGS}/{today}-metrics-table.svelte"
169-
df_styled.to_html(html_path)
167+
# df_styled.to_html(f"{FIGS}/{today}-metrics-table.svelte")
170168

171169

172170
# %% write model metrics to json for use by the website
173-
df_metrics["missing_preds"] = df_wbm[list(models)].isna().sum()
171+
df_metrics["missing_preds"] = df_wbm[list(model_stats)].isna().sum()
174172
df_metrics["missing_percent"] = [
175173
f"{x / len(df_wbm):.2%}" for x in df_metrics.missing_preds
176174
]

scripts/rolling_mae_vs_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# %%
1111
# model = "Wrenformer"
12-
model = "M3GNet MEGNet"
12+
model = "M3GNet + MEGNet"
1313
ax, df_err, df_std = rolling_mae_vs_hull_dist(
1414
e_above_hull_true=df_wbm[each_true_col],
1515
e_above_hull_errors={model: df_wbm[e_form_col] - df_wbm[model]},

site/src/app.css

-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ ul {
8181
padding-left: 1em;
8282
}
8383
label {
84-
font-weight: bold;
8584
cursor: pointer;
8685
}
8786
img {

0 commit comments

Comments
 (0)