Skip to content

Commit aec9933

Browse files
committed
move mb_discovery/{energy/__init__.py -> energy.py} and add tests/test_energy.py
1 parent 8d9e346 commit aec9933

File tree

3 files changed

+123
-19
lines changed

3 files changed

+123
-19
lines changed

.gitignore

+1-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __pycache__
1111
*.csv.bz2
1212
*.pkl.gz
1313
data/**/raw
14+
data/**/202*
1415

1516
# checkpoint files of trained models
1617
pretrained/
@@ -22,9 +23,6 @@ job-logs/
2223
# slurm logs
2324
slurm-*out
2425
models/**/*.csv
25-
mb_discovery/energy/**/*.csv
26-
mb_discovery/energy/**/*.json
27-
mb_discovery/energy/**/*.gzip
2826

2927
# temporary ignore rule
3028
paper

mb_discovery/energy/__init__.py mb_discovery/energy.py

+48-16
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,48 @@
11
import itertools
2+
from collections.abc import Sequence
23

34
import pandas as pd
45
from pymatgen.analysis.phase_diagram import Entry, PDEntry
6+
from pymatgen.core import Composition
7+
from pymatgen.util.typing import EntryLike
58
from tqdm import tqdm
69

710
from mb_discovery import ROOT
811

912

1013
def get_elemental_ref_entries(
11-
entries: list[Entry], verbose: bool = False
14+
entries: Sequence[EntryLike], verbose: bool = False
1215
) -> dict[str, Entry]:
16+
"""Get the lowest energy entry for each element in a list of entries.
1317
18+
Args:
19+
entries (Sequence[Entry]): pymatgen Entries (PDEntry, ComputedEntry or
20+
ComputedStructureEntry) to find elemental reference entries for.
21+
verbose (bool, optional): _description_. Defaults to False.
22+
23+
Raises:
24+
ValueError: If some elements are missing terminal reference entries.
25+
ValueError: If there are more terminal entries than dimensions. Should never
26+
happen.
27+
28+
Returns:
29+
dict[str, Entry]: Map from element symbol to its lowest energy entry.
30+
"""
31+
entries = [PDEntry.from_dict(e) if isinstance(e, dict) else e for e in entries]
1432
elements = {elems for entry in entries for elems in entry.composition.elements}
1533
dim = len(elements)
1634

1735
if verbose:
1836
print(f"Sorting {len(entries)} entries with {dim} dimensions...")
37+
1938
entries = sorted(entries, key=lambda e: e.composition.reduced_composition)
2039

