diff --git a/mp_api/_test_utils.py b/mp_api/_test_utils.py new file mode 100644 index 000000000..5d4044c90 --- /dev/null +++ b/mp_api/_test_utils.py @@ -0,0 +1,120 @@ +"""Define testing utils that need to imported.""" + +# pragma: exclude file + +from __future__ import annotations + +try: + import pytest +except ImportError as exc: + raise ImportError( + "You must `pip install 'mp-api[test]' to use these testing utilities." + ) from exc + +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing import Any + +requires_api_key = pytest.mark.skipif( + os.getenv("MP_API_KEY") is None, + reason="No API key found.", +) + +NUM_DOCS = 5 + + +def client_search_testing( + search_method: Callable, + excluded_params: list[str], + alt_name_dict: dict[str, str], + custom_field_tests: dict[str, Any], + sub_doc_fields: list[str], + int_bounds: tuple[int, int] = (-100, 100), + float_bounds: tuple[float, float] = (-100.12, 100.12), +): + """Function to test a client using its search method. + Each parameter is used to query for data, which is then checked. + + Args: + search_method (Callable): Client search method + excluded_params (list[str]): List of parameters to exclude from testing + alt_name_dict (dict[str, str]): Alternative names for parameters used in the projection and subsequent data checking + custom_field_tests (dict[str, Any]): Custom queries for specific fields. + sub_doc_fields (list[str]): Prefixes for fields to check in resulting data. Useful when data to be tested is nested. + int_bounds (tuple[int,int]) : integer bounds to use in testing int-type query arguments + float_bounds (tuple[float,float]) : float bounds to use in testing float-type query arguments + """ + if search_method is None: + return + # Get list of parameters + param_tuples = list(search_method.__annotations__.items()) + + # Query API for each numeric and boolean parameter and check if returned + for entry in param_tuples: + param = entry[0] + + if param not in excluded_params + ["return"]: + param_type = entry[1] + q: dict[str, Any] = {"chunk_size": 1, "num_chunks": 1} + + if "tuple[int, int]" in param_type: + q[param] = int_bounds + elif "tuple[float, float]" in param_type: + q[param] = float_bounds + elif "bool" in param_type: + q[param] = False + elif param in custom_field_tests: + q[param] = custom_field_tests[param] + else: + raise ValueError( + f"Parameter '{param}' with type '{param_type}' was not " + "properly identified in the generic search method test." + ) + + if len(docs := search_method(**q)) > 0: + doc = docs[0].model_dump() + else: + raise ValueError("No documents returned") + + for sub_field in sub_doc_fields: + if sub_field in doc: + doc = doc[sub_field] + + assert doc[alt_name_dict.get(param, param)] is not None + + +def client_pagination(search_method: Callable, id_name: str): + page_1 = search_method(_page=1, chunk_size=NUM_DOCS, fields=[id_name]) + page_2 = search_method(_page=2, chunk_size=NUM_DOCS, fields=[id_name]) + assert all(len(results) == NUM_DOCS for results in (page_1, page_2)) + assert {str(getattr(doc, id_name)) for doc in page_1}.intersection( + {str(getattr(doc, id_name)) for doc in page_2} + ) == set() + + +def client_sort(search_method: Callable, sort_fields: str | Sequence[str]): + for sort_field in [sort_fields] if isinstance(sort_fields, str) else sort_fields: + asc = search_method( + _page=1, _sort_fields=sort_field, chunk_size=NUM_DOCS, fields=[sort_field] + ) + desc = search_method( + _page=1, + _sort_fields=f"-{sort_field}", + chunk_size=NUM_DOCS, + fields=[sort_field], + ) + + idxs = list(range(NUM_DOCS)) + assert sorted(idxs, key=lambda idx: getattr(asc[idx], sort_field)) == idxs + + assert ( + sorted( + idxs, + key=lambda idx: getattr(desc[idx], sort_field), + reverse=True, + ) + == idxs + ) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index e2127b224..2dace9203 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -965,6 +965,10 @@ def _submit_requests( # noqa # No splitting needed - get first page total_data = {"data": []} initial_criteria = copy(criteria) + if isinstance( + initial_criteria.get("_page"), int + ) and not initial_criteria.get("_per_page"): + initial_criteria["_per_page"] = initial_criteria.get("_limit") data, total_num_docs = self._submit_request_and_process( url=url, verify=True, @@ -1438,6 +1442,9 @@ def _search( # This method should be customized for each end point to give more user friendly, # documented kwargs. + # If user specifies page, ensure only one chunk is returned + if isinstance(kwargs.get("_page"), int) and num_chunks is None: + num_chunks = 1 return self._get_all_documents( kwargs, all_fields=all_fields, diff --git a/mp_api/client/routes/materials/electrodes.py b/mp_api/client/routes/materials/electrodes.py index 71a469d62..1a63b0234 100644 --- a/mp_api/client/routes/materials/electrodes.py +++ b/mp_api/client/routes/materials/electrodes.py @@ -15,7 +15,7 @@ class BaseElectrodeRester(BaseRester): primary_key = "battery_id" _exclude_search_fields: list[str] | None = None - def search( # pragma: ignore + def search( self, battery_ids: str | list[str] | None = None, average_voltage: tuple[float, float] | None = None, @@ -39,6 +39,8 @@ def search( # pragma: ignore chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + _page: int | None = None, + _sort_fields: str | None = None, ) -> list[InsertionElectrodeDoc | ConversionElectrodeDoc] | list[dict]: """Query using a variety of search criteria. @@ -77,63 +79,45 @@ def search( # pragma: ignore all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in InsertionElectrodeDoc or ConversionElectrodeDoc to return data for. Default is battery_id and last_updated if all_fields is False. + _page (int or None) : Page of the results to skip to. + _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. Returns: ([InsertionElectrodeDoc or ConversionElectrodeDoc], [dict]) List of insertion/conversion electrode documents or dictionaries. """ query_params: dict = defaultdict(dict) - if battery_ids: - if isinstance(battery_ids, str): - battery_ids = [battery_ids] - - query_params.update({"battery_ids": ",".join(validate_ids(battery_ids))}) - - if working_ion: - if isinstance(working_ion, (str, Element)): - working_ion = [working_ion] # type: ignore - - query_params.update( - {"working_ion": ",".join([str(ele) for ele in working_ion])} # type: ignore - ) - - if formula: - if isinstance(formula, str): - formula = [formula] - - query_params.update({"formula": ",".join(formula)}) - - if elements: - query_params.update({"elements": ",".join(elements)}) - - if num_elements: - if isinstance(num_elements, int): - num_elements = (num_elements, num_elements) - query_params.update( - {"nelements_min": num_elements[0], "nelements_max": num_elements[1]} - ) - - if exclude_elements: - query_params.update({"exclude_elements": ",".join(exclude_elements)}) - for param, value in locals().items(): - if ( - param - not in [ - "__class__", - "self", - "working_ion", - "query_params", - "num_elements", - ] - and value - ): - if isinstance(value, tuple): + if param not in {"__class__", "self", "query_params"} and value is not None: + if param == "num_elements": # this must come first + if isinstance(num_elements, int): + num_elements = (num_elements, num_elements) + query_params.update( + { + "nelements_min": num_elements[0], # type: ignore[index] + "nelements_max": num_elements[1], # type: ignore[index] + } + ) + + elif isinstance(value, tuple): query_params.update( {f"{param}_min": value[0], f"{param}_max": value[1]} ) + elif param == "battery_ids": + query_params[param] = ",".join(validate_ids(value)) + elif param == "working_ion": + query_params["working_ion"] = ",".join( + str(ele) + for ele in ( + [value] if isinstance(value, str | Element) else value + ) + ) + elif param in ("formula", "elements", "exclude_elements"): + query_params[param] = ",".join( + [value] if isinstance(value, str) else value + ) else: - query_params.update({param: value}) + query_params[param] = value excluded_fields = self._exclude_search_fields or [] ignored_fields = { @@ -177,4 +161,6 @@ class ConversionElectrodeRester(BaseElectrodeRester): "stability_charge", "stability_discharge", "exclude_elements", + "_page", + "_sort_fields", ] diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index 1769781ea..ee5a2dac6 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -73,6 +73,8 @@ def search( # noqa: D417 chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + _page: int | None = None, + _sort_fields: str | None = None, **kwargs, ) -> list[SummaryDoc] | list[dict]: """Query core data using a variety of search criteria. @@ -150,6 +152,8 @@ def search( # noqa: D417 all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in SummaryDoc to return data for. Default is material_id if all_fields is False. + _page (int or None) : Page of the results to skip to. + _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. Returns: ([SummaryDoc], [dict]) List of SummaryDoc documents or dictionaries. @@ -181,6 +185,8 @@ def search( # noqa: D417 "weighted_surface_energy", "weighted_work_function", "shape_factor", + "_page", + "_sort_fields", ] min_max_name_dict = { @@ -200,8 +206,9 @@ def search( # noqa: D417 mmnd_inv = {v: k for k, v in min_max_name_dict.items() if k != v} # Set user query params from `locals` + _locals = locals() user_settings = { - k: v for k, v in locals().items() if k in min_max_name_dict and v + k: v for k, v in _locals.items() if k in min_max_name_dict and v is not None } # Check to see if user specified _search fields using **kwargs, @@ -284,14 +291,17 @@ def _csrc(x): ) for param, value in user_settings.items(): - if isinstance(value, (int, float)): - value = (value, value) - query_params.update( - { - f"{min_max_name_dict[param]}_min": value[0], - f"{min_max_name_dict[param]}_max": value[1], - } - ) + if param in {"_page", "_sort_fields"}: + query_params[param] = value + else: + if isinstance(value, (int, float)): + value = (value, value) + query_params.update( + { + f"{min_max_name_dict[param]}_min": value[0], + f"{min_max_name_dict[param]}_max": value[1], + } + ) if material_ids: if isinstance(material_ids, str): @@ -299,29 +309,15 @@ def _csrc(x): query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) - if deprecated is not None: - query_params.update({"deprecated": deprecated}) - - if formula: - if isinstance(formula, str): - formula = [formula] - - query_params.update({"formula": ",".join(formula)}) - - if chemsys: - if isinstance(chemsys, str): - chemsys = [chemsys] - - query_params.update({"chemsys": ",".join(chemsys)}) - - if elements: - query_params.update({"elements": ",".join(elements)}) - - if exclude_elements is not None: - query_params.update({"exclude_elements": ",".join(exclude_elements)}) - - if possible_species is not None: - query_params.update({"possible_species": ",".join(possible_species)}) + for k in ( + "formula", + "chemsys", + "elements", + "exclude_elements", + "possible_species", + ): + if (v := _locals.get(k)) is not None: + query_params[k] = ",".join([v] if isinstance(v, str) else v) symm_cardinality = { "crystal_system": 7, @@ -341,21 +337,20 @@ def _csrc(x): else: query_params.update({k: symm_vals}) - if is_stable is not None: - query_params.update({"is_stable": is_stable}) - - if is_gap_direct is not None: - query_params.update({"is_gap_direct": is_gap_direct}) - - if is_metal is not None: - query_params.update({"is_metal": is_metal}) + for k in ( + "deprecated", + "is_stable", + "is_gap_direct", + "is_metal", + "has_reconstructed", + "theoretical", + ): + if (v := _locals.get(k)) is not None: + query_params[k] = v if magnetic_ordering: query_params.update({"ordering": magnetic_ordering.value}) - if has_reconstructed is not None: - query_params.update({"has_reconstructed": has_reconstructed}) - if has_props: has_props_clean = [] for prop in has_props: @@ -366,9 +361,6 @@ def _csrc(x): query_params.update({"has_props": ",".join(has_props_clean)}) - if theoretical is not None: - query_params.update({"theoretical": theoretical}) - if not include_gnome: query_params.update({"batch_id_not_eq": "gnome_r2scan_statics"}) diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index 0a8efa36f..fc71ae7ff 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -6,7 +6,6 @@ from pymatgen.core.periodic_table import Element from mp_api.client.core import BaseRester -from mp_api.client.core.utils import validate_ids if TYPE_CHECKING: from typing import Any @@ -33,6 +32,8 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + _page: int | None = None, + _sort_fields: str | None = None, ): """Query core XAS docs using a variety of search criteria. @@ -54,14 +55,18 @@ def search( all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in MaterialsCoreDoc to return data for. Default is material_id, last_updated, and formula_pretty if all_fields is False. + _page (int or None) : Page of the results to skip to. + _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. Returns: ([MaterialsDoc]) List of material documents """ - query_params: dict[str, Any] = {} - - if edge: - query_params.update({"edge": edge}) + _locals = locals() + query_params: dict[str, Any] = { + k: _locals[k] + for k in ("edge", "spectrum_type", "formula", "_page", "_sort_fields") + if _locals.get(k) is not None + } if absorbing_element: query_params.update( @@ -71,33 +76,9 @@ def search( else absorbing_element } ) - - if spectrum_type: - query_params.update({"spectrum_type": spectrum_type}) - - if formula: - query_params.update({"formula": formula}) - - if chemsys: - if isinstance(chemsys, str): - chemsys = [chemsys] - - query_params.update({"chemsys": ",".join(chemsys)}) - - if elements: - query_params.update({"elements": ",".join(elements)}) - - if material_ids: - if isinstance(material_ids, str): - material_ids = [material_ids] - - query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) - - if spectrum_ids: - if isinstance(spectrum_ids, str): - spectrum_ids = [spectrum_ids] - - query_params.update({"spectrum_ids": ",".join(spectrum_ids)}) + for k in ("chemsys", "elements", "material_ids", "spectrum_ids"): + if (v := _locals.get(k)) is not None: + query_params[k] = ",".join([v] if isinstance(v, str) else v) query_params = { entry: query_params[entry] diff --git a/mp_api/client/routes/molecules/jcesr.py b/mp_api/client/routes/molecules/jcesr.py index 79f3e55aa..2d462c193 100644 --- a/mp_api/client/routes/molecules/jcesr.py +++ b/mp_api/client/routes/molecules/jcesr.py @@ -40,8 +40,12 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + _page: int | None = None, + _sort_fields: str | None = None, ): - """Query equations of state docs using a variety of search criteria. + """Query legacy molecule docs using a variety of search criteria. + + JCESR = Joint Center for Energy Storage Research Arguments: task_ids (str, List[str]): A single molecule task ID string or list of strings. @@ -59,6 +63,8 @@ def search( all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in MoleculesDoc to return data for. Default is the material_id only if all_fields is False. + _page (int or None) : Page of the results to skip to. + _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. Returns: ([MoleculesDoc]) List of molecule documents @@ -74,25 +80,14 @@ def search( if elements: query_params.update({"elements": ",".join([str(ele) for ele in elements])}) - if pointgroup: - query_params.update({"pointgroup": pointgroup}) - - if smiles: - query_params.update({"smiles": smiles}) - - if nelements: - query_params.update( - {"nelements_min": nelements[0], "nelements_max": nelements[1]} - ) - - if EA: - query_params.update({"EA_min": EA[0], "EA_max": EA[1]}) - - if IE: - query_params.update({"IE_min": IE[0], "IE_max": IE[1]}) + _locals = locals() + for k in ("pointgroup", "smiles", "_page", "_sort_fields"): + if (v := _locals.get(k)) is not None: + query_params[k] = v - if charge: - query_params.update({"charge_min": charge[0], "charge_max": charge[1]}) + for k in ("nelements", "EA", "IE", "charge"): + if (vals := _locals.get(k)) is not None: + query_params.update({f"{k}_min": vals[0], f"{k}_max": vals[1]}) query_params = { entry: query_params[entry] diff --git a/mp_api/client/routes/molecules/summary.py b/mp_api/client/routes/molecules/summary.py index 471fd84a8..4be3aab58 100644 --- a/mp_api/client/routes/molecules/summary.py +++ b/mp_api/client/routes/molecules/summary.py @@ -28,6 +28,8 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + _page: int | None = None, + _sort_fields: str | None = None, ): """Query core data using a variety of search criteria. @@ -50,6 +52,8 @@ def search( all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in SearchDoc to return data for. Default is material_id if all_fields is False. + _page (int or None) : Page of the results to skip to. + _sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order. Returns: ([MoleculeSummaryDoc]) List of molecules summary documents @@ -80,29 +84,14 @@ def search( molecule_ids = [molecule_ids] query_params.update({"molecule_ids": ",".join(molecule_ids)}) - if charge: - query_params.update({"charge": charge}) + _locals = locals() + for k in ("charge", "spin_multiplicity", "_page", "_sort_fields"): + if (v := _locals.get(k)) is not None: + query_params[k] = v - if spin_multiplicity: - query_params.update({"spin_multiplicity": spin_multiplicity}) - - if formula: - if isinstance(formula, str): - formula = [formula] - - query_params.update({"formula": ",".join(formula)}) - - if chemsys: - if isinstance(chemsys, str): - chemsys = [chemsys] - - query_params.update({"chemsys": ",".join(chemsys)}) - - if elements: - query_params.update({"elements": ",".join(elements)}) - - if exclude_elements is not None: - query_params.update({"exclude_elements": ",".join(exclude_elements)}) + for k in ("formula", "chemsys", "elements", "exclude_elements"): + if (v := _locals.get(k)) is not None: + query_params[k] = ",".join([v] if isinstance(v, str) else v) if has_props: query_params.update({"has_props": ",".join([i.value for i in has_props])}) diff --git a/pyproject.toml b/pyproject.toml index 3a06a1ccb..105eb4814 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,3 +122,9 @@ isort.required-imports = ["from __future__ import annotations"] [tool.mypy] namespace_packages = true ignore_missing_imports = true + +[tool.coverage.report] +exclude_also = [ + # A pragma comment that excludes an entire file: + '\A(?s:.*# pragma: exclude file.*)\Z', +] diff --git a/tests/client/conftest.py b/tests/client/conftest.py deleted file mode 100644 index ab70586ea..000000000 --- a/tests/client/conftest.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING - -import pytest - -if TYPE_CHECKING: - from collections.abc import Callable - from typing import Any - -requires_api_key = pytest.mark.skipif( - os.getenv("MP_API_KEY") is None, - reason="No API key found.", -) - - -def client_search_testing( - search_method: Callable, - excluded_params: list[str], - alt_name_dict: dict[str, str], - custom_field_tests: dict[str, Any], - sub_doc_fields: list[str], -): - """ - Function to test a client using its search method. - Each parameter is used to query for data, which is then checked. - - Args: - search_method (Callable): Client search method - excluded_params (list[str]): List of parameters to exclude from testing - alt_name_dict (dict[str, str]): Alternative names for parameters used in the projection and subsequent data checking - custom_field_tests (dict[str, Any]): Custom queries for specific fields. - sub_doc_fields (list[str]): Prefixes for fields to check in resulting data. Useful when data to be tested is nested. - """ - if search_method is None: - return - # Get list of parameters - param_tuples = list(search_method.__annotations__.items()) - - # Query API for each numeric and boolean parameter and check if returned - for entry in param_tuples: - param = entry[0] - - if param not in excluded_params + ["return"]: - param_type = entry[1] - q = None - - if "tuple[int, int]" in param_type: - project_field = alt_name_dict.get(param, None) - q = { - param: (-100, 100), - "chunk_size": 1, - "num_chunks": 1, - } - elif "tuple[float, float]" in param_type: - project_field = alt_name_dict.get(param, None) - q = { - param: (-100.12, 100.12), - "chunk_size": 1, - "num_chunks": 1, - } - elif "bool" in param_type: - project_field = alt_name_dict.get(param, None) - q = { - param: False, - "chunk_size": 1, - "num_chunks": 1, - } - elif param in custom_field_tests: - project_field = alt_name_dict.get(param, None) - q = { - param: custom_field_tests[param], - "chunk_size": 1, - "num_chunks": 1, - } - - if q is None: - raise ValueError( - f"Parameter '{param}' with type '{param_type}' was not " - "properly identified in the generic search method test." - ) - doc = search_method(**q)[0].model_dump() - - for sub_field in sub_doc_fields: - if sub_field in doc: - doc = doc[sub_field] - - assert ( - doc[project_field if project_field is not None else param] is not None - ) diff --git a/tests/client/materials/test_absorption.py b/tests/client/materials/test_absorption.py index cf6c04735..dda6cbcf7 100644 --- a/tests/client/materials/test_absorption.py +++ b/tests/client/materials/test_absorption.py @@ -4,11 +4,11 @@ from emmet.core.phonon import PhononBS, PhononDOS +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.core import MPRestError from mp_api.client.routes.materials.absorption import AbsorptionRester -from ..conftest import client_search_testing, requires_api_key - @requires_api_key def test_absorption_search(): diff --git a/tests/client/materials/test_alloys.py b/tests/client/materials/test_alloys.py index 8de85791b..92ef4a9d7 100644 --- a/tests/client/materials/test_alloys.py +++ b/tests/client/materials/test_alloys.py @@ -1,7 +1,3 @@ -import os - -import pytest - try: import pymatgen.analysis.alloys except ImportError: @@ -10,9 +6,12 @@ allow_module_level=True, ) -from mp_api.client.routes.materials.alloys import AlloysRester +import os +import pytest -from ..conftest import client_search_testing, requires_api_key +from mp_api._test_utils import client_search_testing, requires_api_key + +from mp_api.client.routes.materials.alloys import AlloysRester @requires_api_key diff --git a/tests/client/materials/test_bonds.py b/tests/client/materials/test_bonds.py index 7f520157c..4be6161cf 100644 --- a/tests/client/materials/test_bonds.py +++ b/tests/client/materials/test_bonds.py @@ -1,6 +1,8 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest + +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.bonds import BondsRester diff --git a/tests/client/materials/test_chemenv.py b/tests/client/materials/test_chemenv.py index 72e0e17f7..77e693c07 100644 --- a/tests/client/materials/test_chemenv.py +++ b/tests/client/materials/test_chemenv.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.chemenv import ChemenvRester diff --git a/tests/client/materials/test_dielectric.py b/tests/client/materials/test_dielectric.py index a37f580fc..ee7c6671d 100644 --- a/tests/client/materials/test_dielectric.py +++ b/tests/client/materials/test_dielectric.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.dielectric import DielectricRester diff --git a/tests/client/materials/test_doi.py b/tests/client/materials/test_doi.py index 72240fbfb..61b3ce8e2 100644 --- a/tests/client/materials/test_doi.py +++ b/tests/client/materials/test_doi.py @@ -1,6 +1,6 @@ -from mp_api.client.routes.materials.doi import DOIRester +from mp_api._test_utils import client_search_testing, requires_api_key -from ..conftest import client_search_testing, requires_api_key +from mp_api.client.routes.materials.doi import DOIRester @requires_api_key diff --git a/tests/client/materials/test_elasticity.py b/tests/client/materials/test_elasticity.py index 0a48c3e7a..cbdc8e566 100644 --- a/tests/client/materials/test_elasticity.py +++ b/tests/client/materials/test_elasticity.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.elasticity import ElasticityRester diff --git a/tests/client/materials/test_electrodes.py b/tests/client/materials/test_electrodes.py index bc6e00ef4..e6a53008f 100644 --- a/tests/client/materials/test_electrodes.py +++ b/tests/client/materials/test_electrodes.py @@ -1,9 +1,15 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest from pymatgen.core.periodic_table import Element +from mp_api._test_utils import ( + client_search_testing, + client_pagination, + client_sort, + requires_api_key, +) + from mp_api.client.routes.materials.electrodes import ( ElectrodeRester, ConversionElectrodeRester, @@ -30,6 +36,8 @@ def conversion_rester(): "num_chunks", "all_fields", "fields", + "_page", + "_sort_fields", ] sub_doc_fields: list = [] @@ -80,3 +88,19 @@ def test_conversion_client(conversion_rester): }, sub_doc_fields=sub_doc_fields, ) + + +@requires_api_key +def test_pagination(): + with ElectrodeRester() as rester: + client_pagination(rester.search, "battery_id") + + +@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) +@requires_api_key +@pytest.mark.parametrize( + "sort_field", ["battery_id", "stability_charge", "average_voltage"] +) +def test_sort(sort_field): + with ElectrodeRester() as rester: + client_sort(rester.search, sort_field) diff --git a/tests/client/materials/test_electronic_structure.py b/tests/client/materials/test_electronic_structure.py index 3ab67e621..a89cc730f 100644 --- a/tests/client/materials/test_electronic_structure.py +++ b/tests/client/materials/test_electronic_structure.py @@ -1,9 +1,9 @@ -from ..conftest import client_search_testing, requires_api_key - import pytest from pymatgen.analysis.magnetism import Ordering from typing import Any +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.core.exceptions import MPRestError from mp_api.client.routes.materials.electronic_structure import ( BandStructureRester, diff --git a/tests/client/materials/test_eos.py b/tests/client/materials/test_eos.py index 8bccbac61..3e633e49b 100644 --- a/tests/client/materials/test_eos.py +++ b/tests/client/materials/test_eos.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.eos import EOSRester diff --git a/tests/client/materials/test_grain_boundary.py b/tests/client/materials/test_grain_boundary.py index 6719efe4b..2eb034cd3 100644 --- a/tests/client/materials/test_grain_boundary.py +++ b/tests/client/materials/test_grain_boundary.py @@ -1,9 +1,10 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest from emmet.core.grain_boundary import GBTypeEnum +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.grain_boundaries import GrainBoundaryRester diff --git a/tests/client/materials/test_magnetism.py b/tests/client/materials/test_magnetism.py index 0417aa388..add6453df 100644 --- a/tests/client/materials/test_magnetism.py +++ b/tests/client/materials/test_magnetism.py @@ -1,9 +1,10 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest from pymatgen.analysis.magnetism import Ordering +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.magnetism import MagnetismRester diff --git a/tests/client/materials/test_materials.py b/tests/client/materials/test_materials.py index 67f410200..ccda72c18 100644 --- a/tests/client/materials/test_materials.py +++ b/tests/client/materials/test_materials.py @@ -1,9 +1,10 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest from emmet.core.symmetry import CrystalSystem +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.materials import MaterialsRester from mp_api.client.routes.materials import MATERIALS_RESTERS diff --git a/tests/client/materials/test_oxidation_states.py b/tests/client/materials/test_oxidation_states.py index 042b35116..7274ec7b3 100644 --- a/tests/client/materials/test_oxidation_states.py +++ b/tests/client/materials/test_oxidation_states.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.oxidation_states import OxidationStatesRester diff --git a/tests/client/materials/test_phonon.py b/tests/client/materials/test_phonon.py index 0b5aae754..1beb1978e 100644 --- a/tests/client/materials/test_phonon.py +++ b/tests/client/materials/test_phonon.py @@ -5,11 +5,11 @@ from emmet.core.phonon import PhononBS, PhononDOS +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.core.exceptions import MPRestError from mp_api.client.routes.materials.phonon import PhononRester -from ..conftest import client_search_testing, requires_api_key - @requires_api_key def test_phonon_search(): diff --git a/tests/client/materials/test_piezo.py b/tests/client/materials/test_piezo.py index 4f0228d80..8bc1a4734 100644 --- a/tests/client/materials/test_piezo.py +++ b/tests/client/materials/test_piezo.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.piezo import PiezoRester diff --git a/tests/client/materials/test_provenance.py b/tests/client/materials/test_provenance.py index 796a32d6d..9a460c7ed 100644 --- a/tests/client/materials/test_provenance.py +++ b/tests/client/materials/test_provenance.py @@ -1,8 +1,8 @@ import os -from ..conftest import client_search_testing, requires_api_key - import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.provenance import ProvenanceRester diff --git a/tests/client/materials/test_robocrys.py b/tests/client/materials/test_robocrys.py index bee970ed7..d72c65358 100644 --- a/tests/client/materials/test_robocrys.py +++ b/tests/client/materials/test_robocrys.py @@ -2,9 +2,9 @@ import pytest -from mp_api.client.routes.materials.robocrys import RobocrysRester +from mp_api._test_utils import requires_api_key -from ..conftest import requires_api_key +from mp_api.client.routes.materials.robocrys import RobocrysRester @pytest.fixture diff --git a/tests/client/materials/test_similarity.py b/tests/client/materials/test_similarity.py index 6fcf7f4ae..072f30764 100644 --- a/tests/client/materials/test_similarity.py +++ b/tests/client/materials/test_similarity.py @@ -6,10 +6,11 @@ from emmet.core.similarity import SimilarityScorer, SimilarityEntry from pymatgen.core import Structure +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.core import MPRestError from mp_api.client.routes.materials.similarity import SimilarityRester -from ..conftest import client_search_testing, requires_api_key try: import matminer diff --git a/tests/client/materials/test_substrates.py b/tests/client/materials/test_substrates.py index 06558598b..65ab40d01 100644 --- a/tests/client/materials/test_substrates.py +++ b/tests/client/materials/test_substrates.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.substrates import SubstratesRester diff --git a/tests/client/materials/test_summary.py b/tests/client/materials/test_summary.py index 9d5b63989..1d9083423 100644 --- a/tests/client/materials/test_summary.py +++ b/tests/client/materials/test_summary.py @@ -1,10 +1,16 @@ import os -from ..conftest import client_search_testing, requires_api_key +from mp_api._test_utils import ( + client_search_testing, + requires_api_key, + client_pagination, + client_sort, +) -import pytest from emmet.core.summary import HasProps from emmet.core.symmetry import CrystalSystem +import numpy as np from pymatgen.analysis.magnetism import Ordering +import pytest from mp_api.client.routes.materials.summary import SummaryRester from mp_api.client.core.exceptions import MPRestWarning, MPRestError @@ -16,6 +22,8 @@ "num_chunks", "all_fields", "fields", + "_page", + "_sort_fields", ] alt_name_dict: dict = { @@ -134,3 +142,23 @@ def test_warning_messages(): with pytest.raises(MPRestError, match="not a valid property"): _ = search_method(num_elements=10, has_props=["apples"]) + + +@requires_api_key +def test_pagination(): + with SummaryRester() as rester: + client_pagination(rester.search, "material_id") + + +summary_sort_fields = [ + "formation_energy_per_atom", + "energy_above_hull", + "band_gap", +] + + +@requires_api_key +@pytest.mark.parametrize("sort_field", summary_sort_fields) +def test_sort(sort_field: str): + with SummaryRester() as rester: + client_sort(rester.search, sort_field) diff --git a/tests/client/materials/test_surface_properties.py b/tests/client/materials/test_surface_properties.py index bd66c00d0..cda19654f 100644 --- a/tests/client/materials/test_surface_properties.py +++ b/tests/client/materials/test_surface_properties.py @@ -1,8 +1,9 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.surface_properties import SurfacePropertiesRester diff --git a/tests/client/materials/test_synthesis.py b/tests/client/materials/test_synthesis.py index 7ec4d9c07..05254e811 100644 --- a/tests/client/materials/test_synthesis.py +++ b/tests/client/materials/test_synthesis.py @@ -1,11 +1,12 @@ import os from typing import List - +from pydantic import BaseModel import pytest from emmet.core.synthesis import SynthesisRecipe, SynthesisTypeEnum +from mp_api._test_utils import requires_api_key + from mp_api.client.routes.materials.synthesis import SynthesisRester -from ..conftest import requires_api_key @pytest.fixture @@ -34,10 +35,13 @@ def test_filters_keywords(rester): search_method = rester.search if search_method is not None: - doc = search_method(keywords=["silicon"])[0] + doc = search_method(keywords=["silicon"], chunk_size=100)[0] assert isinstance(doc.search_score, float) - highlighted = sum((x["texts"] for x in doc.highlights), []) + highlights = [ + x.model_dump() if isinstance(x, BaseModel) else x for x in doc.highlights + ] + highlighted = sum((x["texts"] for x in highlights), []) assert "silicon" in " ".join([x["value"] for x in highlighted]).lower() diff --git a/tests/client/materials/test_tasks.py b/tests/client/materials/test_tasks.py index ea064989c..3173268aa 100644 --- a/tests/client/materials/test_tasks.py +++ b/tests/client/materials/test_tasks.py @@ -5,9 +5,9 @@ from emmet.core.trajectory import RelaxTrajectory from emmet.core.utils import utcnow -from mp_api.client.routes.materials.tasks import TaskRester +from mp_api._test_utils import client_search_testing, requires_api_key -from ..conftest import client_search_testing, requires_api_key +from mp_api.client.routes.materials.tasks import TaskRester @pytest.fixture diff --git a/tests/client/materials/test_thermo.py b/tests/client/materials/test_thermo.py index b64e65726..83885440d 100644 --- a/tests/client/materials/test_thermo.py +++ b/tests/client/materials/test_thermo.py @@ -1,10 +1,11 @@ import os -from ..conftest import client_search_testing, requires_api_key import pytest from emmet.core.types.enums import ThermoType from pymatgen.analysis.phase_diagram import PhaseDiagram +from mp_api._test_utils import client_search_testing, requires_api_key + from mp_api.client.routes.materials.thermo import ThermoRester diff --git a/tests/client/materials/test_xas.py b/tests/client/materials/test_xas.py index c31a5d9f0..850d6b8c9 100644 --- a/tests/client/materials/test_xas.py +++ b/tests/client/materials/test_xas.py @@ -1,11 +1,16 @@ -from ..conftest import client_search_testing, requires_api_key - import pytest from typing import Any from emmet.core.types.enums import XasEdge, XasType from pymatgen.core.periodic_table import Element +from mp_api._test_utils import ( + client_search_testing, + client_pagination, + client_sort, + requires_api_key, +) + from mp_api.client.routes.materials.xas import XASRester @@ -59,3 +64,19 @@ def test_client(rester): custom_field_tests=custom_field_tests, sub_doc_fields=sub_doc_fields, ) + + +@requires_api_key +def test_pagination(): + with XASRester() as rester: + client_pagination(rester.search, "spectrum_id") + + +@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) +@requires_api_key +@pytest.mark.parametrize( + "sort_field", ["material_id", "absorbing_element", "spectrum_id"] +) +def test_sort(sort_field): + with XASRester() as rester: + client_sort(rester.search, sort_field) diff --git a/tests/client/molecules/core_function.py b/tests/client/molecules/core_function.py deleted file mode 100644 index 557cafb3b..000000000 --- a/tests/client/molecules/core_function.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations -from typing import Callable, Any - - -def client_search_testing( - search_method: Callable, - excluded_params: list[str], - alt_name_dict: dict[str, str], - custom_field_tests: dict[str, Any], - sub_doc_fields: list[str], -): - """ - Function to test a client using its search method. - Each parameter is used to query for data, which is then checked. - - Args: - search_method (Callable): Client search method - excluded_params (list[str]): List of parameters to exclude from testing - alt_name_dict (dict[str, str]): Alternative names for parameters used in the projection and subsequent data checking - custom_field_tests (dict[str, Any]): Custom queries for specific fields. - sub_doc_fields (list[str]): Prefixes for fields to check in resulting data. Useful when data to be tested is nested. - """ - if search_method is not None: - # Get list of parameters - param_tuples = list(search_method.__annotations__.items()) - # Query API for each numeric and boolean parameter and check if returned - for entry in param_tuples: - param = entry[0] - if param not in excluded_params: - param_type = entry[1] - q = None - - if param in custom_field_tests: - q = { - param: custom_field_tests[param], - "chunk_size": 1, - "num_chunks": 1, - } - elif "tuple[int, int]" in param_type: - q = { - param: (-100, 100), - "chunk_size": 1, - "num_chunks": 1, - } - elif "tuple[float, float]" in param_type: - q = { - param: (-3000.12, 3000.12), - "chunk_size": 1, - "num_chunks": 1, - } - elif param_type is bool: - q = { - param: False, - "chunk_size": 1, - "num_chunks": 1, - } - - docs = search_method(**q) - - if len(docs) > 0: - doc = docs[0].model_dump() - else: - raise ValueError("No documents returned") - - assert doc[alt_name_dict.get(param, param)] is not None diff --git a/tests/client/molecules/test_jcesr.py b/tests/client/molecules/test_jcesr.py index c3c828858..deb8bf21d 100644 --- a/tests/client/molecules/test_jcesr.py +++ b/tests/client/molecules/test_jcesr.py @@ -1,5 +1,10 @@ import os -from .core_function import client_search_testing +from mp_api._test_utils import ( + client_search_testing, + client_pagination, + client_sort, + requires_api_key, +) import pytest from pymatgen.core.periodic_table import Element @@ -7,8 +12,6 @@ from mp_api.client.core.exceptions import MPRestWarning from mp_api.client.routes.molecules.jcesr import JcesrMoleculesRester -from ..conftest import requires_api_key - @pytest.fixture def rester(): @@ -24,6 +27,8 @@ def rester(): "all_fields", "fields", "charge", + "_page", + "_sort_fields", ] sub_doc_fields: list = [] @@ -48,9 +53,24 @@ def test_client(rester): alt_name_dict=alt_name_dict, custom_field_tests=custom_field_tests, sub_doc_fields=sub_doc_fields, + float_bounds=(-3000.12, 3000.12), ) def test_warning(): with pytest.warns(MPRestWarning, match="unmaintained legacy molecules"): JcesrMoleculesRester() + + +@requires_api_key +def test_pagination(): + with JcesrMoleculesRester() as rester: + client_pagination(rester.search, "task_id") + + +@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) +@requires_api_key +@pytest.mark.parametrize("sort_field", ["task_id", "IE", "EA"]) +def test_sort(sort_field): + with JcesrMoleculesRester() as rester: + client_sort(rester.search, sort_field) diff --git a/tests/client/molecules/test_summary.py b/tests/client/molecules/test_summary.py index 3e9ddd2ed..910adf474 100644 --- a/tests/client/molecules/test_summary.py +++ b/tests/client/molecules/test_summary.py @@ -1,13 +1,18 @@ import os -from .core_function import client_search_testing import pytest from emmet.core.molecules.summary import HasProps from emmet.core.mpid import MPculeID +from mp_api._test_utils import ( + client_search_testing, + client_pagination, + client_sort, + requires_api_key, +) from mp_api.client.routes.molecules.summary import MoleculesSummaryRester -from ..conftest import requires_api_key +num_docs = 5 excluded_params = [ "sort_fields", @@ -16,6 +21,8 @@ "all_fields", "fields", "exclude_elements", + "_page", + "_sort_fields", ] alt_name_dict = {"formula": "formula_alphabetical", "molecule_ids": "molecule_id"} @@ -51,3 +58,17 @@ def test_client(): custom_field_tests=custom_field_tests, sub_doc_fields=[], ) + + +@requires_api_key +def test_pagination(): + with MoleculesSummaryRester() as rester: + client_pagination(rester.search, "molecule_id") + + +@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False) +@requires_api_key +@pytest.mark.parametrize("sort_field", ["molecule_id", "charge", "spin_multiplicity"]) +def test_sort(sort_field): + with MoleculesSummaryRester() as rester: + client_sort(rester.search, sort_field) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index c02f63d28..38eec6676 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,9 +2,9 @@ import pytest +from mp_api._test_utils import requires_api_key from mp_api.client import MPRester -from .conftest import requires_api_key try: import pymatgen.analysis.alloys as pmg_alloys diff --git a/tests/client/test_core_client.py b/tests/client/test_core_client.py index 89c83c674..4842d44db 100644 --- a/tests/client/test_core_client.py +++ b/tests/client/test_core_client.py @@ -2,13 +2,13 @@ import json +from mp_api._test_utils import requires_api_key + from mp_api.client import MPRester from mp_api.client.core import BaseRester from mp_api.client.core.exceptions import MPRestError, MPRestWarning from mp_api.client.routes.materials.materials import MaterialsRester -from .conftest import requires_api_key - @pytest.fixture def rester(): diff --git a/tests/client/test_heartbeat.py b/tests/client/test_heartbeat.py index 3b17eabed..52c4cfb8b 100644 --- a/tests/client/test_heartbeat.py +++ b/tests/client/test_heartbeat.py @@ -2,9 +2,9 @@ import pytest from unittest.mock import patch, Mock -import mp_api.client.mprester +from mp_api._test_utils import requires_api_key -from .conftest import requires_api_key +import mp_api.client.mprester @pytest.fixture diff --git a/tests/client/test_mprester.py b/tests/client/test_mprester.py index 38efd6e02..9b5dade24 100644 --- a/tests/client/test_mprester.py +++ b/tests/client/test_mprester.py @@ -34,10 +34,11 @@ from pymatgen.io.cif import CifParser from pymatgen.io.vasp import Chgcar +from mp_api._test_utils import requires_api_key + from mp_api.client import MPRester from mp_api.client.core import MPRestError, MPRestWarning -from .conftest import requires_api_key try: import mpcontribs.client as contribs_client