Skip to content

Commit 42a7909

Browse files
committed
add mb_discovery/build_phase_diagram.py
add missing MaterialsProject2020Compatibility processing to process_wbm_cleaned.py
1 parent e7f4582 commit 42a7909

9 files changed

+141
-73
lines changed

mb_discovery/build_phase_diagram.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# %%
2+
import gzip
3+
import json
4+
import os
5+
import pickle
6+
from datetime import datetime
7+
8+
import pandas as pd
9+
import pymatviz
10+
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram
11+
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
12+
from pymatgen.entries.computed_entries import ComputedEntry
13+
from pymatgen.ext.matproj import MPRester
14+
15+
from mb_discovery import ROOT
16+
from mb_discovery.compute_formation_energy import (
17+
get_elemental_ref_entries,
18+
get_form_energy_per_atom,
19+
)
20+
21+
today = f"{datetime.now():%Y-%m-%d}"
22+
module_dir = os.path.dirname(__file__)
23+
24+
25+
# %%
26+
all_mp_computed_structure_entries = MPRester().get_entries("") # run on 2022-09-16
27+
28+
# save all ComputedStructureEntries to disk
29+
pd.Series(
30+
{e.entry_id: e for e in all_mp_computed_structure_entries}
31+
).drop_duplicates().to_json( # mp-15590 appears twice so we drop_duplicates()
32+
f"{ROOT}/data/{today}-all-mp-entries.json.gz", default_handler=lambda x: x.as_dict()
33+
)
34+
35+
36+
# %%
37+
all_mp_computed_entries = (
38+
pd.read_json(f"{ROOT}/data/2022-09-16-all-mp-entries.json.gz")
39+
.set_index("material_id")
40+
.entry.map(ComputedEntry.from_dict) # drop the structure, just load ComputedEntry
41+
.to_dict()
42+
)
43+
44+
45+
print(f"{len(all_mp_computed_entries) = :,}")
46+
# len(all_mp_computed_entries) = 146,323
47+
48+
49+
# %% build phase diagram with MP entries only
50+
ppd_mp = PatchedPhaseDiagram(all_mp_computed_entries)
51+
# prints:
52+
# PatchedPhaseDiagram
53+
# Covering 44805 Sub-Spaces
54+
55+
# save MP PPD to disk
56+
with gzip.open(f"{module_dir}/{today}-ppd-mp.pkl.gz", "wb") as zip_file:
57+
pickle.dump(ppd_mp, zip_file)
58+
59+
60+
# %% build phase diagram with both MP entries + WBM entries
61+
df_wbm = pd.read_json(
62+
f"{ROOT}/data/2022-06-26-wbm-cses-and-initial-structures.json.gz"
63+
).set_index("material_id")
64+
65+
wbm_computed_entries: list[ComputedEntry] = df_wbm.query("n_elements > 1").cse.map(
66+
ComputedEntry.from_dict
67+
)
68+
69+
wbm_computed_entries = MaterialsProject2020Compatibility().process_entries(
70+
wbm_computed_entries, verbose=True, clean=True
71+
)
72+
73+
n_skipped = len(df_wbm) - len(wbm_computed_entries)
74+
print(f"{n_skipped:,} ({n_skipped / len(df_wbm):.1%}) entries not processed")
75+
76+
77+
# %% merge MP and WBM entries into a single PatchedPhaseDiagram
78+
mp_wbm_ppd = PatchedPhaseDiagram(
79+
wbm_computed_entries + all_mp_computed_entries, verbose=True
80+
)
81+
82+
83+
# %% compute terminal reference entries across all MP (can be used to compute MP
84+
# compatible formation energies quickly)
85+
elemental_ref_entries = get_elemental_ref_entries(all_mp_computed_entries)
86+
87+
# save elemental_ref_entries to disk as json
88+
with open(f"{module_dir}/{today}-elemental-ref-entries.json", "w") as file:
89+
json.dump(elemental_ref_entries, file, default=lambda x: x.as_dict())
90+
91+
92+
# %% load MP elemental reference entries to compute formation energies
93+
mp_elem_refs_path = f"{ROOT}/data/2022-09-19-mp-elemental-reference-entries.json"
94+
mp_reference_entries = (
95+
pd.read_json(mp_elem_refs_path, typ="series").map(ComputedEntry.from_dict).to_dict()
96+
)
97+
98+
99+
df_mp = pd.read_json(f"{ROOT}/data/2022-08-13-mp-all-energies.json.gz").set_index(
100+
"material_id"
101+
)
102+
103+
104+
# %%
105+
df_mp["our_mp_e_form"] = [
106+
get_form_energy_per_atom(all_mp_computed_entries[mp_id], mp_reference_entries)
107+
for mp_id in df_mp.index
108+
]
109+
110+
111+
# make sure get_form_energy_per_atom() reproduces MP formation energies
112+
ax = pymatviz.density_scatter(
113+
df_mp["formation_energy_per_atom"], df_mp["our_mp_e_form"]
114+
)
115+
ax.set(
116+
title="MP Formation Energy Comparison",
117+
xlabel="MP Formation Energy (eV/atom)",
118+
ylabel="Our Formation Energy (eV/atom)",
119+
)
120+
ax.figure.savefig(f"{ROOT}/tmp/{today}-mp-formation-energy-comparison.png", dpi=300)
+4-55
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,12 @@
1-
# %%
2-
import gzip
31
import itertools
4-
import json
5-
import os
6-
import pickle
7-
from datetime import datetime
82

