Skip to content

Header dictionary pass through and BaseRester nesting fix #715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 51 additions & 105 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
monty_decode: bool = True,
use_document_model: bool = True,
timeout: int = 20,
headers: dict = None,
):
"""
Args:
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
as a dictionary. This can be simpler to work with but bypasses data validation
and will not give auto-complete for available fields.
timeout: Time in seconds to wait until a request timeout error is thrown
headers (dict): Custom headers for localhost connections.
"""

self.api_key = api_key
Expand All @@ -99,6 +101,7 @@ def __init__(
self.monty_decode = monty_decode
self.use_document_model = use_document_model
self.timeout = timeout
self.headers = headers or {}

if self.suffix:
self.endpoint = urljoin(self.endpoint, self.suffix)
Expand All @@ -117,20 +120,20 @@ def __init__(
@property
def session(self) -> requests.Session:
if not self._session:
self._session = self._create_session(self.api_key, self.include_user_agent)
self._session = self._create_session(self.api_key, self.include_user_agent, self.headers)
return self._session

@staticmethod
def _create_session(api_key, include_user_agent):
def _create_session(api_key, include_user_agent, headers):
session = requests.Session()
session.headers = {"x-api-key": api_key}
session.headers.update(headers)

if include_user_agent:
pymatgen_info = "pymatgen/" + pmg_version
python_info = f"Python/{sys.version.split()[0]}"
platform_info = f"{platform.system()}/{platform.release()}"
session.headers[
"user-agent"
] = f"{pymatgen_info} ({python_info} {platform_info})"
session.headers["user-agent"] = f"{pymatgen_info} ({python_info} {platform_info})"

max_retry_num = MAPIClientSettings().MAX_RETRIES
retry = Retry(
Expand Down Expand Up @@ -219,9 +222,7 @@ def _post_resource(
message = data
else:
try:
message = ", ".join(
f"{entry['loc'][1]} - {entry['msg']}" for entry in data
)
message = ", ".join(f"{entry['loc'][1]} - {entry['msg']}" for entry in data)
except (KeyError, IndexError):
message = str(data)

Expand Down Expand Up @@ -352,17 +353,13 @@ def _submit_requests(
url_string += f"{key}={parsed_val}&"

bare_url_len = len(url_string)
max_param_str_length = (
MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len
)
max_param_str_length = MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len

# Next, check if default number of parallel requests works.
# If not, make slice size the minimum number of param entries
# contained in any substring of length max_param_str_length.
param_length = len(criteria[parallel_param].split(","))
slice_size = (
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1
)
slice_size = int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1

url_param_string = quote(criteria[parallel_param])

Expand All @@ -374,9 +371,7 @@ def _submit_requests(

if len(parallel_param_str_chunks) > 0:

params_min_chunk = min(
parallel_param_str_chunks, key=lambda x: len(x.split("%2C"))
)
params_min_chunk = min(parallel_param_str_chunks, key=lambda x: len(x.split("%2C")))

num_params_min_chunk = len(params_min_chunk.split("%2C"))

Expand Down Expand Up @@ -406,11 +401,7 @@ def _submit_requests(
# Split list and generate multiple criteria
new_criteria = [
{
**{
key: criteria[key]
for key in criteria
if key not in [parallel_param, "_limit"]
},
**{key: criteria[key] for key in criteria if key not in [parallel_param, "_limit"]},
parallel_param: ",".join(list_chunk),
"_limit": new_limits[list_num],
}
Expand All @@ -433,13 +424,9 @@ def _submit_requests(
subtotals = []
remaining_docs_avail = {}

initial_params_list = [
{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria
]
initial_params_list = [{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria]

initial_data_tuples = self._multi_thread(
use_document_model, initial_params_list
)
initial_data_tuples = self._multi_thread(use_document_model, initial_params_list)

for data, subtotal, crit_ind in initial_data_tuples:

Expand All @@ -452,9 +439,7 @@ def _submit_requests(

# Rebalance if some parallel queries produced too few results
if len(remaining_docs_avail) > 1 and len(total_data["data"]) < chunk_size:
remaining_docs_avail = dict(
sorted(remaining_docs_avail.items(), key=lambda item: item[1])
)
remaining_docs_avail = dict(sorted(remaining_docs_avail.items(), key=lambda item: item[1]))

# Redistribute missing docs from initial chunk among queries
# which have head room with respect to remaining document number.
Expand All @@ -481,19 +466,15 @@ def _submit_requests(
new_limits[crit_ind] += fill_docs
fill_docs = 0

rebalance_params.append(
{"url": url, "verify": True, "params": copy(crit)}
)
rebalance_params.append({"url": url, "verify": True, "params": copy(crit)})

new_criteria[crit_ind]["_skip"] += crit["_limit"]
new_criteria[crit_ind]["_limit"] = chunk_size

# Obtain missing initial data after rebalancing
if len(rebalance_params) > 0:

rebalance_data_tuples = self._multi_thread(
use_document_model, rebalance_params
)
rebalance_data_tuples = self._multi_thread(use_document_model, rebalance_params)

for data, _, _ in rebalance_data_tuples:
total_data["data"].extend(data["data"])
Expand All @@ -507,9 +488,7 @@ def _submit_requests(
total_data["meta"] = last_data_entry["meta"]

# Get max number of response pages
max_pages = (
num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)
)
max_pages = num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)

# Get total number of docs needed
num_docs_needed = min((max_pages * chunk_size), total_num_docs)
Expand Down Expand Up @@ -625,22 +604,16 @@ def _multi_thread(

return_data = []

params_gen = iter(
params_list
) # Iter necessary for islice to keep track of what has been accessed
params_gen = iter(params_list) # Iter necessary for islice to keep track of what has been accessed

params_ind = 0

with ThreadPoolExecutor(
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS
) as executor:
with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor:

# Get list of initial futures defined by max number of parallel requests
futures = set()

for params in itertools.islice(
params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS
):
for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS):

future = executor.submit(
self._submit_request_and_process,
Expand Down Expand Up @@ -702,13 +675,9 @@ def _submit_request_and_process(
Tuple with data and total number of docs in matching the query in the database.
"""
try:
response = self.session.get(
url=url, verify=verify, params=params, timeout=timeout
)
response = self.session.get(url=url, verify=verify, params=params, timeout=timeout)
except requests.exceptions.ConnectTimeout:
raise MPRestError(
f"REST query timed out on URL {url}. Try again with a smaller request."
)
raise MPRestError(f"REST query timed out on URL {url}. Try again with a smaller request.")

if response.status_code == 200:

Expand All @@ -724,18 +693,10 @@ def _submit_request_and_process(
raw_doc_list = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

if len(raw_doc_list) > 0:
data_model, set_fields, _ = self._generate_returned_model(
raw_doc_list[0]
)
data_model, set_fields, _ = self._generate_returned_model(raw_doc_list[0])

data["data"] = [
data_model(
**{
field: value
for field, value in raw_doc.dict().items()
if field in set_fields
}
)
data_model(**{field: value for field, value in raw_doc.dict().items() if field in set_fields})
for raw_doc in raw_doc_list
]

Expand All @@ -754,9 +715,7 @@ def _submit_request_and_process(
message = data
else:
try:
message = ", ".join(
f"{entry['loc'][1]} - {entry['msg']}" for entry in data
)
message = ", ".join(f"{entry['loc'][1]} - {entry['msg']}" for entry in data)
except (KeyError, IndexError):
message = str(data)

Expand All @@ -767,9 +726,7 @@ def _submit_request_and_process(

def _generate_returned_model(self, doc):

set_fields = [
field for field, _ in doc if field in doc.dict(exclude_unset=True)
]
set_fields = [field for field, _ in doc if field in doc.dict(exclude_unset=True)]
unset_fields = [field for field in doc.__fields__ if field not in set_fields]

data_model = create_model(
Expand All @@ -779,19 +736,12 @@ def _generate_returned_model(self, doc):
)

data_model.__fields__ = {
**{
name: description
for name, description in data_model.__fields__.items()
if name in set_fields
},
**{name: description for name, description in data_model.__fields__.items() if name in set_fields},
"fields_not_requested": data_model.__fields__["fields_not_requested"],
}

def new_repr(self) -> str:
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
)
extra = ",\n".join(f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" for n in data_model.__fields__)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
return s
Expand All @@ -813,9 +763,7 @@ def new_getattr(self, attr) -> str:
" A full list of unrequested fields can be found in `fields_not_requested`."
)
else:
raise AttributeError(
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)
raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {attr!r}")

data_model.__repr__ = new_repr
data_model.__str__ = new_str
Expand Down Expand Up @@ -872,10 +820,7 @@ def get_data_by_id(
"""

if document_id is None:
raise ValueError(
"Please supply a specific ID. You can use the query method to find "
"ids of interest."
)
raise ValueError("Please supply a specific ID. You can use the query method to find " "ids of interest.")

if self.primary_key in ["material_id", "task_id"]:
validate_ids([document_id])
Expand All @@ -897,28 +842,31 @@ def get_data_by_id(
if self.primary_key == "material_id":
# see if the material_id has changed, perhaps a task_id was supplied
# this should likely be re-thought
from mp_api.client import MPRester
from mp_api.client.routes.materials import MaterialsRester

with MPRester(api_key=self.api_key, endpoint=self.base_endpoint) as mpr:
new_document_id = mpr.get_materials_id_from_task_id(document_id)
with MaterialsRester(
api_key=self.api_key, endpoint=self.base_endpoint, use_document_model=False, monty_decode=False
) as mpr:
docs = mpr.search(task_ids=[document_id], fields=["material_id"])

if new_document_id is not None:
warnings.warn(
f"Document primary key has changed from {document_id} to {new_document_id}, "
f"returning data for {new_document_id} in {self.suffix} route. "
)
document_id = new_document_id
if len(docs) > 0:

results = self._query_resource_data(
criteria=criteria, fields=fields, suburl=document_id # type: ignore
)
new_document_id = docs[0].get("material_id", None)

if new_document_id is not None:
warnings.warn(
f"Document primary key has changed from {document_id} to {new_document_id}, "
f"returning data for {new_document_id} in {self.suffix} route. "
)

results = self._query_resource_data(
criteria=criteria, fields=fields, suburl=new_document_id # type: ignore
)

if not results:
raise MPRestError(f"No result for record {document_id}.")
elif len(results) > 1: # pragma: no cover
raise ValueError(
f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug."
)
raise ValueError(f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug.")
else:
return results[0]

Expand Down Expand Up @@ -1025,9 +973,7 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
False,
False,
) # do not waste cycles decoding
results = self._query_resource(
criteria=criteria, num_chunks=1, chunk_size=1
)
results = self._query_resource(criteria=criteria, num_chunks=1, chunk_size=1)
self.monty_decode, self.use_document_model = user_preferences
return results["meta"]["total_doc"]
except Exception: # pragma: no cover
Expand Down
9 changes: 7 additions & 2 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
include_user_agent=True,
monty_decode: bool = True,
use_document_model: bool = True,
headers: dict = None,
):
"""
Args:
Expand Down Expand Up @@ -115,6 +116,7 @@ def __init__(
use_document_model: If False, skip the creating the document model and return data
as a dictionary. This can be simpler to work with but bypasses data validation
and will not give auto-complete for available fields.
headers (dict): Custom headers for localhost connections.
"""

if api_key and len(api_key) == 16:
Expand All @@ -126,14 +128,17 @@ def __init__(

self.api_key = api_key
self.endpoint = endpoint
self.session = BaseRester._create_session(api_key=api_key, include_user_agent=include_user_agent)
self.headers = headers or {}
self.session = BaseRester._create_session(
api_key=api_key, include_user_agent=include_user_agent, headers=self.headers
)
self.use_document_model = use_document_model
self.monty_decode = monty_decode

try:
from mpcontribs.client import Client

self.contribs = Client(api_key)
self.contribs = Client(api_key, headers=self.headers)
except ImportError:
self.contribs = None
warnings.warn(
Expand Down