Skip to content

Update chemenv documetation and suggestions #771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions mp_api/client/routes/chemenv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from collections import defaultdict
from typing import List, Optional, Tuple, Union

from emmet.core.chemenv import ChemEnvDoc
from emmet.core.chemenv import (
COORDINATION_GEOMETRIES,
COORDINATION_GEOMETRIES_IUCR,
COORDINATION_GEOMETRIES_IUPAC,
COORDINATION_GEOMETRIES_NAMES,
ChemEnvDoc,
)

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids
Expand All @@ -15,9 +21,21 @@ class ChemenvRester(BaseRester[ChemEnvDoc]):
def search(
self,
material_ids: Optional[Union[str, List[str]]] = None,
chemenv_iucr: Optional[Union[str, List[str]]] = None,
chemenv_iupac: Optional[Union[str, List[str]]] = None,
chemenv_name: Optional[Union[str, List[str]]] = None,
chemenv_iucr: Optional[
Union[COORDINATION_GEOMETRIES_IUCR, List[COORDINATION_GEOMETRIES_IUCR]]
] = None,
chemenv_iupac: Optional[
Union[COORDINATION_GEOMETRIES_IUPAC, List[COORDINATION_GEOMETRIES_IUPAC]]
] = None,
chemenv_name: Optional[
Union[COORDINATION_GEOMETRIES_NAMES, List[COORDINATION_GEOMETRIES_NAMES]]
] = None,
chemenv_symbol: Optional[
Union[COORDINATION_GEOMETRIES, List[COORDINATION_GEOMETRIES]]
] = None,
species: Optional[Union[str, List[str]]] = None,
elements: Optional[Union[str, List[str]]] = None,
exclude_elements: Optional[List[str]] = None,
csm: Optional[Tuple[float, float]] = None,
density: Optional[Tuple[float, float]] = None,
num_elements: Optional[Tuple[int, int]] = None,
Expand All @@ -28,15 +46,23 @@ def search(
chunk_size: int = 1000,
all_fields: bool = True,
fields: Optional[List[str]] = None,
) -> List[ChemEnvDoc]:
):
"""Query for chemical environment data.

Arguments:
material_ids (str, List[str]): Search forchemical environment associated with the specified Material IDs.
chemenv_iucr (str, List[str]): Unique cationic species in IUCR format.
chemenv_iupac (str, List[str]): Unique cationic species in IUPAC format.
chemenv_iupac (str, List[str]): Coordination environment descriptions for unique cationic species.
density (Tuple[float,float]): Minimum and maximum value of continuous symmetry measure to consider.
chemenv_iucr (COORDINATION_GEOMETRIES_IUCR, List[COORDINATION_GEOMETRIES_IUCR]): Unique cationic species in
IUCR format, e.g. "[3n]".
chemenv_iupac (COORDINATION_GEOMETRIES_IUPAC, List[COORDINATION_GEOMETRIES_IUPAC]): Unique cationic species
in IUPAC format, e.g., "T-4".
chemenv_name (COORDINATION_GEOMETRIES_NAMES, List[COORDINATION_GEOMETRIES_NAMES]): Coordination environment
descriptions in text form for unique cationic species, e.g. "Tetrahedron".
chemenv_symbol (COORDINATION_GEOMETRIES, List[COORDINATION_GEOMETRIES]): Coordination environment
descriptions as used in ChemEnv package for unique cationic species, e.g. "T:4".
species (str, List[str]): Cationic species in the crystal structure, e.g. "Ti4+".
elements (str, List[str]): Element names in the crystal structure, e.g., "Ti".
exclude_elements (List[str]): A list of elements to exclude.
csm (Tuple[float,float]): Minimum and maximum value of continuous symmetry measure to consider.
density (Tuple[float,float]): Minimum and maximum density to consider.
num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider.
num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider.
Expand Down Expand Up @@ -66,6 +92,12 @@ def search(
{"nsites_min": num_sites[0], "nsites_max": num_sites[1]}
)

if elements:
query_params.update({"elements": ",".join(elements)})

if exclude_elements:
query_params.update({"exclude_elements": ",".join(exclude_elements)})

if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
Expand All @@ -79,23 +111,29 @@ def search(

query_params.update({"material_ids": ",".join(validate_ids(material_ids))})

if chemenv_iucr:
if isinstance(chemenv_iucr, str):
chemenv_iucr = [chemenv_iucr]

query_params.update({"chemenv_iucr": ",".join(chemenv_iucr)})
chemenv_literals = {
"chemenv_iucr": (chemenv_iucr, COORDINATION_GEOMETRIES_IUCR),
"chemenv_iupac": (chemenv_iupac, COORDINATION_GEOMETRIES_IUPAC),
"chemenv_name": (chemenv_name, COORDINATION_GEOMETRIES_NAMES),
"chemenv_symbol": (chemenv_symbol, COORDINATION_GEOMETRIES),
}

if chemenv_iupac:
if isinstance(chemenv_iupac, str):
chemenv_iupac = [chemenv_iupac]
for chemenv_var_name, (chemenv_var, literals) in chemenv_literals.items():
if chemenv_var:
t_types = {t if isinstance(t, str) else t.value for t in chemenv_var}
valid_types = {*map(str, literals.__args__)}
if invalid_types := t_types - valid_types:
raise ValueError(
f"Invalid type(s) passed for {chemenv_var_name}: {invalid_types}, valid types are: {valid_types}"
)

query_params.update({"chemenv_iupac": ",".join(chemenv_iupac)})
query_params.update({chemenv_var_name: ",".join(t_types)})

if chemenv_name:
if isinstance(chemenv_name, str):
chemenv_name = [chemenv_name]
if species:
if isinstance(species, str):
species = [species]

query_params.update({"chemenv_name": ",".join(chemenv_name)})
query_params.update({"species": ",".join(species)})

if sort_fields:
query_params.update(
Expand All @@ -113,5 +151,5 @@ def search(
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
**query_params,
)
98 changes: 98 additions & 0 deletions tests/test_chemenv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import typing

import pytest

from mp_api.client.routes.chemenv import ChemenvRester


@pytest.fixture
def rester():
rester = ChemenvRester()
yield rester
rester.session.close()


excluded_params = [
"sort_fields",
"chunk_size",
"num_chunks",
"all_fields",
"fields",
"volume",
]

sub_doc_fields = [] # type: list

alt_name_dict = {
"material_ids": "material_id",
"exclude_elements": "material_id",
"num_elements": "nelements",
"num_sites": "nsites",
} # type: dict

custom_field_tests = {
"material_ids": ["mp-22526"],
"elements": ["Si", "O"],
"exclude_elements": ["Si", "O"],
"chemenv_symbol": ["S:1"],
"chemenv_iupac": ["IC-12"],
"chemenv_iucr": ["[2l]"],
"chemenv_name": ["Octahedron"],
"species": ["Cu2+"],
} # type: dict


@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

if search_method is not None:
# Get list of parameters
param_tuples = list(typing.get_type_hints(search_method).items())

# Query API for each numeric and boolean parameter and check if returned
for entry in param_tuples:
param = entry[0]
if param not in excluded_params:
param_type = entry[1].__args__[0]
q = None

if param_type == typing.Tuple[int, int]:
project_field = alt_name_dict.get(param, None)
q = {
param: (-100, 100),
"chunk_size": 1,
"num_chunks": 1,
}
elif param_type == typing.Tuple[float, float]:
project_field = alt_name_dict.get(param, None)
q = {
param: (-1.12, 1.12),
"chunk_size": 1,
"num_chunks": 1,
}
elif param_type is bool:
project_field = alt_name_dict.get(param, None)
q = {
param: False,
"chunk_size": 1,
"num_chunks": 1,
}
elif param in custom_field_tests:
project_field = alt_name_dict.get(param, None)
q = {
param: custom_field_tests[param],
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].dict()

for sub_field in sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]

assert (
doc[project_field if project_field is not None else param]
is not None
)