9-
import pandas as pd
10-
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram, PDEntry
11-
from pymatgen.ext.matproj import MPRester
3+
from pymatgen.analysis.phase_diagram import Entry
124
from tqdm import tqdm
135

14-
from mb_discovery import ROOT
156

16-
today = f"{datetime.now():%Y-%m-%d}"
17-
module_dir = os.path.dirname(__file__)
18-
19-
20-
# %%
217
def get_elemental_ref_entries(
22-
entries: list[PDEntry], verbose: bool = False
23-
) -> dict[str, PDEntry]:
8+
entries: list[Entry], verbose: bool = False
9+
) -> dict[str, Entry]:
2410

2511
elements = {elems for entry in entries for elems in entry.composition.elements}
2612
dim = len(elements)
@@ -53,7 +39,7 @@ def get_elemental_ref_entries(
5339

5440

5541
def get_form_energy_per_atom(
56-
entry: PDEntry, elemental_ref_entries: dict[str, PDEntry]
42+
entry: Entry, elemental_ref_entries: dict[str, Entry]
5743
) -> float:
5844
"""Get the formation energy of a composition from a list of entries and elemental
5945
reference energies.
@@ -65,40 +51,3 @@ def get_form_energy_per_atom(
6551
)
6652

6753
return form_energy / entry.composition.num_atoms
68-
69-
70-
# %%
71-
if __name__ == "__main__":
72-
all_mp_entries = MPRester().get_entries("") # run on 2022-09-16
73-
# mp-15590 appears twice so we drop_duplicates()
74-
df_mp_entries = pd.DataFrame(all_mp_entries, columns=["entry"]).drop_duplicates()
75-
df_mp_entries["material_id"] = [x.entry_id for x in df_mp_entries.entry]
76-
df_mp_entries = df_mp_entries.set_index("material_id")
77-
78-
df_mp_entries.reset_index().to_json(
79-
f"{ROOT}/data/{today}-2-all-mp-entries.json.gz",
80-
default_handler=lambda x: x.as_dict(),
81-
)
82-
83-
df_mp_entries = pd.read_json(
84-
f"{ROOT}/data/2022-09-16-all-mp-entries.json.gz"
85-
).set_index("material_id")
86-
all_mp_entries = [PDEntry.from_dict(x) for x in df_mp_entries.entry]
87-
88-
print(f"{len(df_mp_entries) = :,}")
89-
# len(df_mp_entries) = 146,323
90-
91-
ppd_mp = PatchedPhaseDiagram(all_mp_entries)
92-
# prints:
93-
# PatchedPhaseDiagram
94-
# Covering 44805 Sub-Spaces
95-
96-
# save MP PPD to disk
97-
with gzip.open(f"{module_dir}/{today}-ppd-mp.pkl.gz", "wb") as zip_file:
98-
pickle.dump(ppd_mp, zip_file)
99-
100-
elemental_ref_entries = get_elemental_ref_entries(all_mp_entries)
101-
102-
# save elemental_ref_entries to disk as json
103-
with open(f"{module_dir}/{today}-elemental-ref-entries.json", "w") as f:
104-
json.dump(elemental_ref_entries, f, default=lambda x: x.as_dict())

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
# download wbm-steps-summary.csv (23.31 MB)
4343
df_summary = pd.read_csv(
44-
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
44+
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
4545
).set_index("material_id")
4646

4747

mb_discovery/plot_scripts/hist_classified_stable_as_func_of_hull_dist_batches.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
# download wbm-steps-summary.csv (23.31 MB)
4343
df_summary = pd.read_csv(
44-
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
44+
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
4545
).set_index("material_id")
4646

4747

mb_discovery/plot_scripts/precision_recall_vs_calc_count.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
# %% download wbm-steps-summary.csv (23.31 MB)
3939
df_wbm = pd.read_csv(
40-
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
40+
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
4141
).set_index("material_id")
4242

4343

@@ -69,23 +69,24 @@
6969
try:
7070
if model_name == "M3GNet":
7171
model_preds = df.e_form_m3gnet
72-
targets = df.e_form_wbm
7372
elif "Wrenformer" in model_name:
7473
model_preds = df.e_form_per_atom_pred_ens
75-
targets = df.e_form_per_atom
7674
elif len(pred_cols := df.filter(like="e_form_pred").columns) >= 1:
7775
# Voronoi+RF has single prediction column, Wren and CGCNN each have 10
7876
# other cases are unexpected
7977
assert len(pred_cols) in (1, 10), f"{model_name=} has {len(pred_cols)=}"
8078
model_preds = df[pred_cols].mean(axis=1)
81-
targets = df.e_form_target
8279
else:
8380
raise ValueError(f"Unhandled {model_name = }")
8481
except AttributeError as exc:
8582
raise KeyError(f"{model_name = }") from exc
8683

8784
df["e_above_mp_hull"] = df_hull.e_above_mp_hull
88-
df["e_above_hull_pred"] = model_preds - targets
85+
df["e_form_per_atom"] = df_wbm.e_form_per_atom
86+
df["e_above_hull_pred"] = model_preds - df.e_form_per_atom
87+
if n_nans := df.isna().values.sum() > 0:
88+
assert n_nans < 10, f"{model_name=} has {n_nans=}"
89+
df = df.dropna()
8990

9091
ax = precision_recall_vs_calc_count(
9192
e_above_hull_error=df.e_above_hull_pred + df.e_above_mp_hull,
@@ -97,9 +98,10 @@
9798
std_pred=std_total,
9899
)
99100

100-
ax.legend(frameon=False, loc="lower right")
101-
102101
ax.figure.set_size_inches(10, 9)
102+
ax.set(xlim=(0, None))
103+
# keep this outside loop so all model names appear in legend
104+
ax.legend(frameon=False, loc="lower right")
103105

104106
img_path = f"{ROOT}/figures/{today}-precision-recall-vs-calc-count-{rare=}.pdf"
105107
if False:

mb_discovery/plots.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,9 @@ def precision_recall_vs_calc_count(
398398
# previous call
399399
return ax
400400

401-
ax.set(
402-
xlabel="Number of compounds sorted by model-predicted hull distance",
403-
ylabel="Precision and Recall (%)",
404-
)
405-
406-
ax.set(ylim=(0, 100))
401+
xlabel = "Number of compounds sorted by model-predicted hull distance"
402+
ylabel = "Precision and Recall (%)"
403+
ax.set(ylim=(0, 100), xlabel=xlabel, ylabel=ylabel)
407404

408405
[precision] = ax.plot((0, 0), (0, 0), "black", linestyle="-")
409406
[recall] = ax.plot((0, 0), (0, 0), "black", linestyle=":")

models/bowsr/join_bowsr_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656

5757
# %%
5858
df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
59-
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
59+
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
6060
).set_index("material_id")
6161

6262
df_bowsr["e_form_wbm"] = df_wbm.e_form_per_atom

models/m3gnet/join_m3gnet_relax_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595

9696

9797
df_wbm = pd.read_csv( # download wbm-steps-summary.csv (23.31 MB)
98-
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
98+
"https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
9999
).set_index("material_id")
100100

101101
df_m3gnet["e_form_wbm"] = df_wbm.e_form_per_atom

models/wrenformer/mp/use_ensemble.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
# %%
2525
# download wbm-steps-summary.csv (23.31 MB)
26-
data_path = "https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
26+
data_path = "https://figshare.com/files/37570234?private_link=ff0ad14505f9624f0c05"
2727
df = pd.read_csv(data_path).set_index("material_id")
2828

2929

0 commit comments

Comments
 (0)