Skip to content

Commit 2517855

Browse files
committed
make 2nd arg elemental_ref_entries to get_form_energy_per_atom() optional
now defaults to mp_elem_reference_entries
1 parent 0473994 commit 2517855

9 files changed

+96
-81
lines changed

mb_discovery/build_phase_diagram.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,14 @@
8989
json.dump(elemental_ref_entries, file, default=lambda x: x.as_dict())
9090

9191

92-
# %% load MP elemental reference entries to compute formation energies
93-
mp_elem_refs_path = f"{ROOT}/data/2022-09-19-mp-elemental-reference-entries.json"
94-
mp_reference_entries = (
95-
pd.read_json(mp_elem_refs_path, typ="series").map(ComputedEntry.from_dict).to_dict()
96-
)
97-
98-
9992
df_mp = pd.read_json(f"{ROOT}/data/2022-08-13-mp-all-energies.json.gz").set_index(
10093
"material_id"
10194
)
10295

10396

10497
# %%
10598
df_mp["our_mp_e_form"] = [
106-
get_form_energy_per_atom(all_mp_computed_entries[mp_id], mp_reference_entries)
107-
for mp_id in df_mp.index
99+
get_form_energy_per_atom(all_mp_computed_entries[mp_id]) for mp_id in df_mp.index
108100
]
109101

110102

mb_discovery/compute_formation_energy.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import itertools
22

3+
import pandas as pd
34
from pymatgen.analysis.phase_diagram import Entry
5+
from pymatgen.entries.computed_entries import ComputedEntry
46
from tqdm import tqdm
57

8+
from mb_discovery import ROOT
9+
610

