Skip to content

Ensure parity with top level legacy methods #691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 152 additions & 33 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from multiprocessing.sharedctypes import Value
import warnings
from functools import lru_cache
from os import environ
Expand All @@ -16,13 +17,14 @@
from pymatgen.analysis.pourbaix_diagram import IonEntry
from pymatgen.core import Composition, Element, Structure
from pymatgen.core.ion import Ion
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from pymatgen.io.vasp import Chgcar
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from requests import get
from typing import Literal

from mp_api.client.core import BaseRester, MPRestError
from mp_api.client.core.utils import validate_ids
from mp_api.client.routes import *

_DEPRECATION_WARNING = (
Expand Down Expand Up @@ -438,50 +440,111 @@ def find_structure(
)

def get_entries(
self, chemsys_formula: Union[str, List[str]], sort_by_e_above_hull=False,
):
self,
chemsys_formula_mpids: Union[str, List[str]],
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,
sort_by_e_above_hull=False,
) -> List[ComputedStructureEntry]:
"""
Get a list of ComputedEntries or ComputedStructureEntries corresponding
to a chemical system or formula.

Args:
chemsys_formula (str): A chemical system, list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]), or single formula (e.g., Fe2O3, Si*).
chemsys_formula_mpids (str, List[str]): A chemical system, list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]), formula, list of formulas
(e.g., Fe2O3, Si*, [SiO2, BiFeO3]), Materials Project ID, or list of Materials
Project IDs (e.g., mp-22526, [mp-22526, mp-149]).
compatible_only (bool): Whether to return only "compatible"
entries. Compatible entries are entries that have been
processed using the MaterialsProject2020Compatibility class,
which performs adjustments to allow mixing of GGA and GGA+U
calculations for more accurate phase diagrams and reaction
energies. This data is obtained from the core "thermo" API endpoint.
inc_structure (str): *This is a deprecated argument*. Previously, if None, entries
returned were ComputedEntries. If inc_structure="initial",
ComputedStructureEntries with initial structures were returned.
Otherwise, ComputedStructureEntries with final structures
were returned. This is no longer needed as all entries will contain
structure data by default.
property_data (list): Specify additional properties to include in
entry.data. If None, only default data is included. Should be a subset of
input parameters in the 'MPRester.thermo.available_fields' list.
conventional_unit_cell (bool): Whether to get the standard
conventional unit cell
sort_by_e_above_hull (bool): Whether to sort the list of entries by
e_above_hull in ascending order.

Returns:
List of ComputedEntry or ComputedStructureEntry objects.
List ComputedStructureEntry objects.
"""

if isinstance(chemsys_formula, list) or (
isinstance(chemsys_formula, str) and "-" in chemsys_formula
):
input_params = {"chemsys": chemsys_formula}
else:
input_params = {"formula": chemsys_formula}
if inc_structure is not None:
warnings.warn("The 'inc_structure' argument is deprecated as structure "
"data is now always included in all returned entry objects.")

if isinstance(chemsys_formula_mpids, str):
chemsys_formula_mpids = [chemsys_formula_mpids]

try:
input_params = {"material_ids": validate_ids(chemsys_formula_mpids)}
except ValueError:

if any("-" in entry for entry in chemsys_formula_mpids):
input_params = {"chemsys": chemsys_formula_mpids}
else:
input_params = {"formula": chemsys_formula_mpids}

entries = []

if sort_by_e_above_hull:
fields = ["entries"] if not property_data else ["entries"] + property_data

for doc in self.thermo.search(
if sort_by_e_above_hull:
docs = self.thermo.search(
**input_params, # type: ignore
all_fields=False,
fields=["entries"],
fields=fields,
sort_fields=["energy_above_hull"],
):
entries.extend(list(doc.entries.values()))
)
else:
docs = self.thermo.search(
**input_params, all_fields=False, fields=fields, # type: ignore
)

return entries
for doc in docs:
for entry in doc.entries.values():
if not compatible_only:
entry.correction = 0.0
entry.energy_adjustments = []

else:
for doc in self.thermo.search(
**input_params, all_fields=False, fields=["entries"], # type: ignore
):
entries.extend(list(doc.entries.values()))
if property_data:
for property in property_data:
entry.data[property] = doc.dict()[property]