2140
elemental_ref_entries = {}
22-
if verbose:
23-
print("Finding elemental reference entries...", flush=True)
24-
for composition, group in tqdm(
25-
itertools.groupby(entries, key=lambda e: e.composition.reduced_composition)
41+
for composition, entry_group in tqdm(
42+
itertools.groupby(entries, key=lambda e: e.composition.reduced_composition),
43+
disable=not verbose,
2644
):
27-
min_entry = min(group, key=lambda e: e.energy_per_atom)
45+
min_entry = min(entry_group, key=lambda e: e.energy_per_atom)
2846
if composition.is_element:
2947
elem_symb = str(composition.elements[0])
3048
elemental_ref_entries[elem_symb] = min_entry
@@ -53,14 +71,16 @@ def get_elemental_ref_entries(
5371

5472

5573
def get_e_form_per_atom(
56-
entry: Entry, elemental_ref_entries: dict[str, Entry] = None
74+
entry: EntryLike,
75+
elemental_ref_entries: dict[str, EntryLike] = None,
5776
) -> float:
5877
"""Get the formation energy of a composition from a list of entries and elemental
5978
reference energies.
6079
6180
Args:
62-
entry (Entry): pymatgen Entry (PDEntry, ComputedEntry or ComputedStructureEntry)
63-
to compute formation energy of.
81+
entry: Entry | dict[str, float | str | Composition]: pymatgen Entry (PDEntry,
82+
ComputedEntry or ComputedStructureEntry) or dict with energy and composition
83+
keys to compute formation energy of.
6484
elemental_ref_entries (dict[str, Entry], optional): Must be a complete set of
6585
terminal (i.e. elemental) reference entries containing the lowest energy
6686
phase for each element present in entry. Defaults to MP elemental reference
@@ -76,13 +96,25 @@ def get_e_form_per_atom(
7696
f"Couldn't load {mp_elem_refs_path=}, you must pass "
7797
f"{elemental_ref_entries=} explicitly."
7898
)
79-
8099
elemental_ref_entries = mp_elem_reference_entries
81100

82-
comp = entry.composition
83-
form_energy = entry.uncorrected_energy - sum(
84-
comp[el] * elemental_ref_entries[str(el)].energy_per_atom
85-
for el in entry.composition.elements
86-
)
101+
if isinstance(entry, dict):
102+
energy = entry["energy"]
103+
comp = Composition(entry["composition"]) # is idempotent if already Composition
104+
elif isinstance(entry, Entry):
105+
energy = entry.energy
106+
comp = entry.composition
107+
else:
108+
raise TypeError(
109+
f"{entry=} must be Entry (or subclass like ComputedEntry) or dict"
110+
)
111+
112+
refs = {str(el): elemental_ref_entries[str(el)] for el in comp}
113+
114+
for key, ref_entry in refs.items():
115+
if isinstance(ref_entry, dict):
116+
refs[key] = PDEntry.from_dict(ref_entry)
117+
118+
form_energy = energy - sum(comp[el] * refs[str(el)].energy_per_atom for el in comp)
87119

88-
return form_energy / entry.composition.num_atoms
120+
return form_energy / comp.num_atoms

tests/test_energy.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Callable
4+
5+
import pytest
6+
from pymatgen.analysis.phase_diagram import PDEntry
7+
from pymatgen.core import Lattice, Structure
8+
from pymatgen.entries.computed_entries import (
9+
ComputedEntry,
10+
ComputedStructureEntry,
11+
Entry,
12+
)
13+
14+
from mb_discovery.energy import get_e_form_per_atom, get_elemental_ref_entries
15+
16+
dummy_struct = Structure(
17+
lattice=Lattice.cubic(5),
18+
species=("Fe", "O"),
19+
coords=((0, 0, 0), (0.5, 0.5, 0.5)),
20+
)
21+
22+
23+
@pytest.mark.parametrize(
24+
"constructor", [PDEntry, ComputedEntry, ComputedStructureEntry, lambda **x: x]
25+
)
26+
def test_get_e_form_per_atom(
27+
constructor: Callable[..., Entry | dict[str, Any]]
28+
) -> None:
29+
"""Test that the formation energy of a composition is computed correctly."""
30+
31+
entry = {"composition": {"Fe": 1, "O": 1}, "energy": -2.5}
32+
elemental_ref_entries = {
33+
"Fe": {"composition": {"Fe": 1}, "energy": -1.0},
34+
"O": {"composition": {"O": 1}, "energy": -1.0},
35+
}
36+
if constructor == ComputedStructureEntry:
37+
entry["structure"] = dummy_struct
38+
entry.pop("composition")
39+
40+
entry = constructor(**entry)
41+
42+
# don't use ComputedStructureEntry for elemental ref entries, would need many
43+
# dummy structures
44+
if constructor == ComputedStructureEntry:
45+
constructor = ComputedEntry
46+
elemental_ref_entries = {
47+
k: constructor(**v) for k, v in elemental_ref_entries.items()
48+
}
49+
assert get_e_form_per_atom(entry, elemental_ref_entries) == -0.25
50+
51+
52+
@pytest.mark.parametrize("constructor", [PDEntry, ComputedEntry, lambda **x: x])
53+
@pytest.mark.parametrize("verbose", [True, False])
54+
def test_get_elemental_ref_entries(
55+
constructor: Callable[..., Entry | dict[str, Any]], verbose: bool
56+
) -> None:
57+
"""Test that the elemental reference entries are correctly identified."""
58+
entries = [
59+
("Fe1 O1", -2.5),
60+
("Fe1", -1.0),
61+
("Fe1", -2.0),
62+
("O1", -1.0),
63+
("O3", -2.0),
64+
]
65+
elemental_ref_entries = get_elemental_ref_entries(
66+
[constructor(composition=comp, energy=energy) for comp, energy in entries],
67+
verbose=verbose,
68+
)
69+
if constructor.__name__ == "<lambda>":
70+
expected = {"Fe": PDEntry(*entries[2]), "O": PDEntry(*entries[3])}
71+
else:
72+
expected = {"Fe": constructor(*entries[2]), "O": constructor(*entries[3])}
73+
74+
assert elemental_ref_entries == expected

0 commit comments

Comments
 (0)