Skip to content

Commit c8eaebd

Browse files
committed
support model_name as array in metadata.yml files to share metadata between multiple types of a model
add chgnet + chgnet_megnet to PredFiles update model-stats.json, metrics-table.svelte, model-run-times.svelte show only n_best=8 models by default on model page and sort descending by F1
1 parent d564ade commit c8eaebd

22 files changed

+506580
-254700
lines changed

matbench_discovery/plots.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@
5050
)
5151
model_labels = dict(
5252
bowsr_megnet="BOWSR + MEGNet",
53+
chgnet="CHGNet",
54+
chgnet_megnet="CHGNet + MEGNet",
5355
cgcnn_p="CGCNN+P",
5456
cgcnn="CGCNN",
5557
m3gnet_megnet="M3GNet + MEGNet",
5658
m3gnet="M3GNet",
5759
megnet="MEGNet",
58-
megnet_old="MEGNet Old",
5960
voronoi_rf="Voronoi RF",
6061
wrenformer="Wrenformer",
6162
dft="DFT",

matbench_discovery/preds.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,25 @@ class PredFiles(Files):
2929
_root = f"{ROOT}/models/"
3030
_key_map = model_labels # remap model keys below to pretty plot labels (see Files)
3131

32+
# bowsr optimizer coupled with original megnet
33+
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
34+
# default CHGNet model from publication with 400,438 params
35+
chgnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
36+
chgnet_megnet = "chgnet/2023-03-04-chgnet-wbm-IS2RE.csv"
3237
# CGCnn 10-member ensemble
3338
cgcnn = "cgcnn/2023-01-26-test-cgcnn-wbm-IS2RE/cgcnn-ensemble-preds.csv"
34-
3539
# cgcnn 10-member ensemble with 5-fold training set perturbations
3640
cgcnn_p = "cgcnn/2023-02-05-cgcnn-perturb=5.csv"
37-
38-
# magpie composition+voronoi tessellation structure features + sklearn random forest
39-
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
40-
41-
# wrenformer 10-member ensemble
42-
wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
43-
44-
# original megnet straight from publication, not re-trained
45-
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
46-
megnet_old = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
47-
4841
# original m3gnet straight from publication, not re-trained
4942
m3gnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
50-
5143
# m3gnet-relaxed structures fed into megnet for formation energy prediction
5244
m3gnet_megnet = "m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv"
53-
# bowsr optimizer coupled with original megnet
54-
bowsr_megnet = "bowsr/2023-01-23-bowsr-megnet-wbm-IS2RE.csv"
45+
# original megnet straight from publication, not re-trained
46+
megnet = "megnet/2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
47+
# magpie composition+voronoi tessellation structure features + sklearn random forest
48+
voronoi_rf = "voronoi/2022-11-27-train-test/e-form-preds-IS2RE.csv"
49+
# wrenformer 10-member ensemble
50+
wrenformer = "wrenformer/2022-11-15-wrenformer-IS2RE-preds.csv"
5551

5652

5753
PRED_FILES = PredFiles()

models/chgnet/2023-03-04-chgnet-wbm-IS2RE.csv

+251,739
Large diffs are not rendered by default.

models/chgnet/join_chgnet_results.py

+22-28
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from matbench_discovery import today
2222
from matbench_discovery.data import DATA_FILES, as_dict_handler
2323
from matbench_discovery.energy import get_e_form_per_atom
24-
from matbench_discovery.preds import df_wbm as df_summary
25-
from matbench_discovery.preds import e_form_col
24+
from matbench_discovery.preds import df_wbm, e_form_col
2625

2726
__author__ = "Janosh Riebesell"
2827
__date__ = "2023-03-01"
@@ -55,52 +54,41 @@
5554

5655

5756
# %%
58-
df_wbm = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(
57+
df_cse = pd.read_json(DATA_FILES.wbm_computed_structure_entries).set_index(
5958
"material_id"
6059
)
6160

62-
df_wbm["cse"] = [
63-
ComputedStructureEntry.from_dict(x) for x in tqdm(df_wbm.computed_structure_entry)
61+
df_cse["cse"] = [
62+
ComputedStructureEntry.from_dict(x) for x in tqdm(df_cse.computed_structure_entry)
6463
]
6564

6665

67-
# %% transfer chgnet energies and relaxed structures WBM CSEs
66+
# %% transfer CHGNet energies and relaxed structures WBM CSEs since MP2020 energy
67+
# corrections applied below are structure-dependent (for oxides and sulfides)
6868
cse: ComputedStructureEntry
6969
for row in tqdm(df_chgnet.itertuples(), total=len(df_chgnet)):
7070
mat_id, struct_dict, chgnet_energy, *_ = row
7171
chgnet_struct = Structure.from_dict(struct_dict)
72-
cse = df_wbm.loc[mat_id, "cse"]
72+
cse = df_cse.loc[mat_id, "cse"]
7373
cse._energy = chgnet_energy # cse._energy is the uncorrected energy
7474
cse._structure = chgnet_struct
7575
df_chgnet.loc[mat_id, "cse"] = cse
7676

7777

78-
# %%
79-
df_chgnet["e_form_per_atom_chgnet_uncorrected"] = [
80-
get_e_form_per_atom(cse) for cse in tqdm(df_chgnet.cse)
81-
]
82-
83-
84-
# %% apply energy corrections
78+
# %% apply energy corrections to CSEs with CHGNet
8579
out = MaterialsProject2020Compatibility().process_entries(
8680
df_chgnet.cse, verbose=True, clean=True
8781
)
8882
assert len(out) == len(df_chgnet)
8983

9084

9185
# %% compute corrected formation energies
92-
df_chgnet["e_form_per_atom_chgnet"] = [
93-
get_e_form_per_atom(cse) for cse in tqdm(df_chgnet.cse)
94-
]
95-
96-
df_chgnet[e_form_col] = df_summary[e_form_col]
86+
e_form_chgnet_col = "e_form_per_atom_chgnet"
87+
df_chgnet[e_form_chgnet_col] = [get_e_form_per_atom(cse) for cse in tqdm(df_chgnet.cse)]
9788

9889

9990
# %%
100-
ax = density_scatter(
101-
df=df_chgnet, x="e_form_per_atom_chgnet", y="e_form_per_atom_chgnet_uncorrected"
102-
)
103-
ax = density_scatter(df=df_chgnet, x="e_form_per_atom_chgnet", y=e_form_col)
91+
ax = density_scatter(x=df_wbm[e_form_col], y=df_chgnet[e_form_chgnet_col])
10492

10593

10694
# %% load 2019 MEGNet formation energy model
@@ -109,7 +97,7 @@
10997

11098

11199
# %% predict formation energies on chgnet relaxed structure with MEGNet
112-
for material_id, cse in tqdm(df_wbm.cse.items(), total=len(df_wbm)):
100+
for material_id, cse in tqdm(df_cse.cse.items(), total=len(df_cse)):
113101
if material_id in megnet_e_form_preds:
114102
continue
115103
try:
@@ -119,17 +107,23 @@
119107
except Exception as exc:
120108
print(f"Failed to predict {material_id=}: {exc}")
121109

122-
df_chgnet["e_form_per_atom_chgnet_megnet"] = pd.Series(megnet_e_form_preds)
110+
e_form_megnet_col = "e_form_per_atom_chgnet_megnet"
111+
# remove legacy MP corrections that MEGNet was trained on and apply newer MP2020
112+
# corrections instead
113+
df_chgnet[e_form_megnet_col] = (
114+
pd.Series(megnet_e_form_preds)
115+
- df_wbm.e_correction_per_atom_mp_legacy
116+
+ df_wbm.e_correction_per_atom_mp2020
117+
)
123118

124119
assert (
125120
n_isna := df_chgnet.e_form_per_atom_chgnet_megnet.isna().sum()
126121
) < 10, f"{n_isna=}, expected 7 or similar"
127122

128123

129124
# %%
130-
ax = density_scatter(
131-
df=df_chgnet, x="e_form_per_atom_chgnet_megnet", y="e_form_per_atom_chgnet"
132-
)
125+
ax = density_scatter(df=df_chgnet, x=e_form_chgnet_col, y=e_form_megnet_col)
126+
ax = density_scatter(df=df_chgnet, x=e_form_col, y=e_form_megnet_col)
133127

134128

135129
# %%

models/chgnet/metadata.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
model_name: CHGNet
1+
model_name: [CHGNet, CHGNet + MEGNet]
22
model_version: 0.0.1
33
matbench_discovery_version: 1.0
44
date_added: "2023-03-03"
@@ -36,5 +36,5 @@ trained_on_benchmark: false
3636
notes:
3737
description: |
3838
The Crystal Hamiltonian Graph Neural Network (CHGNet) is a universal GNN-based interatomic potential trained on energies, forces, stresses and magnetic moments from the MP trajectory dataset containing ∼1.5 million inorganic structures.
39-
![CHGNet Pipeline](https://user-images.githubusercontent.com/30958850/222842305-b6ed2468-8773-4e03-9de5-20c8e8de030e.svg)
40-
training: Using pre-trained model released with preprint. Training set unreleased until after review.
39+
![CHGNet Pipeline](https://user-images.githubusercontent.com/30958850/222924937-1d09bbce-ee18-4b19-8061-ec689cd15887.svg)
40+
training: Using pre-trained model with 400,438 params released with preprint. Training set unreleased at time of writing.

models/chgnet/test_chgnet.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939
slurm_vars = slurm_submit(
4040
job_name=job_name,
4141
out_dir=out_dir,
42-
partition="icelake-himem",
43-
account="LEE-SL3-CPU",
42+
partition="ampere",
43+
account="LEE-SL3-GPU",
4444
time="3:0:0",
45-
array=f"1-{slurm_array_task_count}",
46-
slurm_flags=("--mem", str(12_000)),
45+
# array=f"1-{slurm_array_task_count}",
46+
slurm_flags="--nodes 1 --gpus-per-node 1",
4747
)
4848

4949

@@ -104,14 +104,12 @@
104104
except Exception as error:
105105
print(f"Failed to relax {material_id}: {error}")
106106
continue
107-
relax_dict = {
107+
relax_results[material_id] = {
108108
"chgnet_structure": relax_result["final_structure"],
109109
"chgnet_trajectory": relax_result["trajectory"].__dict__,
110-
e_pred_col: relax_result["energies"][-1],
110+
e_pred_col: relax_result["trajectory"].energies[-1],
111111
}
112112

113-
relax_results[material_id] = relax_dict
114-
115113

116114
# %%
117115
df_out = pd.DataFrame(relax_results).T
@@ -123,7 +121,9 @@
123121
# %%
124122
df_wbm[e_pred_col] = df_out[e_pred_col]
125123
table = wandb.Table(
126-
dataframe=df_wbm[["uncorrected_energy", e_pred_col, "formula"]].reset_index()
124+
dataframe=df_wbm.dropna()[
125+
["uncorrected_energy", e_pred_col, "formula"]
126+
].reset_index()
127127
)
128128

129129
title = f"CHGNet {task_type} ({len(df_wbm):,})"

0 commit comments

Comments
 (0)