Skip to content

Commit db23e1a

Browse files
kavanasejanosh
authored andcommitted
Add formal_chempots option to ChemicalPotentialDiagram to plot the formal chemical potentials rather than the DFT energies (materialsproject#2916)
* Add `formal_chempots` option to `ChemicalPotentialDiagram` to plot the formal chemical potentials rather than the DFT energies * Add tests for `formal_chempots` option in `ChemicalPotentialDiagram` to all applicable tests in `test_chempot_diagram.py` * remove useless dict.copy() in _renormalize_entry(), fix some doc strings and typos * replace assertArrayAlmostEqual() with pytest.approx() in test_chempot_diagram.py * use dot access for plot attributes --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 3a2380d commit db23e1a

File tree

2 files changed

+184
-26
lines changed

2 files changed

+184
-26
lines changed

pymatgen/analysis/chempot_diagram.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,42 @@ class ChemicalPotentialDiagram(MSONable):
6565
def __init__(
6666
self,
6767
entries: list[PDEntry],
68-
limits: dict[Element, float] | None = None,
68+
limits: dict[Element, tuple[float, float]] | None = None,
6969
default_min_limit: float = -50.0,
70-
):
70+
formal_chempots: bool = True,
71+
) -> None:
7172
"""
7273
Args:
73-
entries: List of PDEntry-like objects containing a composition and
74+
entries (list[PDEntry]): PDEntry-like objects containing a composition and
7475
energy. Must contain elemental references and be suitable for typical
7576
phase diagram construction. Entries must be within a chemical system
7677
of with 2+ elements.
77-
limits: Bounds of elemental chemical potentials (min, max), which are
78-
used to construct the border hyperplanes used in the
79-
HalfSpaceIntersection algorithm; these constrain the space over which the
80-
domains are calculated and also determine the size of the plotted
81-
diagram. Any elemental limits not specified are covered in the
82-
default_min_limit argument. e.g., {Element("Li"): [-12.0, 0.0], ...}
78+
limits (dict[Element, float] | None): Bounds of elemental chemical potentials (min, max),
79+
which are used to construct the border hyperplanes used in the HalfSpaceIntersection
80+
algorithm; these constrain the space over which the domains are calculated and also
81+
determine the size of the plotted diagram. Any elemental limits not specified are
82+
covered in the default_min_limit argument. e.g., {Element("Li"): [-12.0, 0.0], ...}
8383
default_min_limit (float): Default minimum chemical potential limit (i.e.,
8484
lower bound) for unspecified elements within the "limits" argument.
85+
formal_chempots (bool): Whether to plot the formal ('reference') chemical potentials
86+
(i.e. μ_X - μ_X^0) or the absolute DFT reference energies (i.e. μ_X(DFT)).
87+
Default is True (i.e. plot formal chemical potentials).
8588
"""
89+
entries = sorted(entries, key=lambda e: e.composition.reduced_composition)
90+
_min_entries, _el_refs = self._get_min_entries_and_el_refs(entries)
91+
92+
if formal_chempots:
93+
# renormalize entry energies to be relative to the elemental references
94+
renormalized_entries = []
95+
for entry in entries:
96+
comp_dict = entry.composition.as_dict()
97+
renormalization_energy = sum(
98+
[comp_dict[el] * _el_refs[Element(el)].energy_per_atom for el in comp_dict]
99+
)
100+
renormalized_entries.append(_renormalize_entry(entry, renormalization_energy / sum(comp_dict.values())))
101+
102+
entries = renormalized_entries
103+
86104
self.entries = sorted(entries, key=lambda e: e.composition.reduced_composition)
87105
self.limits = limits
88106
self.default_min_limit = default_min_limit
@@ -622,7 +640,7 @@ def __repr__(self):
622640

