diff --git a/mp_api/client/routes/chemenv.py b/mp_api/client/routes/chemenv.py index f82cae344..8b9fe5bda 100644 --- a/mp_api/client/routes/chemenv.py +++ b/mp_api/client/routes/chemenv.py @@ -1,7 +1,13 @@ from collections import defaultdict from typing import List, Optional, Tuple, Union -from emmet.core.chemenv import ChemEnvDoc +from emmet.core.chemenv import ( + COORDINATION_GEOMETRIES, + COORDINATION_GEOMETRIES_IUCR, + COORDINATION_GEOMETRIES_IUPAC, + COORDINATION_GEOMETRIES_NAMES, + ChemEnvDoc, +) from mp_api.client.core import BaseRester from mp_api.client.core.utils import validate_ids @@ -15,9 +21,21 @@ class ChemenvRester(BaseRester[ChemEnvDoc]): def search( self, material_ids: Optional[Union[str, List[str]]] = None, - chemenv_iucr: Optional[Union[str, List[str]]] = None, - chemenv_iupac: Optional[Union[str, List[str]]] = None, - chemenv_name: Optional[Union[str, List[str]]] = None, + chemenv_iucr: Optional[ + Union[COORDINATION_GEOMETRIES_IUCR, List[COORDINATION_GEOMETRIES_IUCR]] + ] = None, + chemenv_iupac: Optional[ + Union[COORDINATION_GEOMETRIES_IUPAC, List[COORDINATION_GEOMETRIES_IUPAC]] + ] = None, + chemenv_name: Optional[ + Union[COORDINATION_GEOMETRIES_NAMES, List[COORDINATION_GEOMETRIES_NAMES]] + ] = None, + chemenv_symbol: Optional[ + Union[COORDINATION_GEOMETRIES, List[COORDINATION_GEOMETRIES]] + ] = None, + species: Optional[Union[str, List[str]]] = None, + elements: Optional[Union[str, List[str]]] = None, + exclude_elements: Optional[List[str]] = None, csm: Optional[Tuple[float, float]] = None, density: Optional[Tuple[float, float]] = None, num_elements: Optional[Tuple[int, int]] = None, @@ -28,15 +46,23 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: Optional[List[str]] = None, - ) -> List[ChemEnvDoc]: + ): """Query for chemical environment data. Arguments: material_ids (str, List[str]): Search forchemical environment associated with the specified Material IDs. - chemenv_iucr (str, List[str]): Unique cationic species in IUCR format. - chemenv_iupac (str, List[str]): Unique cationic species in IUPAC format. - chemenv_iupac (str, List[str]): Coordination environment descriptions for unique cationic species. - density (Tuple[float,float]): Minimum and maximum value of continuous symmetry measure to consider. + chemenv_iucr (COORDINATION_GEOMETRIES_IUCR, List[COORDINATION_GEOMETRIES_IUCR]): Unique cationic species in + IUCR format, e.g. "[3n]". + chemenv_iupac (COORDINATION_GEOMETRIES_IUPAC, List[COORDINATION_GEOMETRIES_IUPAC]): Unique cationic species + in IUPAC format, e.g., "T-4". + chemenv_name (COORDINATION_GEOMETRIES_NAMES, List[COORDINATION_GEOMETRIES_NAMES]): Coordination environment + descriptions in text form for unique cationic species, e.g. "Tetrahedron". + chemenv_symbol (COORDINATION_GEOMETRIES, List[COORDINATION_GEOMETRIES]): Coordination environment + descriptions as used in ChemEnv package for unique cationic species, e.g. "T:4". + species (str, List[str]): Cationic species in the crystal structure, e.g. "Ti4+". + elements (str, List[str]): Element names in the crystal structure, e.g., "Ti". + exclude_elements (List[str]): A list of elements to exclude. + csm (Tuple[float,float]): Minimum and maximum value of continuous symmetry measure to consider. density (Tuple[float,float]): Minimum and maximum density to consider. num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider. num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider. @@ -66,6 +92,12 @@ def search( {"nsites_min": num_sites[0], "nsites_max": num_sites[1]} ) + if elements: + query_params.update({"elements": ",".join(elements)}) + + if exclude_elements: + query_params.update({"exclude_elements": ",".join(exclude_elements)}) + if num_elements: if isinstance(num_elements, int): num_elements = (num_elements, num_elements) @@ -79,23 +111,29 @@ def search( query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) - if chemenv_iucr: - if isinstance(chemenv_iucr, str): - chemenv_iucr = [chemenv_iucr] - - query_params.update({"chemenv_iucr": ",".join(chemenv_iucr)}) + chemenv_literals = { + "chemenv_iucr": (chemenv_iucr, COORDINATION_GEOMETRIES_IUCR), + "chemenv_iupac": (chemenv_iupac, COORDINATION_GEOMETRIES_IUPAC), + "chemenv_name": (chemenv_name, COORDINATION_GEOMETRIES_NAMES), + "chemenv_symbol": (chemenv_symbol, COORDINATION_GEOMETRIES), + } - if chemenv_iupac: - if isinstance(chemenv_iupac, str): - chemenv_iupac = [chemenv_iupac] + for chemenv_var_name, (chemenv_var, literals) in chemenv_literals.items(): + if chemenv_var: + t_types = {t if isinstance(t, str) else t.value for t in chemenv_var} + valid_types = {*map(str, literals.__args__)} + if invalid_types := t_types - valid_types: + raise ValueError( + f"Invalid type(s) passed for {chemenv_var_name}: {invalid_types}, valid types are: {valid_types}" + ) - query_params.update({"chemenv_iupac": ",".join(chemenv_iupac)}) + query_params.update({chemenv_var_name: ",".join(t_types)}) - if chemenv_name: - if isinstance(chemenv_name, str): - chemenv_name = [chemenv_name] + if species: + if isinstance(species, str): + species = [species] - query_params.update({"chemenv_name": ",".join(chemenv_name)}) + query_params.update({"species": ",".join(species)}) if sort_fields: query_params.update( @@ -113,5 +151,5 @@ def search( chunk_size=chunk_size, all_fields=all_fields, fields=fields, - **query_params + **query_params, ) diff --git a/tests/test_chemenv.py b/tests/test_chemenv.py new file mode 100644 index 000000000..b899e66e8 --- /dev/null +++ b/tests/test_chemenv.py @@ -0,0 +1,98 @@ +import os +import typing + +import pytest + +from mp_api.client.routes.chemenv import ChemenvRester + + +@pytest.fixture +def rester(): + rester = ChemenvRester() + yield rester + rester.session.close() + + +excluded_params = [ + "sort_fields", + "chunk_size", + "num_chunks", + "all_fields", + "fields", + "volume", +] + +sub_doc_fields = [] # type: list + +alt_name_dict = { + "material_ids": "material_id", + "exclude_elements": "material_id", + "num_elements": "nelements", + "num_sites": "nsites", +} # type: dict + +custom_field_tests = { + "material_ids": ["mp-22526"], + "elements": ["Si", "O"], + "exclude_elements": ["Si", "O"], + "chemenv_symbol": ["S:1"], + "chemenv_iupac": ["IC-12"], + "chemenv_iucr": ["[2l]"], + "chemenv_name": ["Octahedron"], + "species": ["Cu2+"], +} # type: dict + + +@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.") +def test_client(rester): + search_method = rester.search + + if search_method is not None: + # Get list of parameters + param_tuples = list(typing.get_type_hints(search_method).items()) + + # Query API for each numeric and boolean parameter and check if returned + for entry in param_tuples: + param = entry[0] + if param not in excluded_params: + param_type = entry[1].__args__[0] + q = None + + if param_type == typing.Tuple[int, int]: + project_field = alt_name_dict.get(param, None) + q = { + param: (-100, 100), + "chunk_size": 1, + "num_chunks": 1, + } + elif param_type == typing.Tuple[float, float]: + project_field = alt_name_dict.get(param, None) + q = { + param: (-1.12, 1.12), + "chunk_size": 1, + "num_chunks": 1, + } + elif param_type is bool: + project_field = alt_name_dict.get(param, None) + q = { + param: False, + "chunk_size": 1, + "num_chunks": 1, + } + elif param in custom_field_tests: + project_field = alt_name_dict.get(param, None) + q = { + param: custom_field_tests[param], + "chunk_size": 1, + "num_chunks": 1, + } + doc = search_method(**q)[0].dict() + + for sub_field in sub_doc_fields: + if sub_field in doc: + doc = doc[sub_field] + + assert ( + doc[project_field if project_field is not None else param] + is not None + )