Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 39 additions & 31 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def _query_resource(
use_document_model: if None, will defer to the self.use_document_model attribute
num_chunks: Maximum number of chunks of data to yield. None will yield all possible.
chunk_size: Number of data entries per chunk.
timeout : Time in seconds to wait until a request timeout error is thrown
timeout (float or None): Time in seconds to wait until a request timeout error is thrown

Returns:
A Resource, a dict with two keys, "data" containing a list of documents, and
Expand Down Expand Up @@ -805,27 +805,32 @@ def _query_resource(
except RequestException as ex:
raise MPRestError(str(ex))

def _submit_requests( # noqa
def _submit_requests(
self,
url,
criteria,
use_document_model,
chunk_size,
num_chunks=None,
timeout=None,
url: str,
criteria: dict[str, Any],
use_document_model: bool,
chunk_size: int | None,
num_chunks: int | None = None,
timeout: int | None = None,
max_batch_size: int = 100,
norecur: bool = False,
) -> dict:
"""Handle submitting requests sequentially with pagination.

If criteria contains comma-separated parameters (except those that are naturally comma-separated),
split them into multiple sequential requests and combine results.

Arguments:
criteria: dictionary of criteria to filter down
url: url used to make request
use_document_model: if None, will defer to the self.use_document_model attribute
num_chunks: Maximum number of chunks of data to yield. None will yield all possible.
chunk_size: Number of data entries per chunk.
timeout: Time in seconds to wait until a request timeout error is thrown
url (str): url used to make request
criteria (dict of str): dictionary of criteria to filter down
use_document_model (bool): whether to use the document model
num_chunks (int or None): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int or None): Number of data entries per chunk.
timeout (int or None): Time in seconds to wait until a request timeout error is thrown
max_batch_size (int) : Maximum size of a batch when retrieving batches in parallel
norecur (bool) : Whether to forbid recursive splitting of a query field
when a direct query fails

Returns:
Dictionary containing data and metadata
Expand Down Expand Up @@ -884,14 +889,6 @@ def _submit_requests( # noqa
timeout=timeout,
)

# Check if we got 0 results - some parameters are silently ignored by the API
# when passed as comma-separated values, so we need to split them anyway
if total_num_docs == 0 and len(split_values) > 1:
# Treat this the same as a 422 error - split into batches
raise MPRestError(
"Got 0 results for comma-separated parameter, will try splitting"
)

# If successful, continue with normal pagination
data_chunks = [data["data"]]
total_data: dict[str, Any] = {"data": []}
Expand All @@ -903,18 +900,26 @@ def _submit_requests( # noqa
# Continue with pagination if needed (handled below)

except MPRestError as e:
# If we get 422 or 414 error, or 0 results for comma-separated params, split into batches
if any(trace in str(e) for trace in ("422", "414", "Got 0 results")):
# If we get 422 or 414 error, split into batches
if not norecur and any(
trace in str(e)
for trace in (
"422",
"414",
)
):
total_data = {"data": []}
total_num_docs = 0
data_chunks = []

# Batch the split values to reduce number of requests
# Use batches of up to 100 values to balance URL length and request count
batch_size = min(100, max(1, len(split_values) // 10))
num_batches = min(
max_batch_size, max(1, len(split_values) // max_batch_size)
)
batch_size = min(len(split_values), max_batch_size)

# Setup progress bar for split parameter requests
num_batches = ceil(len(split_values) / batch_size)
pbar_message = f"Retrieving {len(split_values)} {split_param} values in {num_batches} batches"
pbar = (
tqdm(
Expand All @@ -938,6 +943,7 @@ def _submit_requests( # noqa
chunk_size=chunk_size,
num_chunks=num_chunks,
timeout=timeout,
norecur=len(batch) <= max_batch_size,
)

data_chunks.append(result["data"])
Expand Down Expand Up @@ -979,6 +985,12 @@ def _submit_requests( # noqa
if "meta" in data:
total_data["meta"] = data["meta"]

# otherwise, paginate sequentially
if chunk_size is None or chunk_size < 1:
raise ValueError(
"A positive chunk size must be provided to enable pagination"
)

# Get max number of response pages
max_pages = (
num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)
Expand All @@ -998,7 +1010,7 @@ def _submit_requests( # noqa
desc=pbar_message,
total=num_docs_needed,
)
if not self.mute_progress_bars
if not self.mute_progress_bars and total_num_docs > 0
else None
)

Expand All @@ -1018,10 +1030,6 @@ def _submit_requests( # noqa
pbar.close()
return new_total_data

# otherwise, paginate sequentially
if chunk_size is None:
raise ValueError("A chunk size must be provided to enable pagination")

# Warning to select specific fields only for many results
if criteria.get("_all_fields", False) and (total_num_docs / chunk_size > 10):
warnings.warn(
Expand Down
27 changes: 24 additions & 3 deletions mp_api/client/routes/materials/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from collections import defaultdict
from itertools import chain, product

from emmet.core.summary import HasProps, SummaryDoc
from emmet.core.symmetry import CrystalSystem
Expand Down Expand Up @@ -200,8 +201,9 @@ def search( # noqa: D417
mmnd_inv = {v: k for k, v in min_max_name_dict.items() if k != v}

# Set user query params from `locals`
_locals = locals()
user_settings = {
k: v for k, v in locals().items() if k in min_max_name_dict and v
k: v for k, v in _locals.items() if k in min_max_name_dict and v
}

# Check to see if user specified _search fields using **kwargs,
Expand Down Expand Up @@ -328,10 +330,11 @@ def _csrc(x):
"spacegroup_number": 230,
"spacegroup_symbol": 230,
}
batched_symm_query = {}
for k, cardinality in symm_cardinality.items():
if isinstance(symm_vals := locals().get(k), list | tuple | set):
if isinstance(symm_vals := _locals.get(k), list | tuple | set):
if len(symm_vals) < cardinality // 2:
query_params.update({k: ",".join(str(v) for v in symm_vals)})
batched_symm_query[k] = symm_vals
else:
raise MPRestError(
f"Querying `{k}` by a list of values is only "
Expand Down Expand Up @@ -378,6 +381,24 @@ def _csrc(x):
if query_params[entry] is not None
}

if batched_symm_query:
ordered_symm_key = sorted(batched_symm_query)
return list(
chain.from_iterable(
self._search( # type: ignore[return-value]
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params,
**{sk: symm_params[i] for i, sk in enumerate(ordered_symm_key)},
)
for symm_params in product(
*[batched_symm_query[k] for k in ordered_symm_key]
)
)
)

return super()._search( # type: ignore[return-value]
num_chunks=num_chunks,
chunk_size=chunk_size,
Expand Down
Loading