623641
def simple_pca(data: np.ndarray, k: int = 2) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
624642
"""
625-
A barebones implementation of principal component analysis (PCA) used in the
643+
A bare-bones implementation of principal component analysis (PCA) used in the
626644
ChemicalPotentialDiagram class for plotting.
627645
628646
Args:
@@ -645,15 +663,15 @@ def simple_pca(data: np.ndarray, k: int = 2) -> tuple[np.ndarray, np.ndarray, np
645663

646664
def get_centroid_2d(vertices: np.ndarray) -> np.ndarray:
647665
"""
648-
A barebones implementation of the formula for calculating the centroid of a 2D
666+
A bare-bones implementation of the formula for calculating the centroid of a 2D
649667
polygon. Useful for calculating the location of an annotation on a chemical
650668
potential domain within a 3D chemical potential diagram.
651669
652-
**NOTE**: vertices must be ordered circumfrentially!
670+
**NOTE**: vertices must be ordered circumferentially!
653671
654672
Args:
655673
vertices: array of 2-d coordinates corresponding to a polygon, ordered
656-
circumfrentially
674+
circumferentially
657675
658676
Returns:
659677
Array giving 2-d centroid coordinates
@@ -690,7 +708,7 @@ def get_2d_orthonormal_vector(line_pts: np.ndarray) -> np.ndarray:
690708
coordinates of a line
691709
692710
Returns:
693-
711+
np.ndarray: A length-2 vector that is orthonormal to the line.
694712
"""
695713
x = line_pts[:, 0]
696714
y = line_pts[:, 1]
@@ -703,3 +721,15 @@ def get_2d_orthonormal_vector(line_pts: np.ndarray) -> np.ndarray:
703721
vec = np.array([np.sin(theta), np.cos(theta)])
704722

705723
return vec
724+
725+
726+
def _renormalize_entry(entry: PDEntry, renormalization_energy_per_atom: float) -> PDEntry:
727+
"""
728+
Regenerate the input entry with an energy per atom decreased by renormalization_energy_per_atom
729+
"""
730+
renormalized_entry_dict = entry.as_dict()
731+
renormalized_entry_dict["energy"] = entry.energy - renormalization_energy_per_atom * sum(
732+
entry.composition.values()
733+
) # entry.energy includes MP corrections as desired
734+
renormalized_entry = PDEntry.from_dict(renormalized_entry_dict)
735+
return renormalized_entry

pymatgen/analysis/tests/test_chempot_diagram.py

Lines changed: 140 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66

77
import numpy as np
8+
import pytest
89
from plotly.graph_objects import Figure
910

1011
from pymatgen.analysis.chempot_diagram import (
@@ -23,15 +24,18 @@
2324
class ChemicalPotentialDiagramTest(PymatgenTest):
2425
def setUp(self):
2526
self.entries = EntrySet.from_csv(str(module_dir / "pdentries_test.csv"))
26-
self.cpd_ternary = ChemicalPotentialDiagram(entries=self.entries, default_min_limit=-25)
27+
self.cpd_ternary = ChemicalPotentialDiagram(entries=self.entries, default_min_limit=-25, formal_chempots=False)
28+
self.cpd_ternary_formal = ChemicalPotentialDiagram(
29+
entries=self.entries, default_min_limit=-25, formal_chempots=True
30+
)
2731
elements = [Element("Fe"), Element("O")]
2832
binary_entries = list(
2933
filter(
3034
lambda e: set(e.composition.elements).issubset(elements),
3135
self.entries,
3236
)
3337
)
34-
self.cpd_binary = ChemicalPotentialDiagram(entries=binary_entries, default_min_limit=-25)
38+
self.cpd_binary = ChemicalPotentialDiagram(entries=binary_entries, default_min_limit=-25, formal_chempots=False)
3539
warnings.simplefilter("ignore")
3640

3741
def tearDown(self):
@@ -40,6 +44,7 @@ def tearDown(self):
4044
def test_dim(self):
4145
assert self.cpd_binary.dim == 2
4246
assert self.cpd_ternary.dim == 3
47+
assert self.cpd_ternary_formal.dim == 3
4348

4449
def test_el_refs(self):
4550
el_refs = {elem: entry.energy for elem, entry in self.cpd_ternary.el_refs.items()}
@@ -48,17 +53,26 @@ def test_el_refs(self):
4853
energies = [-1.91301487, -6.5961471, -25.54966885]
4954
correct_el_refs = dict(zip(elems, energies))
5055

51-
self.assertDictsAlmostEqual(el_refs, correct_el_refs)
56+
assert el_refs == pytest.approx(correct_el_refs)
57+
58+
def test_el_refs_formal(self):
59+
el_refs = {elem: entry.energy for elem, entry in self.cpd_ternary_formal.el_refs.items()}
60+
elems = [Element("Li"), Element("Fe"), Element("O")]
61+
energies = [0, 0, 0]
62+
correct_el_refs = dict(zip(elems, energies))
63+
assert el_refs == pytest.approx(correct_el_refs)
5264

5365
def test_border_hyperplanes(self):
5466
desired = np.array(
5567
[[-1, 0, 0, -25], [1, 0, 0, 0], [0, -1, 0, -25], [0, 1, 0, 0], [0, 0, -1, -25], [0, 0, 1, 0]]
5668
)
57-
self.assertArrayAlmostEqual(self.cpd_ternary.border_hyperplanes, desired)
69+
assert self.cpd_ternary.border_hyperplanes == pytest.approx(desired)
70+
assert self.cpd_ternary_formal.border_hyperplanes == pytest.approx(desired)
5871

5972
def test_lims(self):
6073
desired_lims = np.array([[-25, 0], [-25, 0], [-25, 0]])
61-
self.assertArrayAlmostEqual(self.cpd_ternary.lims, desired_lims)
74+
assert self.cpd_ternary.lims == pytest.approx(desired_lims)
75+
assert self.cpd_ternary_formal.lims == pytest.approx(desired_lims)
6276

6377
def test_pca(self):
6478
points_3d = np.array(
@@ -80,7 +94,7 @@ def test_pca(self):
8094

8195
points_2d, _, _ = simple_pca(points_3d, k=2)
8296

83-
self.assertArrayAlmostEqual(points_2d, points_2d_desired)
97+
assert points_2d == pytest.approx(points_2d_desired)
8498

8599
def test_centroid(self):
86100
vertices = np.array(
@@ -97,7 +111,7 @@ def test_centroid(self):
97111
centroid = get_centroid_2d(vertices)
98112
centroid_desired = np.array([-0.00069433, -0.00886174])
99113

100-
self.assertArrayAlmostEqual(centroid, centroid_desired)
114+
assert centroid == pytest.approx(centroid_desired, abs=1e-6)
101115

102116
def test_get_2d_orthonormal_vector(self):
103117
pts_1 = np.array([[1, 1], [2, 2]])
@@ -109,18 +123,25 @@ def test_get_2d_orthonormal_vector(self):
109123
vec_1_desired = np.array([0.70710678, 0.70710678])
110124
vec_2_desired = np.array([0.98386991, 0.17888544])
111125

112-
self.assertArrayAlmostEqual(vec_1, vec_1_desired)
113-
self.assertArrayAlmostEqual(vec_2, vec_2_desired)
126+
assert vec_1 == pytest.approx(vec_1_desired)
127+
assert vec_2 == pytest.approx(vec_2_desired)
114128

115129
def test_get_plot(self):
116130
fig_2d = self.cpd_binary.get_plot()
117131
fig_3d = self.cpd_ternary.get_plot()
132+
fig_3d_formal = self.cpd_ternary_formal.get_plot()
118133

119134
assert isinstance(fig_2d, Figure)
120-
assert fig_2d["data"][0]["type"] == "scatter"
135+
assert fig_2d.data[0].type == "scatter"
121136

122137
assert isinstance(fig_3d, Figure)
123-
assert fig_3d["data"][0]["type"] == "scatter3d"
138+
assert fig_3d.data[0].type == "scatter3d"
139+
140+
assert isinstance(fig_3d_formal, Figure)
141+
assert fig_3d_formal.data[0].type == "scatter3d"
142+
assert fig_3d_formal.data[0].mode == "lines"
143+
assert fig_3d_formal.layout.plot_bgcolor == "rgba(0,0,0,0)"
144+
assert fig_3d_formal.layout.scene.annotations[0].text == "FeO"
124145

125146
def test_domains(self):
126147
correct_domains = {
@@ -229,7 +250,114 @@ def test_domains(self):
229250
d = self.cpd_ternary.domains[formula]
230251
d = d.round(6) # to get rid of numerical errors from qhull
231252
actual_domain_sorted = d[np.lexsort((d[:, 2], d[:, 1], d[:, 0]))]
232-
self.assertArrayAlmostEqual(actual_domain_sorted, domain)
253+
assert actual_domain_sorted == pytest.approx(domain)
254+
255+
formal_domains = {
256+
"FeO": np.array(
257+
[
258+
[-2.50000000e01, 3.55271368e-15, -2.85707600e00],
259+
[-2.01860032e00, 3.55271368e-15, -2.85707600e00],
260+
[-2.50000000e01, -1.45446765e-01, -2.71162923e00],
261+
[-2.16404709e00, -1.45446765e-01, -2.71162923e00],
262+
]
263+
),
264+
"Fe2O3": np.array(
265+
[
266+
[-25.0, -4.14354109, 0.0],
267+
[-3.637187, -4.14354108, 0.0],
268+
[-3.49325969, -3.85568646, -0.19190308],
269+
[-25.0, -0.70024301, -2.29553205],
270+
[-2.44144521, -0.70024301, -2.29553205],
271+
]
272+
),
273+
"Fe3O4": np.array(
274+
[
275+
[-25.0, -0.70024301, -2.29553205],
276+
[-25.0, -0.14544676, -2.71162923],
277+
[-2.44144521, -0.70024301, -2.29553205],
278+
[-2.16404709, -0.14544676, -2.71162923],
279+
]
280+
),
281+
"LiFeO2": np.array(
282+
[
283+
[-3.49325969e00, -3.85568646e00, -1.91903083e-01],
284+
[-2.01860032e00, 3.55271368e-15, -2.85707600e00],
285+
[-2.44144521e00, -7.00243005e-01, -2.29553205e00],
286+
[-2.16404709e00, -1.45446765e-01, -2.71162923e00],
287+
[-1.71198739e00, 3.55271368e-15, -3.01038246e00],
288+
[-2.74919447e00, -3.11162124e00, -9.35968300e-01],
289+
]
290+
),
291+
"Li2O": np.array(
292+
[
293+
[0.00000000e00, -2.50000000e01, -6.22930387e00],
294+
[-2.69949567e00, -2.50000000e01, -8.30312528e-01],
295+
[3.55271368e-15, 3.55271368e-15, -6.22930387e00],
296+
[-1.43858289e00, 3.55271368e-15, -3.35213809e00],
297+
[-2.69949567e00, -3.78273835e00, -8.30312528e-01],
298+
]
299+
),
300+
"Li2O2": np.array(
301+
[
302+
[-3.52980820e00, -2.50000000e01, 0.00000000e00],
303+
[-2.69949567e00, -2.50000000e01, -8.30312528e-01],
304+
[-3.52980820e00, -4.35829869e00, 3.55271368e-15],
305+
[-2.69949567e00, -3.78273835e00, -8.30312528e-01],
306+
[-2.82687176e00, -3.65536226e00, -7.02936437e-01],
307+
]
308+
),
309+
"Li2FeO3": np.array(
310+
[
311+
[-3.52980820e00, -4.35829869e00, 3.55271368e-15],
312+
[-3.63718700e00, -4.14354108e00, 0.00000000e00],
313+
[-3.49325969e00, -3.85568646e00, -1.91903083e-01],
314+
[-2.74919447e00, -3.11162124e00, -9.35968300e-01],
315+
[-2.82687176e00, -3.65536226e00, -7.02936437e-01],
316+
]
317+
),
318+
"Li5FeO4": np.array(
319+
[
320+
[-1.43858289e00, 3.55271368e-15, -3.35213809e00],
321+
[-1.71198739e00, 3.55271368e-15, -3.01038246e00],
322+
[-2.74919447e00, -3.11162124e00, -9.35968300e-01],
323+
[-2.69949567e00, -3.78273835e00, -8.30312528e-01],
324+
[-2.82687176e00, -3.65536226e00, -7.02936437e-01],
325+
]
326+
),
327+
"O2": np.array(
328+
[
329+
[-2.50000000e01, -2.50000000e01, 3.55271368e-15],
330+
[-3.52980820e00, -2.50000000e01, 0.00000000e00],
331+
[-2.50000000e01, -4.14354109e00, 0.00000000e00],
332+
[-3.52980820e00, -4.35829869e00, 3.55271368e-15],
333+
[-3.63718700e00, -4.14354108e00, 0.00000000e00],
334+
]
335+
),
336+
"Fe": np.array(
337+
[
338+
[0.00000000e00, 0.00000000e00, -2.50000000e01],
339+
[-2.50000000e01, 0.00000000e00, -2.50000000e01],
340+
[3.55271368e-15, 3.55271368e-15, -6.22930387e00],
341+
[-2.50000000e01, 3.55271368e-15, -2.85707600e00],
342+
[-2.01860032e00, 3.55271368e-15, -2.85707600e00],
343+
[-1.43858289e00, 3.55271368e-15, -3.35213809e00],
344+
[-1.71198739e00, 3.55271368e-15, -3.01038246e00],
345+
]
346+
),
347+
"Li": np.array(
348+
[
349+
[3.55271368e-15, -2.50000000e01, -2.50000000e01],
350+
[0.00000000e00, -2.50000000e01, -6.22930387e00],
351+
[0.00000000e00, 0.00000000e00, -2.50000000e01],
352+
[3.55271368e-15, 3.55271368e-15, -6.22930387e00],
353+
]
354+
),
355+
}
356+
357+
for formula, domain in formal_domains.items():
358+
d = self.cpd_ternary_formal.domains[formula]
359+
d = d.round(6) # to get rid of numerical errors from qhull
360+
assert d == pytest.approx(domain, abs=1e-5)
233361

234362

235363
if __name__ == "__main__":

0 commit comments

Comments
 (0)