diff --git a/scottypy/app.py b/scottypy/app.py index 971d235..6a7d677 100644 --- a/scottypy/app.py +++ b/scottypy/app.py @@ -1,5 +1,6 @@ from __future__ import print_function +import asyncio import json import logging import os @@ -102,7 +103,7 @@ def link(beam_id_or_tag: str, url: str, storage_base: str, dest: str) -> None: for beam in scotty.get_beams_by_tag(tag): _link_beam(storage_base, beam, os.path.join(dest, str(beam.id))) else: - beam = scotty.get_beam(beam_id_or_tag) + beam = asyncio.run(scotty.get_beam(beam_id_or_tag)) if dest is None: dest = beam_id_or_tag @@ -132,21 +133,23 @@ def _list(beam: "Beam") -> None: for beam in scotty.get_beams_by_tag(tag): _list(beam) else: - _list(scotty.get_beam(beam_id_or_tag)) + _list(asyncio.get_event_loop().run_until_complete(scotty.get_beam(beam_id_or_tag))) -def _download_beam(beam: "Beam", dest: str, overwrite: bool, filter: str) -> None: +async def _download_beam(beam: "Beam", dest: str, overwrite: bool, filter: str) -> None: if not os.path.isdir(dest): os.makedirs(dest) click.echo("Downloading beam {} to directory {}".format(beam.id, dest)) + tasks = [] for file_ in beam.get_files(filter_=filter): click.echo("Downloading {}".format(file_.file_name)) try: - file_.download(dest, overwrite=overwrite) + tasks.append(file_.download(dest, overwrite=overwrite)) except NotOverwriting as e: click.echo("{} already exists. Use --overwrite to overwrite".format(e.file)) + await asyncio.gather(*tasks) _write_beam_info(beam, dest) @@ -176,18 +179,27 @@ def down( To download an entire tag specify t:[tag_name] as an argument, replacing [tag_name] with the name of the tag""" scotty = Scotty(url) + beams = [] if beam_id_or_tag.startswith("t:"): tag = beam_id_or_tag[2:] - if dest is None: - dest = tag + dest = dest or tag for beam in scotty.get_beams_by_tag(tag): - _download_beam(beam, os.path.join(dest, str(beam.id)), overwrite, filter) + beams.append( + _download_beam( + beam, os.path.join(dest, str(beam.id)), overwrite, filter + ) + ) + else: - beam = scotty.get_beam(beam_id_or_tag) - if dest is None: - dest = beam_id_or_tag - _download_beam(beam, dest, overwrite, filter) + beam = await scotty.get_beam(beam_id_or_tag) + dest = dest or beam_id_or_tag + beams.append( + _download_beam(beam, os.path.join(dest, str(beam.id)), overwrite, filter) + ) + + loop = asyncio.get_event_loop() + loop.run_until_complete(asyncio.gather(*beams)) @main.group() @@ -317,5 +329,5 @@ def set_comment(beam_id: int, url: str, comment: str) -> None: """Set a comment for the specified beam""" scotty = Scotty(url) - beam = scotty.get_beam(beam_id) + beam = asyncio.get_event_loop().run_until_complete(scotty.get_beam(beam_id)) beam.set_comment(comment) diff --git a/scottypy/exc.py b/scottypy/exc.py index 86fdc77..19b0603 100644 --- a/scottypy/exc.py +++ b/scottypy/exc.py @@ -7,3 +7,12 @@ class NotOverwriting(Exception): def __init__(self, file_: str): super(NotOverwriting, self).__init__() self.file = file_ + + +class HTTPError(Exception): + def __init__(self, *, url, code, text): + super().__init__( + "Server responded {code} when accessing {url}:\n{text}".format( + code=code, url=url, text=text + ) + ) diff --git a/scottypy/file.py b/scottypy/file.py index f5f56fa..9fb8974 100644 --- a/scottypy/file.py +++ b/scottypy/file.py @@ -4,6 +4,8 @@ import dateutil.parser +import aiohttp + from .exc import NotOverwriting from .types import JSON from .utils import fix_path_sep_for_current_platform, raise_for_status @@ -66,15 +68,18 @@ def from_json(cls, session: "Session", json_node: JSON) -> "File": mtime, ) - def stream_to(self, fileobj: "typing.BinaryIO") -> None: - """Fetch the file content from the server and write it to fileobj""" - response = self._session.get(self.url, stream=True) - raise_for_status(response) + async def fetch(self, session, url): + return await session.get(url) - for chunk in response.iter_content(chunk_size=_CHUNK_SIZE): - fileobj.write(chunk) + async def stream_to(self, fileobj: "typing.BinaryIO") -> None: + """Fetch the file content from the server and write it to fileobj""" + async with aiohttp.ClientSession() as session: + response = await self.fetch(session, self.url) + response.raise_for_status() + async for chunk in response.content.iter_chunked(_CHUNK_SIZE): + fileobj.write(chunk) - def download(self, directory: str = ".", overwrite: bool = False) -> None: + async def download(self, directory: str = ".", overwrite: bool = False) -> None: """Download the file to the specified directory, retaining its name""" subdir, file_ = os.path.split(fix_path_sep_for_current_platform(self.file_name)) subdir = os.path.join(directory, subdir) @@ -90,7 +95,7 @@ def download(self, directory: str = ".", overwrite: bool = False) -> None: raise NotOverwriting(file_) with open(file_, "wb") as f: - self.stream_to(f) + await self.stream_to(f) if self.mtime is not None: mtime = _to_epoch(self.mtime) diff --git a/scottypy/scotty.py b/scottypy/scotty.py index 55ee33c..fa27bd6 100644 --- a/scottypy/scotty.py +++ b/scottypy/scotty.py @@ -1,4 +1,5 @@ import abc +import asyncio import errno import json import logging @@ -18,11 +19,14 @@ from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry +import aiohttp +from aiohttp import ClientSession + from .beam import Beam from .exc import PathNotExists from .file import File from .types import JSON -from .utils import raise_for_status +from .utils import execute_http, raise_for_status _SLEEP_TIME = 10 _NUM_OF_RETRIES = (60 // _SLEEP_TIME) * 15 @@ -146,8 +150,6 @@ def __init__(self, url: str, retry_times: int = 3, backoff_factor: int = 2): def prefetch_combadge( self, combadge_version: str = _DEFAULT_COMBADGE_VERSION ) -> None: - """Prefetch the combadge to a temporary file. Future beams will use that combadge - instead of having to re-download it.""" self._get_combadge(combadge_version=combadge_version) def remove_combadge(self) -> None: @@ -354,7 +356,7 @@ def remove_tag(self, beam_id: int, tag: str) -> None: ) raise_for_status(response) - def get_beam(self, beam_id: typing.Union[str, int]) -> "Beam": + async def get_beam(self, beam_id: typing.Union[str, int]) -> "Beam": """Retrieve details about the specified beam. :param int beam_id: Beam ID or tag @@ -397,14 +399,20 @@ def get_beams_by_tag(self, tag: str) -> typing.List[Beam]: :param str tag: The name of the tag. :return: a list of :class:`.Beam` objects. """ - response = self._session.get( "{0}/beams?tag={1}".format(self._url, tag), timeout=_TIMEOUT ) - raise_for_status(response) + response.raise_for_status() + + beams_id_list = (b["id"] for b in response.json()["beams"]) + beams = [] + for beam_id in beams_id_list: + beams.append(self.get_beam(beam_id)) + + loop = asyncio.get_event_loop() + beams_result = loop.run_until_complete(asyncio.gather(*beams)) - ids = (b["id"] for b in response.json()["beams"]) - return [self.get_beam(id_) for id_ in ids] + return beams_result def sanity_check(self) -> None: """Check if this instance of Scotty is functioning. Raise an exception if something's wrong""" diff --git a/scottypy/utils.py b/scottypy/utils.py index 60033ed..dbf3ff1 100644 --- a/scottypy/utils.py +++ b/scottypy/utils.py @@ -1,7 +1,19 @@ +import asyncio +import logging import os +from typing import Any, Dict, Optional import requests +import aiohttp +import yarl +from aiohttp import ClientSession +from tenacity import retry, stop_after_attempt, stop_after_delay + +from .exc import HTTPError + +logger = logging.getLogger("scotty") # type: logging.Logger + def raise_for_status(response: requests.Response) -> None: if 400 <= response.status_code < 500: @@ -27,3 +39,39 @@ def raise_for_status(response: requests.Response) -> None: def fix_path_sep_for_current_platform(file_name: str) -> str: return file_name.replace("\\", os.path.sep).replace("/", os.path.sep) + + +class AsyncRequestHelper: + def __init__(self): + self._loop = asyncio.get_event_loop() + self._session = aiohttp.ClientSession( + loop=self._loop, + timeout=aiohttp.ClientTimeout(total=30), + headers={"Accept-Encoding": "gzip", "Content-Type": "application/json"}, + ) + + @retry(stop=(stop_after_delay(5) | stop_after_attempt(3))) + async def execute_http( + self, + url: yarl.URL, + *, + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + method = "GET" if data is None else "POST" + logger.info("Async Calling {} {}", method, url) + async with self._session.request( + method, url, params=params, json=data + ) as response: + if response.status != 200: + raise HTTPError( + url=url, code=response.status, text=await response.text() + ) + return await response.json() + + def __del__(self): + self._loop.run_until_complete(self._session.close()) + + +_async_request_helper = AsyncRequestHelper() +execute_http = _async_request_helper.execute_http diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..5beb5d4 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,5 @@ +slash +scotty +aiohttp +asyncio +tenacity \ No newline at end of file