Skip to content

Commit 8166801

Browse files
committed
include legacy MP energy corrections in data/wbm/2022-10-19-wbm-summary.csv, use them to remove old and apply new corrections in test_megnet.py
1 parent da39074 commit 8166801

File tree

8 files changed

+254458
-254416
lines changed

8 files changed

+254458
-254416
lines changed

data/wbm/fetch_process_wbm_dataset.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from aviary.wren.utils import get_aflow_label_from_spglib
1111
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
1212
from pymatgen.core import Composition, Structure
13-
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
13+
from pymatgen.entries.compatibility import (
14+
MaterialsProject2020Compatibility,
15+
MaterialsProjectCompatibility,
16+
)
1417
from pymatgen.entries.computed_entries import ComputedStructureEntry
1518
from pymatviz import density_scatter
1619
from pymatviz.utils import save_fig
@@ -184,7 +187,6 @@ def increment_wbm_material_id(wbm_id: str) -> str:
184187
cse_step_paths = sorted(glob(f"{module_dir}/raw/wbm-cse-step-*.json.bz2"))
185188
assert len(cse_step_paths) == 5
186189

187-
188190
"""
189191
There is a discrepancy of 6 entries between the files on Materials Cloud containing the
190192
ComputedStructureEntries (CSE) and those on Google Drive containing initial+relaxed
@@ -496,10 +498,25 @@ def fix_bad_struct_index_mismatch(material_id: str) -> str:
496498
assert all(df_summary.n_sites == [len(cse.structure) for cse in df_wbm.cse])
497499

498500

499-
compat_out = MaterialsProject2020Compatibility().process_entries(
500-
entries=df_wbm.cse, clean=True, verbose=True
501+
# entries are corrected in-place by default so we apply legacy corrections first
502+
# and then leave the new corrections in place below
503+
# having both old and new corrections allows updating predictions from older models
504+
# like MEGNet that were trained on MP release prior to new corrections by subtracting
505+
# old corrections and adding the new ones
506+
entries_old_corr = MaterialsProjectCompatibility().process_entries(
507+
df_wbm.cse, clean=True, verbose=True
508+
)
509+
assert len(entries_old_corr) == len(df_wbm), f"{len(entries_old_corr)=} {len(df_wbm)=}"
510+
511+
# extract legacy MP energy corrections to df_megnet
512+
e_correction_col = "e_correction_per_atom_mp_legacy"
513+
df_wbm[e_correction_col] = [cse.correction_per_atom for cse in df_wbm.cse]
514+
515+
# clean up legacy corrections and apply new corrections
516+
entries_new_corr = MaterialsProject2020Compatibility().process_entries(
517+
df_wbm.cse, clean=True, verbose=True
501518
)
502-
assert len(compat_out) == len(df_wbm) == len(df_summary)
519+
assert len(entries_new_corr) == len(df_wbm), f"{len(entries_new_corr)=} {len(df_wbm)=}"
503520

504521
n_corrected = sum(cse.uncorrected_energy != cse.energy for cse in df_wbm.cse)
505522
assert n_corrected == 100_930, f"{n_corrected=} expected 100,930"

matbench_discovery/plots.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
m3gnet_megnet="M3GNet + MEGNet",
5656
m3gnet="M3GNet",
5757
megnet="MEGNet",
58+
megnet_old="MEGNet Old",
5859
voronoi_rf="Voronoi Random Forest",
5960
wrenformer="Wrenformer",
6061
dft="DFT",

models/m3gnet/2022-10-31-m3gnet-wbm-IS2RE.csv

+254,395-254,395
Large diffs are not rendered by default.

models/m3gnet/join_m3gnet_results.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from matbench_discovery import today
22-
from matbench_discovery.data import DATA_FILES, as_dict_handler
22+
from matbench_discovery.data import DATA_FILES, as_dict_handler, df_wbm
2323
from matbench_discovery.energy import get_e_form_per_atom
2424

2525
__author__ = "Janosh Riebesell"
@@ -124,7 +124,15 @@
124124
except Exception as exc:
125125
print(f"Failed to predict {material_id=}: {exc}")
126126

127-
df_m3gnet["e_form_per_atom_m3gnet_megnet"] = pd.Series(megnet_e_form_preds)
127+
pred_col_megnet = "e_form_per_atom_m3gnet_megnet"
128+
df_m3gnet[f"{pred_col_megnet}_old"] = pd.Series(megnet_e_form_preds)
129+
# remove legacy MP corrections that MEGNet was trained on and apply newer MP2020
130+
# corrections instead
131+
df_m3gnet[pred_col_megnet] = (
132+
df_m3gnet[f"{pred_col_megnet}_old"]
133+
- df_wbm.e_correction_per_atom_mp_legacy
134+
+ df_wbm.e_correction_per_atom_mp2020
135+
)
128136

129137
assert (
130138
n_isna := df_m3gnet.e_form_per_atom_m3gnet_megnet.isna().sum()
@@ -145,5 +153,5 @@
145153
df_m3gnet.select_dtypes("number").to_csv(out_path.replace(".json.gz", ".csv"))
146154

147155
# in_path = f"{module_dir}/2022-10-31-m3gnet-wbm-IS2RE.json.gz"
148-
# df_m3gnet_csv = pd.read_csv(in_path.replace(".json.gz", ".csv"))
156+
# df_m3gnet = pd.read_csv(in_path.replace(".json.gz", ".csv")).set_index("material_id")
149157
# df_m3gnet = pd.read_json(in_path).set_index("material_id")

models/megnet/test_megnet.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import pandas as pd
1515
import wandb
1616
from megnet.utils.models import load_model
17+
from pymatgen.core import Structure
18+
from pymatgen.entries.computed_entries import ComputedStructureEntry
1719
from sklearn.metrics import r2_score
1820
from tqdm import tqdm
1921

@@ -75,24 +77,19 @@
7577

7678
# %%
7779
if task_type == "IS2RE":
78-
from pymatgen.core import Structure
79-
8080
structures = df_wbm_structs.initial_structure.map(Structure.from_dict)
8181
elif task_type == "RS2RE":
82-
from pymatgen.entries.computed_entries import ComputedStructureEntry
83-
8482
df_wbm_structs.cse = df_wbm_structs.cse.map(ComputedStructureEntry.from_dict)
8583
structures = df_wbm_structs.cse.map(lambda x: x.structure)
8684
else:
8785
raise ValueError(f"Unknown {task_type = }")
8886

8987
megnet_e_form_preds = {}
90-
for material_id, structure in tqdm(
91-
structures.items(), disable=None, total=len(structures)
92-
):
88+
for material_id in tqdm(structures, disable=None):
9389
if material_id in megnet_e_form_preds:
9490
continue
9591
try:
92+
structure = structures[material_id]
9693
e_form_per_atom = megnet_mp_e_form.predict_structure(structure)[0]
9794
megnet_e_form_preds[material_id] = e_form_per_atom
9895
except Exception as exc:
@@ -104,9 +101,23 @@
104101
print(f"{len(structures)=:,}")
105102
print(f"missing: {len(structures) - len(megnet_e_form_preds):,}")
106103
pred_col = "e_form_per_atom_megnet"
107-
df_wbm[pred_col] = pd.Series(megnet_e_form_preds)
104+
# old columns contains direct MEGNet predictions which was trained on legacy-corrected
105+
# MP formation energies
106+
df_wbm[f"{pred_col}_old"] = pd.Series(megnet_e_form_preds)
107+
108+
# remove legacy MP corrections that MEGNet was trained on and apply newer MP2020
109+
# corrections instead
110+
df_wbm[pred_col] = (
111+
df_wbm[pred_col]
112+
- df_wbm.e_correction_per_atom_mp_legacy
113+
+ df_wbm.e_correction_per_atom_mp2020
114+
)
115+
116+
df_wbm.filter(like=pred_col).round(4).to_csv(
117+
"2022-11-18-megnet-wbm-IS2RE/megnet-e-form-preds.csv"
118+
)
108119

109-
df_wbm[pred_col].round(4).to_csv(out_path)
120+
# df_megnet = pd.read_csv(f"{ROOT}/models/{PRED_FILES.megnet}").set_index("material_id")
110121

111122

112123
# %%

scripts/rolling_mae_vs_hull_dist.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# %%
1414
# model = "Wrenformer"
1515
model = "M3GNet + MEGNet"
16+
model = "MEGNet"
17+
model = "MEGNet Old"
1618
ax, df_err, df_std = rolling_mae_vs_hull_dist(
1719
e_above_hull_true=df_wbm[each_true_col],
1820
e_above_hull_errors={model: df_wbm[e_form_col] - df_wbm[model]},
@@ -21,8 +23,8 @@
2123
# template="plotly_white",
2224
)
2325

24-
MAE, DAF = df_metrics[model].MAE, df_metrics[model].DAF
25-
title = f"{today} {model} · {MAE=:.2f} · {DAF=:.2f}"
26+
MAE, DAF, F1 = df_metrics[model][["MAE", "DAF", "F1"]]
27+
title = f"{today} {model} · {MAE=:.2f} · {DAF=:.2f} · {F1=:.2f}"
2628
if backend == "matplotlib":
2729
fig = ax.figure
2830
fig.set_size_inches(6, 5)

site/src/routes/contribute/+page.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ assert sorted(DATA_FILES) == [
3535

3636
df_wbm = load_train_test("wbm-summary", version="v1.0.0")
3737

38-
assert df_wbm.shape == (256963, 14)
38+
assert df_wbm.shape == (256963, 15)
3939

4040
assert list(df_wbm) == [
4141
"formula",
@@ -47,6 +47,7 @@ assert list(df_wbm) == [
4747
"bandgap_pbe",
4848
"uncorrected_energy_from_cse",
4949
"e_correction_per_atom_mp2020",
50+
"e_correction_per_atom_mp_legacy",
5051
"e_above_hull_mp2020_corrected_ppd_mp",
5152
"e_form_per_atom_uncorrected",
5253
"e_form_per_atom_mp2020_corrected",
@@ -65,6 +66,8 @@ assert list(df_wbm) == [
6566
1. `bandgap_pbe`: PBE-level DFT band gap from [WBM paper]
6667
1. `uncorrected_energy_from_cse`: Should be the same as `uncorrected_energy`. There are 2 cases where the absolute difference reported in the summary file and in the computed structure entries exceeds 0.1 eV (`wbm-2-3218`, `wbm-1-56320`) which we attribute to rounding errors.
6768
1. `e_form_per_atom_mp2020_corrected`: Matbench Discovery takes these as ground truth for the formation energy. Includes MP2020 energy corrections (latest correction scheme at time of release).
69+
1. `e_correction_per_atom_mp2020`: [`MaterialsProject2020Compatibility`](https://pymatgen.org/pymatgen.entries.compatibility.html#pymatgen.entries.compatibility.MaterialsProject2020Compatibility) energy corrections in eV/atom.
70+
1. `e_correction_per_atom_mp_legacy`: Legacy [`MaterialsProjectCompatibility`](https://pymatgen.org/pymatgen.entries.compatibility.html#pymatgen.entries.compatibility.MaterialsProjectCompatibility) energy corrections in eV/atom. Having both old and new corrections allows updating predictions from older models like MEGNet that were trained on MP formation energies treated with the old correction scheme.
6871
1. `e_above_hull_mp2020_corrected_ppd_mp`: Energy above hull distances in eV/atom after applying the MP2020 correction scheme. The convex hull in question is the one spanned by all ~145k Materials Project `ComputedStructureEntries`. Matbench Discovery takes these as ground truth for material stability. Any value above 0 is assumed to be an unstable/metastable material.
6972
<!-- TODO document remaining columns, or maybe drop them from df -->
7073

tests/test_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def as_dict(self) -> dict[str, Any]:
173173

174174

175175
def test_df_wbm() -> None:
176-
assert df_wbm.shape == (256963, 14)
176+
assert df_wbm.shape == (256963, 15)
177177
assert df_wbm.index.name == "material_id"
178178
assert set(df_wbm) > {"bandgap_pbe", "formula", "material_id"}
179179

0 commit comments

Comments
 (0)