diff --git a/src/bibx/__init__.py b/src/bibx/__init__.py index 1f33671..cfb5d5f 100644 --- a/src/bibx/__init__.py +++ b/src/bibx/__init__.py @@ -26,7 +26,7 @@ "read_wos", ] -__version__ = "0.4.0" +__version__ = "0.4.1" def query_openalex( diff --git a/src/bibx/builders/openalex.py b/src/bibx/builders/openalex.py index 5b3ac9e..902c72d 100644 --- a/src/bibx/builders/openalex.py +++ b/src/bibx/builders/openalex.py @@ -1,4 +1,5 @@ import logging +from collections import Counter from enum import Enum from typing import Optional from urllib.parse import urlparse @@ -11,11 +12,14 @@ logger = logging.getLogger(__name__) +MAX_REFERENCES = 400 + class HandleReferences(Enum): """How to handle references when building an openalex collection.""" BASIC = "basic" + COMMON = "common" FULL = "full" @@ -39,14 +43,22 @@ def build(self) -> Collection: logger.info("building collection for query %s", self.query) works = self.client.list_recent_articles(self.query, self.limit) cache = {work.id: work for work in works} + references: list[str] = [] + for work in works: + references.extend(work.referenced_works) + if self.references == HandleReferences.COMMON: + counter = Counter(references) + most_common = {key for key, _ in counter.most_common(MAX_REFERENCES)} + missing = most_common - set(cache.keys()) + logger.info("fetching %d missing references", len(missing)) + missing_works = self.client.list_articles_by_openalex_id(list(missing)) + cache.update({work.id: work for work in missing_works}) if self.references == HandleReferences.FULL: - references: list[str] = [] - for work in works: - references.extend(work.referenced_works) missing = set(references) - set(cache.keys()) logger.info("fetching %d missing references", len(missing)) missing_works = self.client.list_articles_by_openalex_id(list(missing)) cache.update({work.id: work for work in missing_works}) + article_cache = { openalexid: self._work_to_article(work) for openalexid, work in cache.items() diff --git a/src/bibx/cli.py b/src/bibx/cli.py index 323bfd1..453b6e4 100644 --- a/src/bibx/cli.py +++ b/src/bibx/cli.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Callable from enum import Enum from typing import TextIO @@ -83,8 +84,14 @@ def openalex( help="how to handle references", default=HandleReferences.BASIC, ), + verbose: bool = typer.Option( + help="be more verbose", + default=False, + ), ) -> None: """Run the sap algorithm on a seed file of any supported format.""" + if verbose: + logging.basicConfig(level=logging.INFO) c = query_openalex(" ".join(query), references=references) s = Sap() graph = s.create_graph(c) diff --git a/src/bibx/clients/openalex.py b/src/bibx/clients/openalex.py index ea39096..2d4981f 100644 --- a/src/bibx/clients/openalex.py +++ b/src/bibx/clients/openalex.py @@ -1,4 +1,5 @@ import logging +from concurrent.futures import ThreadPoolExecutor, as_completed, wait from enum import Enum from typing import Optional, Union @@ -122,6 +123,18 @@ def __init__( } ) + def _fetch_works(self, params: dict[str, Union[str, int]]) -> WorkResponse: + response = self.session.get( + f"{self.base_url}/works", + params=params, + ) + try: + response.raise_for_status() + data = response.json() + return WorkResponse.model_validate(data) + except (requests.RequestException, ValidationError) as error: + raise OpenAlexError(str(error)) from error + def list_recent_articles(self, query: str, limit: int = 600) -> list[Work]: """List recent articles from the openalex API.""" select = ",".join(Work.model_fields.keys()) @@ -134,56 +147,48 @@ def list_recent_articles(self, query: str, limit: int = 600) -> list[Work]: ) pages = (limit // MAX_WORKS_PER_PAGE) + 1 results: list[Work] = [] - for page in range(1, pages + 1): - logger.info("fetching page %d with filter %s", page, filter_) - params: dict[str, Union[str, int]] = { - "select": select, - "filter": filter_, - "sort": "publication_year:desc", - "per_page": MAX_WORKS_PER_PAGE, - "page": page, - } - response = self.session.get( - f"{self.base_url}/works", - params=params, - ) - try: - response.raise_for_status() - data = response.json() - work_response = WorkResponse.model_validate(data) - logger.info( - "fetched %d works in page %d", len(work_response.results), page + with ThreadPoolExecutor(max_workers=min(pages, 25)) as executor: + futures = [ + executor.submit( + self._fetch_works, + { + "select": select, + "filter": filter_, + "sort": "publication_year:desc", + "per_page": MAX_WORKS_PER_PAGE, + "page": page, + }, ) + for page in range(1, pages + 1) + ] + wait(futures) + for future in futures: + work_response = future.result() results.extend(work_response.results) - if page * MAX_WORKS_PER_PAGE >= min(work_response.meta.count, limit): + if len(results) >= limit: break - except (requests.RequestException, ValidationError) as error: - raise OpenAlexError(str(error)) from error return results[:limit] def list_articles_by_openalex_id(self, ids: list[str]) -> list[Work]: """List articles by openalex id.""" select = ",".join(Work.model_fields.keys()) - filter_ = ",".join([f"ids.openalex:{id_}" for id_ in ids]) results: list[Work] = [] - for ids_ in chunks(ids, MAX_IDS_PER_REQUEST): - value = "|".join(ids_) - filter_ = f"ids.openalex:{value},type:types/article" - logger.info("fetching %d ids from openalex", len(ids_)) - params: dict[str, Union[str, int]] = { - "select": select, - "filter": filter_, - "per_page": MAX_IDS_PER_REQUEST, - } - response = self.session.get( - f"{self.base_url}/works", - params=params, - ) - try: - response.raise_for_status() - data = response.json() - work_response = WorkResponse.model_validate(data) + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [ + executor.submit( + self._fetch_works, + { + "select": select, + "filter": f"ids.openalex:{'|'.join(ids)},type:types/article", + "per_page": MAX_IDS_PER_REQUEST, + }, + ) + for ids in chunks(ids, MAX_IDS_PER_REQUEST) + ] + for future in as_completed(futures): + work_response = future.result() + logger.info( + "got %s works from the openalex api", len(work_response.results) + ) results.extend(work_response.results) - except (requests.RequestException, ValidationError) as error: - raise OpenAlexError(str(error)) from error return results