if conventional_unit_cell:

s = SpacegroupAnalyzer(entry.structure).get_conventional_standard_structure()
site_ratio = (len(s) / len(entry.structure))
new_energy = entry.uncorrected_energy * site_ratio

entry_dict = entry.as_dict()
entry_dict["energy"] = new_energy
entry_dict["structure"] = s.as_dict()
entry_dict["correction"] = 0.0

for element in entry_dict["composition"]:
entry_dict["composition"][element] *= site_ratio

return entries
for correction in entry_dict["energy_adjustments"]:
correction["n_atoms"] *= site_ratio

entry = ComputedStructureEntry.from_dict(entry_dict)

entries.append(entry)

return entries

def get_pourbaix_entries(
self,
Expand Down Expand Up @@ -783,24 +846,51 @@ def get_ion_entries(

return ion_entries

def get_entry_by_material_id(self, material_id: str):
def get_entry_by_material_id(self, material_id: str,
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,):
"""
Get all ComputedEntry objects corresponding to a material_id.

Args:
material_id (str): Materials Project material_id (a string,
e.g., mp-1234).
compatible_only (bool): Whether to return only "compatible"
entries. Compatible entries are entries that have been
processed using the MaterialsProject2020Compatibility class,
which performs adjustments to allow mixing of GGA and GGA+U
calculations for more accurate phase diagrams and reaction
energies. This data is obtained from the core "thermo" API endpoint.
inc_structure (str): *This is a deprecated argument*. Previously, if None, entries
returned were ComputedEntries. If inc_structure="initial",
ComputedStructureEntries with initial structures were returned.
Otherwise, ComputedStructureEntries with final structures
were returned. This is no longer needed as all entries will contain
structure data by default.
property_data (list): Specify additional properties to include in
entry.data. If None, only default data is included. Should be a subset of
input parameters in the 'MPRester.thermo.available_fields' list.
conventional_unit_cell (bool): Whether to get the standard
conventional unit cell
Returns:
List of ComputedEntry or ComputedStructureEntry object.
"""
return list(
self.thermo.get_data_by_id(
document_id=material_id, fields=["entries"]
).entries.values()
)
return self.get_entries(material_id,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell)

def get_entries_in_chemsys(
self, elements: Union[str, List[str]], use_gibbs: Optional[int] = None,
self, elements: Union[str, List[str]],
use_gibbs: Optional[int] = None,
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,
additional_criteria=None,
):
"""
Helper method to get a list of ComputedEntries in a chemical system.
Expand All @@ -817,9 +907,34 @@ def get_entries_in_chemsys(
(see GibbsComputedStructureEntry). The number is the temperature in
Kelvin at which to estimate the free energy. Must be between 300 K and
2000 K.
compatible_only (bool): Whether to return only "compatible"
entries. Compatible entries are entries that have been
processed using the MaterialsProject2020Compatibility class,
which performs adjustments to allow mixing of GGA and GGA+U
calculations for more accurate phase diagrams and reaction
energies. This data is obtained from the core "thermo" API endpoint.
inc_structure (str): *This is a deprecated argument*. Previously, if None, entries
returned were ComputedEntries. If inc_structure="initial",
ComputedStructureEntries with initial structures were returned.
Otherwise, ComputedStructureEntries with final structures
were returned. This is no longer needed as all entries will contain
structure data by default.
property_data (list): Specify additional properties to include in
entry.data. If None, only default data is included. Should be a subset of
input parameters in the 'MPRester.thermo.available_fields' list.
conventional_unit_cell (bool): Whether to get the standard
conventional unit cell
additional_criteria (dict): *This is a deprecated argument*. To obtain entry objects
with additional criteria, use the `MPRester.thermo.search` method directly.
Returns:
List of ComputedEntries.
List of ComputedStructureEntries.
"""

if additional_criteria is not None:
warnings.warn("The 'additional_criteria' argument is deprecated. "
"To obtain entry objects with additional criteria, use "
"the 'MPRester.thermo.search' method directly")

if isinstance(elements, str):
elements = elements.split("-")

Expand All @@ -830,7 +945,11 @@ def get_entries_in_chemsys(

entries = [] # type: List[ComputedEntry]

entries.extend(self.get_entries(all_chemsyses))
entries.extend(self.get_entries(all_chemsyses,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell))

if use_gibbs:
# replace the entries with GibbsComputedStructureEntry
Expand Down
36 changes: 32 additions & 4 deletions tests/test_mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from emmet.core.symmetry import CrystalSystem
from emmet.core.tasks import TaskDoc
from emmet.core.vasp.calc_types import CalcType
from sympy import prime
from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client import MPRester
from pymatgen.analysis.magnetism import Ordering
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_get_structures(self, mpr):
structs = mpr.get_structures("Mn-O", final=False)
assert len(structs) > 0

@pytest.mark.skip(reason="endpoint issues")
@pytest.mark.skip(reason="Endpoint issues")
def test_find_structure(self, mpr):
path = os.path.join(MAPIClientSettings().TEST_FILES, "Si_mp_149.cif")
with open(path) as file:
Expand Down Expand Up @@ -130,12 +131,41 @@ def test_get_entries(self, mpr):

assert sorted_entries != entries

# Formula
formula = "SiO2"
entries = mpr.get_entries(formula)

for e in entries:
assert isinstance(e, ComputedEntry)

# Property data
formula = "BiFeO3"
entries = mpr.get_entries(formula, property_data=["energy_above_hull"])

for e in entries:
assert e.data.get("energy_above_hull", None) is not None

# Conventional structure
formula = "BiFeO3"
entry = mpr.get_entry_by_material_id("mp-22526", inc_structure=True, conventional_unit_cell=True)[0]

s = entry.structure
assert pytest.approx(s.lattice.a) == s.lattice.b
assert pytest.approx(s.lattice.a) != s.lattice.c
assert pytest.approx(s.lattice.alpha) == 90
assert pytest.approx(s.lattice.beta) == 90
assert pytest.approx(s.lattice.gamma) == 120

# Ensure energy per atom is same
prim = mpr.get_entry_by_material_id("mp-22526", inc_structure=True, conventional_unit_cell=False)[0]
assert pytest.approx(prim.energy_per_atom) == entry.energy_per_atom

s = prim.structure
assert pytest.approx(s.lattice.a) == s.lattice.b
assert pytest.approx(s.lattice.a) == s.lattice.c
assert pytest.approx(s.lattice.alpha) == s.lattice.beta
assert pytest.approx(s.lattice.alpha) == s.lattice.gamma

def test_get_entries_in_chemsys(self, mpr):
syms = ["Li", "Fe", "O"]
syms2 = "Li-Fe-O"
Expand All @@ -154,7 +184,6 @@ def test_get_entries_in_chemsys(self, mpr):
for e in gibbs_entries:
assert isinstance(e, GibbsComputedStructureEntry)

@pytest.mark.skip(reason="Until SSL issue fix")
def test_get_pourbaix_entries(self, mpr):
# test input chemsys as a list of elements
pbx_entries = mpr.get_pourbaix_entries(["Fe", "Cr"])
Expand Down Expand Up @@ -195,7 +224,6 @@ def test_get_pourbaix_entries(self, mpr):
# so4_two_minus = pbx_entries[9]
# self.assertAlmostEqual(so4_two_minus.energy, 0.301511, places=3)

@pytest.mark.skip(reason="Until SSL issue fix")
def test_get_ion_entries(self, mpr):
entries = mpr.get_entries_in_chemsys("Ti-O-H")
pd = PhaseDiagram(entries)
Expand Down Expand Up @@ -249,7 +277,7 @@ def test_get_phonon_data_by_material_id(self, mpr):
dos = mpr.get_phonon_dos_by_material_id("mp-11659")
assert isinstance(dos, PhononDos)

@pytest.mark.xfail(reason="SSL issue")
@pytest.mark.skip(reason="Test needs fixing with ENV variables")
def test_get_charge_density_data(self, mpr):
chgcar = mpr.get_charge_density_from_material_id("mp-149")
assert isinstance(chgcar, Chgcar)
Expand Down