Skip to content

Commit 72eafaa

Browse files
author
Jason Munro
authored
Add back additional_criteria as an input to get_entries and get_entries_in_chemsys (#693)
* Add back additiona_critiera to get_entries * Change task test input * Revert task test * Update generic rester tests * Update gitignore and tests * Update imports and default thread number for multithreading * Remove formula in client test for tasks * Allow top level entry related methods to work without de-serialization enabled * Linting * Fix energy key * Reduce parallel test runs
1 parent 2292fc1 commit 72eafaa

File tree

7 files changed

+136
-65
lines changed

7 files changed

+136
-65
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
4747
test:
4848
strategy:
49-
max-parallel: 6
49+
max-parallel: 2
5050
matrix:
5151
os: [ubuntu-latest, macos-latest, windows-latest]
5252
python-version: [3.8, 3.9]

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ ENV/
106106
.project
107107
.pydevproject
108108

109+
# fleet
110+
.fleet
111+
109112
*~
110113

111114
.idea

mp_api/client/core/settings.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
from pydantic import BaseSettings, Field
22
from mp_api.client import __file__ as root_dir
3+
from multiprocessing import cpu_count
34
from typing import List
45
import os
56

7+
CPU_COUNT = 8
8+
9+
try:
10+
CPU_COUNT = cpu_count()
11+
except NotImplementedError:
12+
pass
13+
614

715
class MAPIClientSettings(BaseSettings):
816
"""
@@ -41,7 +49,7 @@ class MAPIClientSettings(BaseSettings):
4149
)
4250

4351
NUM_PARALLEL_REQUESTS: int = Field(
44-
8, description="Number of parallel requests to send.",
52+
CPU_COUNT, description="Number of parallel requests to send.",
4553
)
4654

4755
MAX_RETRIES: int = Field(3, description="Maximum number of retries for requests.")

mp_api/client/mprester.py

Lines changed: 95 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
import itertools
2-
from multiprocessing.sharedctypes import Value
32
import warnings
43
from functools import lru_cache
54
from os import environ
6-
from typing import Dict, List, Optional, Tuple, Union
5+
from typing import Dict, List, Optional, Union
76

87
from emmet.core.charge_density import ChgcarDataDoc
98
from emmet.core.electronic_structure import BSPathType
109
from emmet.core.mpid import MPID
1110
from emmet.core.settings import EmmetSettings
12-
from emmet.core.summary import HasProps
13-
from emmet.core.symmetry import CrystalSystem
1411
from emmet.core.vasp.calc_types import CalcType
15-
from pymatgen.analysis.magnetism import Ordering
1612
from pymatgen.analysis.phase_diagram import PhaseDiagram
1713
from pymatgen.analysis.pourbaix_diagram import IonEntry
18-
from pymatgen.core import Composition, Element, Structure
14+
from pymatgen.core import Element, Structure
1915
from pymatgen.core.ion import Ion
2016
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
2117
from pymatgen.io.vasp import Chgcar
@@ -132,15 +128,20 @@ def __init__(
132128
self.session = BaseRester._create_session(
133129
api_key=api_key, include_user_agent=include_user_agent
134130
)
131+
self.use_document_model = use_document_model
132+
self.monty_decode = monty_decode
135133

136134
try:
137135
from mpcontribs.client import Client
136+
138137
self.contribs = Client(api_key)
139138
except ImportError:
140139
self.contribs = None
141-
warnings.warn("mpcontribs-client not installed. "
142-
"Install the package to query MPContribs data, or construct pourbaix diagrams: "
143-
"'pip install mpcontribs-client'")
140+
warnings.warn(
141+
"mpcontribs-client not installed. "
142+
"Install the package to query MPContribs data, or construct pourbaix diagrams: "
143+
"'pip install mpcontribs-client'"
144+
)
144145
except Exception as error:
145146
self.contribs = None
146147
warnings.warn(f"Problem loading MPContribs client: {error}")
@@ -186,15 +187,19 @@ def __exit__(self, exc_type, exc_val, exc_tb):
186187

187188
def __getattr__(self, attr):
188189
if attr == "alloys":
189-
raise MPRestError("Alloy addon package not installed. "
190-
"To query alloy data first install with: 'pip install pymatgen-analysis-alloys'")
190+
raise MPRestError(
191+
"Alloy addon package not installed. "
192+
"To query alloy data first install with: 'pip install pymatgen-analysis-alloys'"
193+
)
191194
elif attr == "charge_density":
192-
raise MPRestError("boto3 not installed. "
193-
"To query charge density data first install with: 'pip install boto3'")
195+
raise MPRestError(
196+
"boto3 not installed. "
197+
"To query charge density data first install with: 'pip install boto3'"
198+
)
194199
else:
195200
raise AttributeError(
196-
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
197-
)
201+
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
202+
)
198203

199204
def get_task_ids_associated_with_material_id(
200205
self, material_id: str, calc_types: Optional[List[CalcType]] = None
@@ -446,7 +451,8 @@ def get_entries(
446451
inc_structure: bool = None,
447452
property_data: List[str] = None,
448453
conventional_unit_cell: bool = False,
449-
sort_by_e_above_hull=False,
454+
sort_by_e_above_hull: bool = False,
455+
additional_criteria: dict = None,
450456
) -> List[ComputedStructureEntry]:
451457
"""
452458
Get a list of ComputedEntries or ComputedStructureEntries corresponding
@@ -476,14 +482,20 @@ def get_entries(
476482
conventional unit cell
477483
sort_by_e_above_hull (bool): Whether to sort the list of entries by
478484
e_above_hull in ascending order.
485+
additional_criteria (dict): Any additional criteria to pass. The keys and values should
486+
correspond to proper function inputs to `MPRester.thermo.search`. For instance,
487+
if you are only interested in entries on the convex hull, you could pass
488+
{"energy_above_hull": (0.0, 0.0)} or {"is_stable": True}.
479489
480490
Returns:
481491
List ComputedStructureEntry objects.
482492
"""
483493

484494
if inc_structure is not None:
485-
warnings.warn("The 'inc_structure' argument is deprecated as structure "
486-
"data is now always included in all returned entry objects.")
495+
warnings.warn(
496+
"The 'inc_structure' argument is deprecated as structure "
497+
"data is now always included in all returned entry objects."
498+
)
487499

488500
if isinstance(chemsys_formula_mpids, str):
489501
chemsys_formula_mpids = [chemsys_formula_mpids]
@@ -497,6 +509,9 @@ def get_entries(
497509
else:
498510
input_params = {"formula": chemsys_formula_mpids}
499511

512+
if additional_criteria:
513+
input_params = {**input_params, **additional_criteria}
514+
500515
entries = []
501516

502517
fields = ["entries"] if not property_data else ["entries"] + property_data
@@ -514,22 +529,25 @@ def get_entries(
514529
)
515530

516531
for doc in docs:
517-
for entry in doc.entries.values():
532+
entry_list = doc.entries.values() if self.use_document_model else doc["entries"].values()
533+
for entry in entry_list:
534+
entry_dict = entry.as_dict() if self.monty_decode else entry
518535
if not compatible_only:
519-
entry.correction = 0.0
520-
entry.energy_adjustments = []
536+
entry_dict["correction"] = 0.0
537+
entry_dict["energy_adjustments"] = []
521538

522539
if property_data:
523540
for property in property_data:
524-
entry.data[property] = doc.dict()[property]
541+
entry_dict["data"][property] = doc.dict()[property] if self.use_document_model else doc[
542+
property]
525543

526544
if conventional_unit_cell:
527545

528-
s = SpacegroupAnalyzer(entry.structure).get_conventional_standard_structure()
529-
site_ratio = (len(s) / len(entry.structure))
530-
new_energy = entry.uncorrected_energy * site_ratio
546+
entry_struct = Structure.from_dict(entry_dict["structure"])
547+
s = SpacegroupAnalyzer(entry_struct).get_conventional_standard_structure()
548+
site_ratio = len(s) / len(entry_struct)
549+
new_energy = entry_dict["energy"] * site_ratio
531550

532-
entry_dict = entry.as_dict()
533551
entry_dict["energy"] = new_energy
534552
entry_dict["structure"] = s.as_dict()
535553
entry_dict["correction"] = 0.0
@@ -540,7 +558,7 @@ def get_entries(
540558
for correction in entry_dict["energy_adjustments"]:
541559
correction["n_atoms"] *= site_ratio
542560

543-
entry = ComputedStructureEntry.from_dict(entry_dict)
561+
entry = ComputedStructureEntry.from_dict(entry_dict) if self.monty_decode else entry_dict
544562

545563
entries.append(entry)
546564

@@ -575,9 +593,11 @@ def get_pourbaix_entries(
575593
# imports are not top-level due to expense
576594
from pymatgen.analysis.pourbaix_diagram import PourbaixEntry
577595
from pymatgen.entries.compatibility import (
578-
Compatibility, MaterialsProject2020Compatibility,
596+
Compatibility,
597+
MaterialsProject2020Compatibility,
579598
MaterialsProjectAqueousCompatibility,
580-
MaterialsProjectCompatibility)
599+
MaterialsProjectCompatibility,
600+
)
581601
from pymatgen.entries.computed_entries import ComputedEntry
582602

583603
if solid_compat == "MaterialsProjectCompatibility":
@@ -638,8 +658,7 @@ def get_pourbaix_entries(
638658
# could be removed
639659
if use_gibbs:
640660
# replace the entries with GibbsComputedStructureEntry
641-
from pymatgen.entries.computed_entries import \
642-
GibbsComputedStructureEntry
661+
from pymatgen.entries.computed_entries import GibbsComputedStructureEntry
643662

644663
ion_ref_entries = GibbsComputedStructureEntry.from_entries(
645664
ion_ref_entries, temp=use_gibbs
@@ -846,11 +865,14 @@ def get_ion_entries(
846865

847866
return ion_entries
848867

849-
def get_entry_by_material_id(self, material_id: str,
850-
compatible_only: bool = True,
851-
inc_structure: bool = None,
852-
property_data: List[str] = None,
853-
conventional_unit_cell: bool = False,):
868+
def get_entry_by_material_id(
869+
self,
870+
material_id: str,
871+
compatible_only: bool = True,
872+
inc_structure: bool = None,
873+
property_data: List[str] = None,
874+
conventional_unit_cell: bool = False,
875+
):
854876
"""
855877
Get all ComputedEntry objects corresponding to a material_id.
856878
@@ -877,14 +899,17 @@ def get_entry_by_material_id(self, material_id: str,
877899
Returns:
878900
List of ComputedEntry or ComputedStructureEntry object.
879901
"""
880-
return self.get_entries(material_id,
881-
compatible_only=compatible_only,
882-
inc_structure=inc_structure,
883-
property_data=property_data,
884-
conventional_unit_cell=conventional_unit_cell)
902+
return self.get_entries(
903+
material_id,
904+
compatible_only=compatible_only,
905+
inc_structure=inc_structure,
906+
property_data=property_data,
907+
conventional_unit_cell=conventional_unit_cell,
908+
)
885909

886910
def get_entries_in_chemsys(
887-
self, elements: Union[str, List[str]],
911+
self,
912+
elements: Union[str, List[str]],
888913
use_gibbs: Optional[int] = None,
889914
compatible_only: bool = True,
890915
inc_structure: bool = None,
@@ -924,17 +949,14 @@ def get_entries_in_chemsys(
924949
input parameters in the 'MPRester.thermo.available_fields' list.
925950
conventional_unit_cell (bool): Whether to get the standard
926951
conventional unit cell
927-
additional_criteria (dict): *This is a deprecated argument*. To obtain entry objects
928-
with additional criteria, use the `MPRester.thermo.search` method directly.
952+
additional_criteria (dict): Any additional criteria to pass. The keys and values should
953+
correspond to proper function inputs to `MPRester.thermo.search`. For instance,
954+
if you are only interested in entries on the convex hull, you could pass
955+
{"energy_above_hull": (0.0, 0.0)} or {"is_stable": True}.
929956
Returns:
930957
List of ComputedStructureEntries.
931958
"""
932959

933-
if additional_criteria is not None:
934-
warnings.warn("The 'additional_criteria' argument is deprecated. "
935-
"To obtain entry objects with additional criteria, use "
936-
"the 'MPRester.thermo.search' method directly")
937-
938960
if isinstance(elements, str):
939961
elements = elements.split("-")
940962

@@ -945,19 +967,29 @@ def get_entries_in_chemsys(
945967

946968
entries = [] # type: List[ComputedEntry]
947969

948-
entries.extend(self.get_entries(all_chemsyses,
949-
compatible_only=compatible_only,
950-
inc_structure=inc_structure,
951-
property_data=property_data,
952-
conventional_unit_cell=conventional_unit_cell))
970+
entries.extend(
971+
self.get_entries(
972+
all_chemsyses,
973+
compatible_only=compatible_only,
974+
inc_structure=inc_structure,
975+
property_data=property_data,
976+
conventional_unit_cell=conventional_unit_cell,
977+
additional_criteria=additional_criteria,
978+
)
979+
)
980+
981+
if not self.monty_decode:
982+
entries = [ComputedStructureEntry.from_dict(entry) for entry in entries]
953983

954984
if use_gibbs:
955985
# replace the entries with GibbsComputedStructureEntry
956-
from pymatgen.entries.computed_entries import \
957-
GibbsComputedStructureEntry
986+
from pymatgen.entries.computed_entries import GibbsComputedStructureEntry
958987

959988
entries = GibbsComputedStructureEntry.from_entries(entries, temp=use_gibbs)
960989

990+
if not self.monty_decode:
991+
entries = [entry.as_dict() for entry in entries]
992+
961993
return entries
962994

963995
def get_bandstructure_by_material_id(
@@ -970,7 +1002,7 @@ def get_bandstructure_by_material_id(
9701002
Get the band structure pymatgen object associated with a Materials Project ID.
9711003
9721004
Arguments:
973-
materials_id (str): Materials Project ID for a material
1005+
material_id (str): Materials Project ID for a material
9741006
path_type (BSPathType): k-point path selection convention
9751007
line_mode (bool): Whether to return data for a line-mode calculation
9761008
@@ -986,7 +1018,7 @@ def get_dos_by_material_id(self, material_id: str):
9861018
Get the complete density of states pymatgen object associated with a Materials Project ID.
9871019
9881020
Arguments:
989-
materials_id (str): Materials Project ID for a material
1021+
material_id (str): Materials Project ID for a material
9901022
9911023
Returns:
9921024
dos (CompleteDos): CompleteDos object
@@ -1028,6 +1060,7 @@ def submit_structures(self, structures, public_name, public_email):
10281060
Args:
10291061
structures: A list of Structure objects
10301062
1063+
10311064
Returns:
10321065
?
10331066
"""
@@ -1077,8 +1110,10 @@ def get_charge_density_from_material_id(
10771110
"""
10781111

10791112
if not hasattr(self, "charge_density"):
1080-
raise MPRestError("boto3 not installed. "
1081-
"To query charge density data install the boto3 package.")
1113+
raise MPRestError(
1114+
"boto3 not installed. "
1115+
"To query charge density data install the boto3 package."
1116+
)
10821117

10831118
# TODO: really we want a recommended task_id for charge densities here
10841119
# this could potentially introduce an ambiguity

tests/test_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
"_user_settings",
3131
"_general_store",
3232
"tasks",
33+
"bonds",
3334
"xas",
3435
"elasticity",
36+
"fermi",
3537
"alloys",
3638
"summary",
3739
] # temp

0 commit comments

Comments
 (0)