711
def get_elemental_ref_entries(
812
entries: list[Entry], verbose: bool = False
@@ -38,12 +42,46 @@ def get_elemental_ref_entries(
3842
return elemental_ref_entries
3943

4044

45+
# contains all MP elemental reference entries to compute formation energies
46+
# produced by get_elemental_ref_entries() in build_phase_diagram.py
47+
mp_elem_refs_path = f"{ROOT}/data/2022-09-19-mp-elemental-reference-entries.json"
48+
try:
49+
mp_elem_reference_entries = (
50+
pd.read_json(mp_elem_refs_path, typ="series")
51+
.map(ComputedEntry.from_dict)
52+
.to_dict()
53+
)
54+
except FileNotFoundError:
55+
mp_elem_reference_entries = None
56+
57+
4158
def get_form_energy_per_atom(
42-
entry: Entry, elemental_ref_entries: dict[str, Entry]
59+
entry: Entry, elemental_ref_entries: dict[str, Entry] = None
4360
) -> float:
4461
"""Get the formation energy of a composition from a list of entries and elemental
4562
reference energies.
63+
64+
Args:
65+
entry (Entry): pymatgen Entry (PDEntry, ComputedEntry or ComputedStructureEntry)
66+
to compute formation energy of.
67+
elemental_ref_entries (dict[str, Entry], optional): Must be a complete set of
68+
terminal (i.e. elemental) reference entries containing the lowest energy
69+
phase for each element present in entry. Defaults to MP elemental reference
70+
entries as collected on 2022-09-19 get_elemental_ref_entries(). This was
71+
tested to give the same formation energies as computed by MP.
72+
73+
Returns:
74+
float: formation energy in eV/atom.
4675
"""
76+
if elemental_ref_entries is None:
77+
if mp_elem_reference_entries is None:
78+
raise ValueError(
79+
f"Couldn't load {mp_elem_refs_path=}, you must pass "
80+
f"{elemental_ref_entries=} explicitly."
81+
)
82+
83+
elemental_ref_entries = mp_elem_reference_entries
84+
4785
comp = entry.composition
4886
form_energy = entry.energy - sum(
4987
comp[el] * elemental_ref_entries[str(el)].energy_per_atom

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
df["e_form_per_atom_pred"] = df[pred_cols].mean(axis=1)
6666
if "m3gnet" in dfs:
6767
df = dfs["m3gnet"]
68-
df["e_form_per_atom_pred"] = df.e_form_ppd_2022_01_25
68+
df["e_form_per_atom_pred"] = df.e_form_m3gnet
6969
if "bowsr_megnet" in dfs:
7070
df = dfs["bowsr_megnet"]
7171
df["e_form_per_atom_pred"] = df.e_form_per_atom_bowsr
@@ -76,7 +76,8 @@
7676
stability_crit: StabilityCriterion = "energy"
7777
fig, axs = plt.subplots(2, 3, figsize=(18, 9))
7878

79-
df = dfs[(model_name := "bowsr_megnet")]
79+
model_name = "m3gnet"
80+
df = dfs[model_name]
8081

8182
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
8283
df["e_form_per_atom"] = df_wbm.e_form_per_atom
@@ -109,14 +110,13 @@
109110
axs.flat[-1].set(title=f"Combined ({len(df.filter(like='e_').dropna()):,})")
110111
axs.flat[0].legend(frameon=False, loc="upper left")
111112

112-
img_name = (
113-
f"{today}-{model_name}-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}.pdf"
114-
)
115-
fig.suptitle(img_name.replace("-", "/", 2).replace("-", " "), y=1.07, fontsize=16)
113+
img_name = f"{today}-{model_name}-wbm-hull-dist-hist-{which_energy=}-{stability_crit=}"
114+
suptitle = img_name.replace("-", "/", 2).replace("-", " ")
115+
fig.suptitle(suptitle, y=1.07, fontsize=16)
116116

117117

118118
# %%
119-
ax.figure.savefig(f"{ROOT}/figures/{img_name}")
119+
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")
120120

121121

122122
# %%

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,21 @@
1717
df_hull = pd.read_csv(f"{DATA_DIR}/wbm-e-above-mp-hull.csv").set_index("material_id")
1818

1919
dfs: dict[str, pd.DataFrame] = {}
20-
for model_name in ("Wren", "CGCNN", "Voronoi"):
21-
df = pd.read_csv(
22-
f"{DATA_DIR}/{model_name.lower()}-mp-initial-structures.csv"
23-
).set_index("material_id")
20+
for model_name in ("wren", "cgcnn", "voronoi"):
21+
csv_path = f"{DATA_DIR}/{model_name}-mp-initial-structures.csv"
22+
df = pd.read_csv(csv_path).set_index("material_id")
2423
dfs[model_name] = df
2524

26-
dfs["M3GNet"] = pd.read_json(
25+
dfs["m3gnet"] = pd.read_json(
2726
f"{ROOT}/models/m3gnet/2022-08-16-m3gnet-wbm-IS2RE.json.gz"
2827
).set_index("material_id")
2928

30-
dfs["Wrenformer"] = pd.read_csv(
29+
dfs["wrenformer"] = pd.read_csv(
3130
f"{ROOT}/models/wrenformer/mp/"
3231
"2022-09-20-wrenformer-e_form-ensemble-1-preds-e_form_per_atom.csv"
3332
).set_index("material_id")
3433

35-
dfs["BOWSR Megnet"] = pd.read_json(
34+
dfs["bowsr_megnet"] = pd.read_json(
3635
f"{ROOT}/models/bowsr/2022-09-22-bowsr-wbm-megnet-IS2RE.json.gz"
3736
).set_index("material_id")
3837

@@ -69,16 +68,16 @@
6968
std_total = None
7069

7170
try:
72-
if model_name == "M3GNet":
71+
if model_name == "m3gnet":
7372
model_preds = df.e_form_m3gnet
74-
elif "Wrenformer" in model_name:
73+
elif "wrenformer" in model_name:
7574
model_preds = df.e_form_per_atom_pred_ens
7675
elif len(pred_cols := df.filter(like="e_form_pred").columns) >= 1:
7776
# Voronoi+RF has single prediction column, Wren and CGCNN each have 10
7877
# other cases are unexpected
7978
assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
8079
model_preds = df[pred_cols].mean(axis=1)
81-
elif "BOWSR" in model_name:
80+
elif "bowsr" in model_name:
8281
model_preds = df.e_form_per_atom_bowsr
8382
else:
8483
raise ValueError(f"Unhandled {model_name = }")
@@ -107,7 +106,9 @@
107106
# keep this outside loop so all model names appear in legend
108107
ax.legend(frameon=False, loc="lower right")
109108

109+
img_name = f"{today}-precision-recall-vs-calc-count-{rare=}"
110+
ax.set(title=img_name.replace("-", "/", 2).replace("-", " ").title())
111+
110112

111113
# %%
112-
img_path = f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-{rare=}.pdf"
113-
ax.figure.savefig(img_path)
114+
ax.figure.savefig(f"{ROOT}/figures/{img_name}.pdf")

mb_discovery/plot_scripts/rolling_mae_vs_hull_dist_wbm_batches.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,24 @@
1515
# %%
1616
rare = "all"
1717

18-
df_wbm = pd.read_csv(
18+
df_wren = pd.read_csv(
1919
f"{ROOT}/data/2022-06-11-from-rhys/wren-mp-initial-structures.csv"
2020
).set_index("material_id")
2121

2222
df_hull = pd.read_csv(
2323
f"{ROOT}/data/2022-06-11-from-rhys/wbm-e-above-mp-hull.csv"
2424
).set_index("material_id")
2525

26-
df_wbm["e_above_mp_hull"] = df_hull.e_above_mp_hull
27-
assert df_wbm.e_above_mp_hull.isna().sum() == 0
26+
df_wren["e_above_mp_hull"] = df_hull.e_above_mp_hull
27+
assert df_wren.e_above_mp_hull.isna().sum() == 0
2828

2929
target_col = "e_form_target"
3030

3131
# make sure we average the expected number of ensemble member predictions
32-
assert df_wbm.filter(regex=r"_pred_\d").shape[1] == 10
32+
assert df_wren.filter(regex=r"_pred_\d").shape[1] == 10
3333

34-
df_wbm["e_above_hull_pred"] = (
35-
df_wbm.filter(regex=r"_pred_\d").mean(axis=1) - df_wbm[target_col]
34+
df_wren["e_above_hull_pred"] = (
35+
df_wren.filter(regex=r"_pred_\d").mean(axis=1) - df_wren[target_col]
3636
)
3737

3838

@@ -42,7 +42,7 @@
4242
assert len(markers) == 5 # number of WBM rounds of element substitution
4343

4444
for idx, marker in enumerate(markers, 1):
45-
df = df_wbm[df_wbm.index.str.startswith(f"wbm-step-{idx}")]
45+
df = df_wren[df_wren.index.str.startswith(f"wbm-step-{idx}")]
4646
title = f"Batch {idx} ({len(df.filter(like='e_').dropna()):,})"
4747
assert 1e4 < len(df) < 1e5, print(f"{len(df) = :,}")
4848

mb_discovery/plots.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,30 @@
2020

2121

2222
# --- define global plot settings
23-
px.defaults.labels = {
24-
"n_atoms": "Atom Count",
25-
"n_elems": "Element Count",
26-
"crystal_sys": "Crystal system",
27-
"spg_num": "Space group",
28-
"n_wyckoff": "Number of Wyckoff positions",
29-
"n_sites": "Lattice site count",
30-
"energy_per_atom": "Energy (eV/atom)",
31-
"e_form": "Formation energy (eV/atom)",
32-
"e_above_hull": "Energy above convex hull (eV/atom)",
33-
"e_above_hull_pred": "Predicted energy above convex hull (eV/atom)",
34-
"e_above_mp_hull": "Energy above MP convex hull (eV/atom)",
35-
"e_above_hull_error": "Error in energy above convex hull (eV/atom)",
36-
}
23+
quantity_labels = dict(
24+
n_atoms="Atom Count",
25+
n_elems="Element Count",
26+
crystal_sys="Crystal system",
27+
spg_num="Space group",
28+
n_wyckoff="Number of Wyckoff positions",
29+
n_sites="Lattice site count",
30+
energy_per_atom="Energy (eV/atom)",
31+
e_form="Formation energy (eV/atom)",
32+
e_above_hull="Energy above convex hull (eV/atom)",
33+
e_above_hull_pred="Predicted energy above convex hull (eV/atom)",
34+
e_above_mp_hull="Energy above MP convex hull (eV/atom)",
35+
e_above_hull_error="Error in energy above convex hull (eV/atom)",
36+
)
37+
model_labels = dict(
38+
wren="Wren",
39+
wrenformer="Wrenformer",
40+
m3gnet="M3GNet",
41+
bowsr_megnet="BOWSR + MEGNet",
42+
cgcnn="CGCNN",
43+
voronoi="Voronoi",
44+
wbm="WBM",
45+
)
46+
px.defaults.labels = quantity_labels | model_labels
3747

3848
pio.templates.default = "plotly_white"
3949

models/bowsr/slurm_array_bowsr_wbm.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
data_path=data_path,
8585
bayes_optim_kwargs=bayes_optim_kwargs,
8686
optimize_kwargs=optimize_kwargs,
87+
task_type=task_type,
8788
)
8889
if wandb.run is None:
8990
wandb.login()

models/m3gnet/join_m3gnet_relax_results.py

+3-31
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
# %%
22
from __future__ import annotations
33

4-
import gzip
5-
import io
64
import os
7-
import pickle
8-
import urllib.request
95
from datetime import datetime
106
from glob import glob
117

128
import pandas as pd
13-
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram, PDEntry
9+
from pymatgen.analysis.phase_diagram import PDEntry
1410
from pymatgen.core import Structure
1511
from tqdm import tqdm
1612

1713
from mb_discovery import ROOT, as_dict_handler
14+
from mb_discovery.compute_formation_energy import get_form_energy_per_atom
1815

1916
__author__ = "Janosh Riebesell"
2017
__date__ = "2022-08-16"
@@ -67,21 +64,12 @@
6764

6865

6966
# %%
70-
# 2022-01-25-ppd-mp+wbm.pkl.gz (235 MB)
71-
ppd_pickle_url = "https://figshare.com/files/36669624"
72-
zipped_file = urllib.request.urlopen(ppd_pickle_url)
73-
74-
ppd_mp_wbm: PatchedPhaseDiagram = pickle.load(
75-
io.BytesIO(gzip.decompress(zipped_file.read()))
76-
)
77-
78-
7967
pd_entries_m3gnet = [
8068
PDEntry(row.m3gnet_structure.composition, row.m3gnet_energy)
8169
for row in df_m3gnet.itertuples()
8270
]
8371
df_m3gnet["e_form_m3gnet_from_ppd"] = [
84-
ppd_mp_wbm.get_form_energy_per_atom(x) for x in pd_entries_m3gnet
72+
get_form_energy_per_atom(entry) for entry in pd_entries_m3gnet
8573
]
8674

8775

@@ -93,22 +81,6 @@
9381
df_m3gnet["e_above_mp_hull"] = df_hull.e_above_mp_hull
9482

9583

96-
df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
97-
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
98-
).set_index("material_id")
99-
100-
df_m3gnet["e_form_wbm"] = df_wbm.e_form_per_atom
101-
df_m3gnet["wbm_energy"] = df_wbm.energy
102-
103-
pd_entries_wbm = [
104-
PDEntry(row.m3gnet_structure.composition, row.wbm_energy)
105-
for row in df_m3gnet.itertuples()
106-
]
107-
df_m3gnet["e_form_ppd_2022_01_25"] = [
108-
ppd_mp_wbm.get_form_energy_per_atom(x) for x in pd_entries_wbm
109-
]
110-
111-
11284
# %%
11385
df_m3gnet.hist(bins=200, figsize=(18, 12))
11486
df_m3gnet.isna().sum()

models/m3gnet/slurm_array_m3gnet_relax_wbm.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
slurm_job_id=slurm_job_id,
7575
slurm_array_task_id=slurm_array_task_id,
7676
data_path=data_path,
77+
task_type=task_type,
7778
)
7879
if wandb.run is None:
7980
wandb.login()

0 commit comments

Comments
 (0)