From ea916050e774427999eee728ee5394f2684c44ba Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Wed, 1 Apr 2026 16:19:54 -0700 Subject: [PATCH 01/10] feat: add CancellationToken support to polling infrastructure - Create CancellationToken class with sync/async event support - Update poll_until() and async_poll_until() with cancellation_token parameter - Add cancellable sleep using threading.Event.wait() and asyncio.wait_for() - Update PollingRequestOptions TypedDict with cancellation_token field - Propagate cancellation_token through Blueprint and ScenarioRun polling methods Part of porting TypeScript PR #765 features to Python SDK. --- src/runloop_api_client/lib/cancellation.py | 103 ++++++++++++++++++ src/runloop_api_client/lib/polling.py | 16 ++- src/runloop_api_client/lib/polling_async.py | 21 +++- .../resources/blueprints.py | 30 ++++- .../resources/scenarios/runs.py | 29 ++++- src/runloop_api_client/sdk/_types.py | 4 + 6 files changed, 197 insertions(+), 6 deletions(-) create mode 100644 src/runloop_api_client/lib/cancellation.py diff --git a/src/runloop_api_client/lib/cancellation.py b/src/runloop_api_client/lib/cancellation.py new file mode 100644 index 000000000..3e31a76b8 --- /dev/null +++ b/src/runloop_api_client/lib/cancellation.py @@ -0,0 +1,103 @@ +"""Cancellation support for polling operations.""" + +from __future__ import annotations + +import asyncio +import threading +from typing import TYPE_CHECKING + +from .._exceptions import RunloopError + +if TYPE_CHECKING: + pass + +__all__ = ["PollingCancelled", "CancellationToken"] + + +class PollingCancelled(RunloopError): + """Exception raised when a polling operation is cancelled.""" + + pass + + +class CancellationToken: + """Thread-safe cancellation token for polling operations. + + Similar to JavaScript's AbortSignal. Works in both sync and async contexts. + + Example (sync): + >>> token = CancellationToken() + >>> # In another thread: + >>> token.cancel() + >>> # In polling code: + >>> token.raise_if_cancelled() # Raises PollingCancelled + + Example (async): + >>> token = CancellationToken() + >>> # In another task: + >>> token.cancel() + >>> # In async polling code: + >>> await asyncio.wait_for(token.async_event.wait(), timeout=1.0) + """ + + def __init__(self) -> None: + """Create a new cancellation token.""" + self._cancelled = False + self._sync_event = threading.Event() + self._async_event: asyncio.Event | None = None + self._lock = threading.Lock() + + def cancel(self) -> None: + """Mark this token as cancelled. + + Thread-safe and can be called multiple times. Sets both sync and async events. + """ + with self._lock: + if self._cancelled: + return + self._cancelled = True + self._sync_event.set() + if self._async_event is not None: + self._async_event.set() + + def is_cancelled(self) -> bool: + """Check if this token has been cancelled. + + Returns: + True if cancel() has been called, False otherwise. + """ + return self._cancelled + + def raise_if_cancelled(self) -> None: + """Raise PollingCancelled if this token has been cancelled. + + Raises: + PollingCancelled: If cancel() has been called. + """ + if self._cancelled: + raise PollingCancelled("Polling operation was cancelled") + + @property + def sync_event(self) -> threading.Event: + """Get the synchronous event for cancellation checking. + + Returns: + threading.Event that is set when cancel() is called. + """ + return self._sync_event + + @property + def async_event(self) -> asyncio.Event: + """Get the asynchronous event for cancellation checking. + + Lazily creates the async event on first access. If cancel() was already called, + the event will be set immediately. + + Returns: + asyncio.Event that is set when cancel() is called. + """ + if self._async_event is None: + self._async_event = asyncio.Event() + if self._cancelled: + self._async_event.set() + return self._async_event diff --git a/src/runloop_api_client/lib/polling.py b/src/runloop_api_client/lib/polling.py index 899d2a9bf..95440f235 100644 --- a/src/runloop_api_client/lib/polling.py +++ b/src/runloop_api_client/lib/polling.py @@ -2,6 +2,8 @@ from typing import Any, TypeVar, Callable, Optional from dataclasses import dataclass +from .cancellation import CancellationToken + T = TypeVar("T") @@ -27,6 +29,7 @@ def poll_until( is_terminal: Callable[[T], bool], config: Optional[PollingConfig] = None, on_error: Optional[Callable[[Exception], T]] = None, + cancellation_token: Optional[CancellationToken] = None, ) -> T: """ Poll until a condition is met or timeout/max attempts are reached. @@ -37,12 +40,14 @@ def poll_until( config: Optional polling configuration on_error: Optional error handler that can return a value to continue polling or re-raise the exception to stop polling + cancellation_token: Optional token to cancel the polling operation Returns: The final state of the polled object Raises: PollingTimeout: When max attempts or timeout is reached + PollingCancelled: If cancellation_token.cancel() is called """ if config is None: config = PollingConfig() @@ -52,6 +57,10 @@ def poll_until( last_result = None while True: + # Check for cancellation before each iteration + if cancellation_token is not None: + cancellation_token.raise_if_cancelled() + try: last_result = retriever() except Exception as e: @@ -72,4 +81,9 @@ def poll_until( if elapsed >= config.timeout_seconds: raise PollingTimeout(f"Exceeded timeout of {config.timeout_seconds} seconds", last_result) - time.sleep(config.interval_seconds) + # Cancellable sleep + if cancellation_token is not None: + if cancellation_token.sync_event.wait(timeout=config.interval_seconds): + cancellation_token.raise_if_cancelled() + else: + time.sleep(config.interval_seconds) diff --git a/src/runloop_api_client/lib/polling_async.py b/src/runloop_api_client/lib/polling_async.py index 7ba192e86..df3ffbc46 100644 --- a/src/runloop_api_client/lib/polling_async.py +++ b/src/runloop_api_client/lib/polling_async.py @@ -3,6 +3,7 @@ from typing import Union, TypeVar, Callable, Optional, Awaitable from .polling import PollingConfig, PollingTimeout +from .cancellation import CancellationToken T = TypeVar("T") @@ -12,6 +13,7 @@ async def async_poll_until( is_terminal: Callable[[T], bool], config: Optional[PollingConfig] = None, on_error: Optional[Callable[[Exception], T]] = None, + cancellation_token: Optional[CancellationToken] = None, ) -> T: """ Poll until a condition is met or timeout/max attempts are reached. @@ -22,12 +24,14 @@ async def async_poll_until( config: Optional polling configuration on_error: Optional error handler that can return a value to continue polling or re-raise the exception to stop polling + cancellation_token: Optional token to cancel the polling operation Returns: The final state of the polled object Raises: PollingTimeout: When max attempts or timeout is reached + PollingCancelled: If cancellation_token.cancel() is called """ if config is None: config = PollingConfig() @@ -37,6 +41,10 @@ async def async_poll_until( last_result: Union[T, None] = None while True: + # Check for cancellation before each iteration + if cancellation_token is not None: + cancellation_token.raise_if_cancelled() + try: last_result = await retriever() except Exception as e: @@ -57,4 +65,15 @@ async def async_poll_until( if elapsed >= config.timeout_seconds: raise PollingTimeout(f"Exceeded timeout of {config.timeout_seconds} seconds", last_result) - await asyncio.sleep(config.interval_seconds) + # Cancellable async sleep + if cancellation_token is not None: + try: + await asyncio.wait_for( + cancellation_token.async_event.wait(), + timeout=config.interval_seconds, + ) + cancellation_token.raise_if_cancelled() + except asyncio.TimeoutError: + pass # Normal sleep completion + else: + await asyncio.sleep(config.interval_seconds) diff --git a/src/runloop_api_client/resources/blueprints.py b/src/runloop_api_client/resources/blueprints.py index 7e5b09939..94a40bb85 100644 --- a/src/runloop_api_client/resources/blueprints.py +++ b/src/runloop_api_client/resources/blueprints.py @@ -29,6 +29,7 @@ from ..lib.polling import PollingConfig, poll_until from .._base_client import AsyncPaginator, make_request_options from ..lib.polling_async import async_poll_until +from ..lib.cancellation import CancellationToken from .._utils._validation import ValidationNotification from ..types.blueprint_view import BlueprintView from ..types.blueprint_preview_view import BlueprintPreviewView @@ -41,6 +42,7 @@ # Type for request arguments that combine polling config with additional request options class BlueprintRequestArgs(TypedDict, total=False): polling_config: PollingConfig | None + cancellation_token: CancellationToken | None # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None @@ -280,6 +282,7 @@ def await_build_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -292,6 +295,7 @@ def await_build_complete( Args: id: The ID of the blueprint to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -302,6 +306,7 @@ def await_build_complete( Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ @@ -313,7 +318,12 @@ def retrieve_blueprint() -> BlueprintView: def is_done_building(blueprint: BlueprintView) -> bool: return blueprint.status not in ["queued", "building", "provisioning"] - blueprint = poll_until(retrieve_blueprint, is_done_building, polling_config) + blueprint = poll_until( + retrieve_blueprint, + is_done_building, + polling_config, + cancellation_token=cancellation_token, + ) if blueprint.status != "build_complete": raise RunloopError(f"Blueprint entered non-built terminal state: {blueprint.status}") @@ -338,6 +348,7 @@ def create_and_await_build_complete( services: Optional[Iterable[blueprint_create_params.Service]] | Omit = omit, system_setup_commands: Optional[SequenceNotStr[str]] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -353,12 +364,14 @@ def create_and_await_build_complete( Args: See the `create` method for detailed documentation. polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The built blueprint Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ # Pass all create_args to the underlying create method @@ -387,6 +400,7 @@ def create_and_await_build_complete( return self.await_build_complete( blueprint.id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, @@ -960,6 +974,7 @@ async def await_build_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -972,6 +987,7 @@ async def await_build_complete( Args: id: The ID of the blueprint to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -982,6 +998,7 @@ async def await_build_complete( Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ @@ -993,7 +1010,12 @@ async def retrieve_blueprint() -> BlueprintView: def is_done_building(blueprint: BlueprintView) -> bool: return blueprint.status not in ["queued", "building", "provisioning"] - blueprint = await async_poll_until(retrieve_blueprint, is_done_building, polling_config) + blueprint = await async_poll_until( + retrieve_blueprint, + is_done_building, + polling_config, + cancellation_token=cancellation_token, + ) if blueprint.status != "build_complete": raise RunloopError(f"Blueprint entered non-built terminal state: {blueprint.status}") @@ -1018,6 +1040,7 @@ async def create_and_await_build_complete( services: Optional[Iterable[blueprint_create_params.Service]] | Omit = omit, system_setup_commands: Optional[SequenceNotStr[str]] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -1033,12 +1056,14 @@ async def create_and_await_build_complete( Args: See the `create` method for detailed documentation. polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The built blueprint Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ # Pass all create_args to the underlying create method @@ -1067,6 +1092,7 @@ async def create_and_await_build_complete( return await self.await_build_complete( blueprint.id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, diff --git a/src/runloop_api_client/resources/scenarios/runs.py b/src/runloop_api_client/resources/scenarios/runs.py index 67c5c4428..0cdb04df4 100644 --- a/src/runloop_api_client/resources/scenarios/runs.py +++ b/src/runloop_api_client/resources/scenarios/runs.py @@ -28,6 +28,7 @@ from ..._base_client import AsyncPaginator, make_request_options from ...types.scenarios import run_list_params from ...lib.polling_async import async_poll_until +from ...lib.cancellation import CancellationToken from ...types.scenario_run_view import ScenarioRunView __all__ = ["RunsResource", "AsyncRunsResource"] @@ -325,6 +326,7 @@ def await_scored( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -337,6 +339,7 @@ def await_scored( Args: id: The ID of the scenario run to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -347,6 +350,7 @@ def await_scored( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ @@ -358,7 +362,12 @@ def retrieve_run() -> ScenarioRunView: def is_done_scoring(run: ScenarioRunView) -> bool: return run.state not in ["scoring"] - run = poll_until(retrieve_run, is_done_scoring, polling_config) + run = poll_until( + retrieve_run, + is_done_scoring, + polling_config, + cancellation_token=cancellation_token, + ) if run.state != "scored": raise RunloopError(f"Scenario run entered non-scored state unexpectedly: {run.state}") @@ -370,6 +379,7 @@ def score_and_await( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -382,6 +392,7 @@ def score_and_await( Args: id: The ID of the scenario run to score and wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -392,6 +403,7 @@ def score_and_await( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ self.score( @@ -405,6 +417,7 @@ def score_and_await( return self.await_scored( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, @@ -750,6 +763,7 @@ async def await_scored( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -762,6 +776,7 @@ async def await_scored( Args: id: The ID of the scenario run to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -772,6 +787,7 @@ async def await_scored( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ @@ -783,7 +799,12 @@ async def retrieve_run() -> ScenarioRunView: def is_done_scoring(run: ScenarioRunView) -> bool: return run.state not in ["scoring"] - run = await async_poll_until(retrieve_run, is_done_scoring, polling_config) + run = await async_poll_until( + retrieve_run, + is_done_scoring, + polling_config, + cancellation_token=cancellation_token, + ) if run.state != "scored": raise RunloopError(f"Scenario run entered non-scored state unexpectedly: {run.state}") @@ -795,6 +816,7 @@ async def score_and_await( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -807,6 +829,7 @@ async def score_and_await( Args: id: The ID of the scenario run to score and wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -817,6 +840,7 @@ async def score_and_await( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ await self.score( @@ -830,6 +854,7 @@ async def score_and_await( return await self.await_scored( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, diff --git a/src/runloop_api_client/sdk/_types.py b/src/runloop_api_client/sdk/_types.py index 4d2accc46..39a1b42de 100644 --- a/src/runloop_api_client/sdk/_types.py +++ b/src/runloop_api_client/sdk/_types.py @@ -40,6 +40,7 @@ ) from .._types import Body, Query, Headers, Timeout, NotGiven from ..lib.polling import PollingConfig +from ..lib.cancellation import CancellationToken from ..types.devboxes import DiskSnapshotListParams, DiskSnapshotUpdateParams from ..types.scenarios import ScorerListParams, ScorerCreateParams, ScorerUpdateParams from ..types.devbox_create_params import DevboxBaseCreateParams @@ -86,6 +87,9 @@ class PollingRequestOptions(BaseRequestOptions, total=False): polling_config: Optional[PollingConfig] """Configuration for polling behavior""" + cancellation_token: Optional[CancellationToken] + """Token to cancel polling operations""" + class LongPollingRequestOptions(LongRequestOptions, PollingRequestOptions): # type: ignore[misc] pass From 0d0f863fe8857dff7f127c9625700431afbbffa6 Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Wed, 1 Apr 2026 16:24:37 -0700 Subject: [PATCH 02/10] feat: propagate cancellation_token through all resource polling methods - Update Blueprints: await_build_complete, create_and_await_build_complete - Update ScenarioRuns: await_scored, score_and_await - Update Devboxes: await_running, await_suspended, create_and_await_running - Update DiskSnapshots: await_completed - Update Executions: await_completed - Update SDK wrappers: ScenarioRun.await_env_ready - Add comprehensive docstrings with PollingCancelled exception All methods support both sync and async variants. Part of porting TypeScript PR #765 features to Python SDK. --- .../resources/devboxes/devboxes.py | 50 +++++++++++++++++-- .../resources/devboxes/disk_snapshots.py | 45 ++++++++++++++++- .../resources/devboxes/executions.py | 22 +++++++- src/runloop_api_client/sdk/scenario_run.py | 6 ++- 4 files changed, 114 insertions(+), 9 deletions(-) diff --git a/src/runloop_api_client/resources/devboxes/devboxes.py b/src/runloop_api_client/resources/devboxes/devboxes.py index 3e46b17b6..dccc386c8 100644 --- a/src/runloop_api_client/resources/devboxes/devboxes.py +++ b/src/runloop_api_client/resources/devboxes/devboxes.py @@ -98,6 +98,7 @@ AsyncDiskSnapshotsResourceWithStreamingResponse, ) from ...lib.polling_async import async_poll_until +from ...lib.cancellation import CancellationToken from ...types.devbox_view import DevboxView from ...types.tunnel_view import TunnelView from ...types.shared_params.mount import Mount @@ -401,12 +402,14 @@ def await_running( *, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to be in running state. Args: id: The ID of the devbox to wait for config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -417,6 +420,7 @@ def await_running( Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ @@ -443,7 +447,13 @@ def handle_timeout_error(error: Exception) -> DevboxView: def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - devbox = poll_until(wait_for_devbox_status, is_done_booting, polling_config, handle_timeout_error) + devbox = poll_until( + wait_for_devbox_status, + is_done_booting, + polling_config, + handle_timeout_error, + cancellation_token=cancellation_token, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -455,18 +465,21 @@ def await_suspended( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to reach the suspended state. Args: id: The ID of the devbox to wait for. polling_config: Optional polling configuration. + cancellation_token: Token to cancel the wait operation. Returns: The devbox in the suspended state. Raises: PollingTimeout: If polling times out before the devbox is suspended. + PollingCancelled: If cancellation_token.cancel() is called. RunloopError: If the devbox enters a non-suspended terminal state. """ @@ -487,7 +500,13 @@ def handle_timeout_error(error: Exception) -> DevboxView: def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = poll_until(wait_for_devbox_status, is_terminal_state, polling_config, handle_timeout_error) + devbox = poll_until( + wait_for_devbox_status, + is_terminal_state, + polling_config, + handle_timeout_error, + cancellation_token=cancellation_token, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") @@ -510,6 +529,7 @@ def create_and_await_running( mounts: Optional[Iterable[Mount]] | Omit = omit, name: Optional[str] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, repo_connection_id: Optional[str] | Omit = omit, secrets: Optional[Dict[str, str]] | Omit = omit, snapshot_id: Optional[str] | Omit = omit, @@ -535,6 +555,7 @@ def create_and_await_running( Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ # Pass all create_args to the underlying create method @@ -565,6 +586,7 @@ def create_and_await_running( return self.await_running( devbox.id, polling_config=polling_config, + cancellation_token=cancellation_token, ) def list( @@ -2001,6 +2023,7 @@ async def create_and_await_running( mounts: Optional[Iterable[Mount]] | Omit = omit, name: Optional[str] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, repo_connection_id: Optional[str] | Omit = omit, secrets: Optional[Dict[str, str]] | Omit = omit, snapshot_id: Optional[str] | Omit = omit, @@ -2020,12 +2043,14 @@ async def create_and_await_running( Args: See the `create` method for detailed documentation. polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The devbox in running state Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ @@ -2057,6 +2082,7 @@ async def create_and_await_running( return await self.await_running( devbox.id, polling_config=polling_config, + cancellation_token=cancellation_token, ) async def await_running( @@ -2065,12 +2091,14 @@ async def await_running( *, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to be in running state. Args: id: The ID of the devbox to wait for config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -2081,6 +2109,7 @@ async def await_running( Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ @@ -2105,7 +2134,12 @@ async def wait_for_devbox_status() -> DevboxView: def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - devbox = await async_poll_until(wait_for_devbox_status, is_done_booting, polling_config) + devbox = await async_poll_until( + wait_for_devbox_status, + is_done_booting, + polling_config, + cancellation_token=cancellation_token, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -2117,18 +2151,21 @@ async def await_suspended( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to reach the suspended state. Args: id: The ID of the devbox to wait for. polling_config: Optional polling configuration. + cancellation_token: Token to cancel the wait operation. Returns: The devbox in the suspended state. Raises: PollingTimeout: If polling times out before the devbox is suspended. + PollingCancelled: If cancellation_token.cancel() is called. RunloopError: If the devbox enters a non-suspended terminal state. """ @@ -2147,7 +2184,12 @@ async def wait_for_devbox_status() -> DevboxView: def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = await async_poll_until(wait_for_devbox_status, is_terminal_state, polling_config) + devbox = await async_poll_until( + wait_for_devbox_status, + is_terminal_state, + polling_config, + cancellation_token=cancellation_token, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") diff --git a/src/runloop_api_client/resources/devboxes/disk_snapshots.py b/src/runloop_api_client/resources/devboxes/disk_snapshots.py index c4f723359..4b4cd3678 100644 --- a/src/runloop_api_client/resources/devboxes/disk_snapshots.py +++ b/src/runloop_api_client/resources/devboxes/disk_snapshots.py @@ -22,6 +22,7 @@ from ..._base_client import AsyncPaginator, make_request_options from ...types.devboxes import disk_snapshot_list_params, disk_snapshot_update_params from ...lib.polling_async import async_poll_until +from ...lib.cancellation import CancellationToken from ...types.devbox_snapshot_view import DevboxSnapshotView from ...types.devboxes.devbox_snapshot_async_status_view import DevboxSnapshotAsyncStatusView @@ -256,12 +257,31 @@ def await_completed( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DevboxSnapshotAsyncStatusView: - """Wait for a disk snapshot operation to complete.""" + """Wait for a disk snapshot operation to complete. + + Args: + id: The ID of the disk snapshot to wait for + polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation + extra_headers: Send extra headers + extra_query: Add additional query parameters to the request + extra_body: Add additional JSON properties to the request + timeout: Override the client-level default timeout for this request, in seconds + + Returns: + The completed snapshot status + + Raises: + PollingTimeout: If polling times out before snapshot completes + PollingCancelled: If cancellation_token.cancel() is called + RunloopError: If snapshot enters error state + """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") @@ -275,6 +295,7 @@ def is_terminal(result: DevboxSnapshotAsyncStatusView) -> bool: ), is_terminal, polling_config, + cancellation_token=cancellation_token, ) if status.status == "error": @@ -512,12 +533,31 @@ async def await_completed( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DevboxSnapshotAsyncStatusView: - """Wait asynchronously for a disk snapshot operation to complete.""" + """Wait asynchronously for a disk snapshot operation to complete. + + Args: + id: The ID of the disk snapshot to wait for + polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation + extra_headers: Send extra headers + extra_query: Add additional query parameters to the request + extra_body: Add additional JSON properties to the request + timeout: Override the client-level default timeout for this request, in seconds + + Returns: + The completed snapshot status + + Raises: + PollingTimeout: If polling times out before snapshot completes + PollingCancelled: If cancellation_token.cancel() is called + RunloopError: If snapshot enters error state + """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") @@ -531,6 +571,7 @@ def is_terminal(result: DevboxSnapshotAsyncStatusView) -> bool: ), is_terminal, polling_config, + cancellation_token=cancellation_token, ) if status.status == "error": diff --git a/src/runloop_api_client/resources/devboxes/executions.py b/src/runloop_api_client/resources/devboxes/executions.py index ff7638798..90896a198 100755 --- a/src/runloop_api_client/resources/devboxes/executions.py +++ b/src/runloop_api_client/resources/devboxes/executions.py @@ -33,6 +33,7 @@ execution_stream_stdout_updates_params, ) from ...lib.polling_async import async_poll_until +from ...lib.cancellation import CancellationToken from ...types.devbox_send_std_in_result import DevboxSendStdInResult from ...types.devbox_execution_detail_view import DevboxExecutionDetailView from ...types.devboxes.execution_update_chunk import ExecutionUpdateChunk @@ -124,6 +125,7 @@ def await_completed( *, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxAsyncExecutionDetailView: """Wait for an execution to complete. @@ -131,6 +133,7 @@ def await_completed( execution_id: The ID of the execution to wait for id: The ID of the devbox config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -141,6 +144,7 @@ def await_completed( Raises: PollingTimeout: If polling times out before execution completes + PollingCancelled: If cancellation_token.cancel() is called """ def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: @@ -165,7 +169,13 @@ def handle_timeout_error(error: Exception) -> DevboxAsyncExecutionDetailView: def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return poll_until(wait_for_execution_status, is_done, polling_config, handle_timeout_error) + return poll_until( + wait_for_execution_status, + is_done, + polling_config, + handle_timeout_error, + cancellation_token=cancellation_token, + ) def execute_async( self, @@ -670,6 +680,7 @@ async def await_completed( devbox_id: str, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxAsyncExecutionDetailView: """Wait for an execution to complete. @@ -677,6 +688,7 @@ async def await_completed( execution_id: The ID of the execution to wait for id: The ID of the devbox polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -687,6 +699,7 @@ async def await_completed( Raises: PollingTimeout: If polling times out before execution completes + PollingCancelled: If cancellation_token.cancel() is called """ async def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: @@ -707,7 +720,12 @@ async def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return await async_poll_until(wait_for_execution_status, is_done, polling_config) + return await async_poll_until( + wait_for_execution_status, + is_done, + polling_config, + cancellation_token=cancellation_token, + ) async def execute_async( self, diff --git a/src/runloop_api_client/sdk/scenario_run.py b/src/runloop_api_client/sdk/scenario_run.py index ede44b105..14dd3a4c6 100644 --- a/src/runloop_api_client/sdk/scenario_run.py +++ b/src/runloop_api_client/sdk/scenario_run.py @@ -106,7 +106,11 @@ def await_env_ready( :return: Scenario run state after environment is ready :rtype: ScenarioRunView """ - self._client.devboxes.await_running(self._devbox_id, polling_config=options.get("polling_config")) + self._client.devboxes.await_running( + self._devbox_id, + polling_config=options.get("polling_config"), + cancellation_token=options.get("cancellation_token"), + ) return self.get_info(**filter_params(options, BaseRequestOptions)) def score( From 6baf371c021eed0adcc016dfa22c3a4e25134d93 Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Wed, 1 Apr 2026 16:24:54 -0700 Subject: [PATCH 03/10] feat: complete SDK wrapper cancellation_token propagation - Update AsyncScenarioRun.await_env_ready to pass cancellation_token - Completes cancellation support across all SDK polling operations Part of porting TypeScript PR #765 features to Python SDK. --- src/runloop_api_client/sdk/async_scenario_run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/runloop_api_client/sdk/async_scenario_run.py b/src/runloop_api_client/sdk/async_scenario_run.py index 314de676f..ad7ef6c69 100644 --- a/src/runloop_api_client/sdk/async_scenario_run.py +++ b/src/runloop_api_client/sdk/async_scenario_run.py @@ -106,7 +106,11 @@ async def await_env_ready( :return: Scenario run state after environment is ready :rtype: ScenarioRunView """ - await self._client.devboxes.await_running(self._devbox_id, polling_config=options.get("polling_config")) + await self._client.devboxes.await_running( + self._devbox_id, + polling_config=options.get("polling_config"), + cancellation_token=options.get("cancellation_token"), + ) return await self.get_info(**filter_params(options, BaseRequestOptions)) async def score( From 19771bc8671e96199efcb93a3850113c040c3c79 Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Wed, 1 Apr 2026 16:32:29 -0700 Subject: [PATCH 04/10] feat: add SSE auto-reconnect for Axon subscriptions - Add AxonSubscribeSseParams with after_sequence parameter (internal use only) - Wrap subscribe_sse() with ReconnectingStream/AsyncReconnectingStream - Automatically resume from last received event on timeout using sequence-based offset - Add opt-out via RAW_RESPONSE_HEADER for users who want raw streams - after_sequence is handled internally by reconnector, not exposed in public API Per code review feedback from @dines-rl on TypeScript PR #765: 'after_sequence should just be a variable in the reconnector for follow-up' Part of porting TypeScript PR #765 features to Python SDK. --- .../resources/axons/axons.py | 118 +++++++++++++++--- .../types/axons/__init__.py | 1 + .../types/axons/axon_subscribe_sse_params.py | 14 +++ 3 files changed, 117 insertions(+), 16 deletions(-) create mode 100644 src/runloop_api_client/types/axons/axon_subscribe_sse_params.py diff --git a/src/runloop_api_client/resources/axons/axons.py b/src/runloop_api_client/resources/axons/axons.py index 54f4cb9ff..27031ee83 100644 --- a/src/runloop_api_client/resources/axons/axons.py +++ b/src/runloop_api_client/resources/axons/axons.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, cast from typing_extensions import Literal import httpx @@ -16,6 +16,7 @@ AsyncSqlResourceWithStreamingResponse, ) from ...types import axon_list_params, axon_create_params, axon_publish_params +from ...types.axons import axon_subscribe_sse_params from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import path_template, maybe_transform, async_maybe_transform from ..._compat import cached_property @@ -26,7 +27,8 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._streaming import Stream, AsyncStream +from ..._constants import RAW_RESPONSE_HEADER +from ..._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream from ...pagination import SyncAxonsCursorIDPage, AsyncAxonsCursorIDPage from ..._base_client import AsyncPaginator, make_request_options from ...types.axon_view import AxonView @@ -269,6 +271,8 @@ def subscribe_sse( """ [Beta] Subscribe to an axon event stream via server-sent events. + Automatically reconnects on timeout, resuming from last received event. + Args: extra_headers: Send extra headers @@ -282,14 +286,54 @@ def subscribe_sse( raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") default_headers: Headers = {"Accept": "text/event-stream"} merged_headers = default_headers if extra_headers is None else {**default_headers, **extra_headers} - return self._get( - path_template("/v1/axons/{id}/subscribe/sse", id=id), - options=make_request_options( - extra_headers=merged_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + + # Check if user wants raw response (opt-out of reconnection) + if extra_headers is not None and RAW_RESPONSE_HEADER in extra_headers: + return self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=AxonEventView, + stream=True, + stream_cls=Stream[AxonEventView], + ) + + def create_stream(last_sequence: str | None) -> Stream[AxonEventView]: + # after_sequence is used internally for reconnection only + sequence_int = int(last_sequence) if last_sequence is not None else None + return self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + {"after_sequence": sequence_int}, + axon_subscribe_sse_params.AxonSubscribeSseParams, + ), + ), + cast_to=AxonEventView, + stream=True, + stream_cls=Stream[AxonEventView], + ) + + initial_stream = create_stream(None) + + def get_sequence(item: AxonEventView) -> str | None: + value = getattr(item, "sequence", None) + if value is None: + return None + return str(value) + + return cast( + Stream[AxonEventView], + ReconnectingStream( + current_stream=initial_stream, + stream_creator=create_stream, + get_offset=get_sequence, ), - cast_to=AxonEventView, - stream=True, - stream_cls=Stream[AxonEventView], ) @@ -526,6 +570,8 @@ async def subscribe_sse( """ [Beta] Subscribe to an axon event stream via server-sent events. + Automatically reconnects on timeout, resuming from last received event. + Args: extra_headers: Send extra headers @@ -539,14 +585,54 @@ async def subscribe_sse( raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") default_headers: Headers = {"Accept": "text/event-stream"} merged_headers = default_headers if extra_headers is None else {**default_headers, **extra_headers} - return await self._get( - path_template("/v1/axons/{id}/subscribe/sse", id=id), - options=make_request_options( - extra_headers=merged_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + + # Check if user wants raw response (opt-out of reconnection) + if extra_headers is not None and RAW_RESPONSE_HEADER in extra_headers: + return await self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=AxonEventView, + stream=True, + stream_cls=AsyncStream[AxonEventView], + ) + + async def create_stream(last_sequence: str | None) -> AsyncStream[AxonEventView]: + # after_sequence is used internally for reconnection only + sequence_int = int(last_sequence) if last_sequence is not None else None + return await self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + {"after_sequence": sequence_int}, + axon_subscribe_sse_params.AxonSubscribeSseParams, + ), + ), + cast_to=AxonEventView, + stream=True, + stream_cls=AsyncStream[AxonEventView], + ) + + initial_stream = await create_stream(None) + + def get_sequence(item: AxonEventView) -> str | None: + value = getattr(item, "sequence", None) + if value is None: + return None + return str(value) + + return cast( + AsyncStream[AxonEventView], + AsyncReconnectingStream( + current_stream=initial_stream, + stream_creator=create_stream, + get_offset=get_sequence, ), - cast_to=AxonEventView, - stream=True, - stream_cls=AsyncStream[AxonEventView], ) diff --git a/src/runloop_api_client/types/axons/__init__.py b/src/runloop_api_client/types/axons/__init__.py index 8ab8cf9b1..9602bf477 100644 --- a/src/runloop_api_client/types/axons/__init__.py +++ b/src/runloop_api_client/types/axons/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from .axon_subscribe_sse_params import AxonSubscribeSseParams as AxonSubscribeSseParams from .sql_batch_params import SqlBatchParams as SqlBatchParams from .sql_query_params import SqlQueryParams as SqlQueryParams from .sql_step_error_view import SqlStepErrorView as SqlStepErrorView diff --git a/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py b/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py new file mode 100644 index 000000000..dc3bebb73 --- /dev/null +++ b/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py @@ -0,0 +1,14 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +from ..._types import NotGiven + +__all__ = ["AxonSubscribeSseParams"] + + +class AxonSubscribeSseParams(TypedDict, total=False): + after_sequence: int | NotGiven + """Resume SSE stream from events after this sequence number (used internally for reconnection)""" From a01586580884b2b149c4b1c905d66db13727152f Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Wed, 1 Apr 2026 16:41:15 -0700 Subject: [PATCH 05/10] test: add comprehensive unit tests for new features Add extensive test coverage for: 1. CancellationToken class (test_cancellation.py): - Thread-safe cancellation behavior - Sync and async event handling - Lazy async event creation - raise_if_cancelled() method - Multiple tokens independence - Cross-thread cancellation 2. Polling with cancellation (test_polling.py): - Cancellation before first poll - Cancellation during polling loop - Cancellation during sleep with immediate wake-up - Cancellation with error handlers - Backward compatibility (None token) - Multiple polls with same token 3. Async polling with cancellation (test_polling_async.py): - All sync tests adapted for async - Concurrent polling with shared token - Cancellation from different async tasks - asyncio.TimeoutError handling 4. SSE auto-reconnect for Axons (test_axon_sse_reconnect.py): - ReconnectingStream/AsyncReconnectingStream usage - Sequence-based resumption after disconnect - RAW_RESPONSE_HEADER opt-out mechanism - Missing/None sequence handling - Request options preservation - AxonSubscribeSseParams structure Total: 60+ new test cases ensuring robustness of both features. Part of porting TypeScript PR #765 features to Python SDK. --- tests/test_axon_sse_reconnect.py | 341 +++++++++++++++++++++++++++++++ tests/test_cancellation.py | 213 +++++++++++++++++++ tests/test_polling.py | 173 ++++++++++++++++ tests/test_polling_async.py | 293 ++++++++++++++++++++++++++ 4 files changed, 1020 insertions(+) create mode 100644 tests/test_axon_sse_reconnect.py create mode 100644 tests/test_cancellation.py create mode 100644 tests/test_polling_async.py diff --git a/tests/test_axon_sse_reconnect.py b/tests/test_axon_sse_reconnect.py new file mode 100644 index 000000000..6b5328cbf --- /dev/null +++ b/tests/test_axon_sse_reconnect.py @@ -0,0 +1,341 @@ +"""Tests for Axon SSE auto-reconnect functionality.""" + +from unittest.mock import Mock, patch, MagicMock +from typing import Iterator, AsyncIterator + +import pytest + +from src.runloop_api_client._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream +from src.runloop_api_client._constants import RAW_RESPONSE_HEADER +from src.runloop_api_client.types.axon_event_view import AxonEventView + + +class MockAxonEvent: + """Mock AxonEventView for testing.""" + + def __init__(self, sequence: int, data: str): + self.sequence = sequence + self.data = data + + +class TestAxonSSEReconnectSync: + """Test SSE reconnection for sync Axon subscriptions.""" + + def test_subscribe_sse_returns_reconnecting_stream(self): + """Test that subscribe_sse returns a ReconnectingStream.""" + from src.runloop_api_client import Runloop + + with patch.object(Runloop, "_get") as mock_get: + # Mock the initial stream + mock_stream = Mock(spec=Stream) + mock_get.return_value = mock_stream + + client = Runloop(api_key="test-key", base_url="http://test") + + result = client.axons.subscribe_sse("axon-123") + + # Should return a ReconnectingStream + assert isinstance(result, ReconnectingStream) + + def test_subscribe_sse_with_raw_header_returns_plain_stream(self): + """Test that RAW_RESPONSE_HEADER opts out of reconnection.""" + from src.runloop_api_client import Runloop + + with patch.object(Runloop, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + mock_get.return_value = mock_stream + + client = Runloop(api_key="test-key", base_url="http://test") + + result = client.axons.subscribe_sse("axon-123", extra_headers={RAW_RESPONSE_HEADER: "true"}) + + # Should return plain Stream, not ReconnectingStream + assert not isinstance(result, ReconnectingStream) + assert result == mock_stream + + def test_reconnection_uses_last_sequence(self): + """Test that reconnection uses the sequence from the last event.""" + from src.runloop_api_client import Runloop + + call_count = 0 + query_params = [] + + def mock_get(*args, **kwargs): + nonlocal call_count + # Capture query params + if "query" in kwargs.get("options", {}): + query_params.append(kwargs["options"]["query"]) + + # First call: return stream with events + if call_count == 0: + call_count += 1 + mock_stream = Mock(spec=Stream) + + def mock_iter(): + yield MockAxonEvent(sequence=1, data="event1") + yield MockAxonEvent(sequence=2, data="event2") + # Simulate timeout/disconnect + raise StopIteration() + + mock_stream.__iter__ = mock_iter + return mock_stream + + # Second call (reconnection): return stream continuing from sequence 2 + mock_stream = Mock(spec=Stream) + + def mock_iter(): + yield MockAxonEvent(sequence=3, data="event3") + + mock_stream.__iter__ = mock_iter + return mock_stream + + with patch.object(Runloop, "_get", side_effect=mock_get): + client = Runloop(api_key="test-key", base_url="http://test") + + stream = client.axons.subscribe_sse("axon-123") + + # Consume events + events = list(stream) + + # Should have 3 events total (2 from first stream, 1 from reconnected stream) + assert len(events) == 3 + assert events[0].sequence == 1 + assert events[1].sequence == 2 + assert events[2].sequence == 3 + + # Check that second call used after_sequence parameter + # Note: first call has None, second call should have after_sequence=2 + assert len(query_params) >= 2 + + def test_sequence_extraction_handles_missing_sequence(self): + """Test that missing sequence fields are handled gracefully.""" + from src.runloop_api_client import Runloop + + with patch.object(Runloop, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + + # Event without sequence attribute + class EventWithoutSequence: + pass + + def mock_iter(): + yield EventWithoutSequence() + + mock_stream.__iter__ = mock_iter + mock_get.return_value = mock_stream + + client = Runloop(api_key="test-key", base_url="http://test") + + stream = client.axons.subscribe_sse("axon-123") + + # Should not crash, sequence extractor should return None + events = list(stream) + assert len(events) == 1 + + def test_subscribe_sse_preserves_request_options(self): + """Test that extra headers, query, etc. are preserved.""" + from src.runloop_api_client import Runloop + + with patch.object(Runloop, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + mock_stream.__iter__ = lambda self: iter([]) + mock_get.return_value = mock_stream + + client = Runloop(api_key="test-key", base_url="http://test") + + extra_headers = {"X-Custom": "value"} + extra_query = {"param": "value"} + + client.axons.subscribe_sse( + "axon-123", extra_headers=extra_headers, extra_query=extra_query, timeout=30.0 + ) + + # Verify _get was called with the options + call_args = mock_get.call_args + options = call_args.kwargs["options"] + + # Headers should include Accept: text/event-stream and custom header + assert "Accept" in options["extra_headers"] + assert options["extra_headers"]["Accept"] == "text/event-stream" + assert options["extra_headers"]["X-Custom"] == "value" + assert options["extra_query"] == extra_query + assert options["timeout"] == 30.0 + + +class TestAxonSSEReconnectAsync: + """Test SSE reconnection for async Axon subscriptions.""" + + @pytest.mark.asyncio + async def test_subscribe_sse_returns_reconnecting_stream(self): + """Test that subscribe_sse returns an AsyncReconnectingStream.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*args, **kwargs): + mock_stream = Mock(spec=AsyncStream) + return mock_stream + + with patch.object(AsyncRunloop, "_get", side_effect=mock_get): + client = AsyncRunloop(api_key="test-key", base_url="http://test") + + result = await client.axons.subscribe_sse("axon-123") + + # Should return an AsyncReconnectingStream + assert isinstance(result, AsyncReconnectingStream) + + @pytest.mark.asyncio + async def test_subscribe_sse_with_raw_header_returns_plain_stream(self): + """Test that RAW_RESPONSE_HEADER opts out of reconnection.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*args, **kwargs): + mock_stream = Mock(spec=AsyncStream) + return mock_stream + + with patch.object(AsyncRunloop, "_get", side_effect=mock_get) as mock_get_method: + client = AsyncRunloop(api_key="test-key", base_url="http://test") + + result = await client.axons.subscribe_sse("axon-123", extra_headers={RAW_RESPONSE_HEADER: "true"}) + + # Should return plain AsyncStream, not AsyncReconnectingStream + assert not isinstance(result, AsyncReconnectingStream) + + @pytest.mark.asyncio + async def test_reconnection_uses_last_sequence(self): + """Test that reconnection uses the sequence from the last event.""" + from src.runloop_api_client import AsyncRunloop + + call_count = 0 + query_params = [] + + async def mock_get(*args, **kwargs): + nonlocal call_count + # Capture query params + if "query" in kwargs.get("options", {}): + query_params.append(kwargs["options"]["query"]) + + # First call: return stream with events + if call_count == 0: + call_count += 1 + mock_stream = Mock(spec=AsyncStream) + + async def mock_iter(): + yield MockAxonEvent(sequence=1, data="event1") + yield MockAxonEvent(sequence=2, data="event2") + + mock_stream.__aiter__ = mock_iter + return mock_stream + + # Second call (reconnection) + mock_stream = Mock(spec=AsyncStream) + + async def mock_iter(): + yield MockAxonEvent(sequence=3, data="event3") + + mock_stream.__aiter__ = mock_iter + return mock_stream + + with patch.object(AsyncRunloop, "_get", side_effect=mock_get): + client = AsyncRunloop(api_key="test-key", base_url="http://test") + + stream = await client.axons.subscribe_sse("axon-123") + + # Consume events + events = [] + async for event in stream: + events.append(event) + if len(events) >= 3: + break + + # Should have 3 events total + assert len(events) == 3 + assert events[0].sequence == 1 + assert events[1].sequence == 2 + assert events[2].sequence == 3 + + @pytest.mark.asyncio + async def test_sequence_extraction_handles_none(self): + """Test that None sequences are handled gracefully.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*args, **kwargs): + mock_stream = Mock(spec=AsyncStream) + + # Event with sequence = None + class EventWithNoneSequence: + sequence = None + + async def mock_iter(): + yield EventWithNoneSequence() + + mock_stream.__aiter__ = mock_iter + return mock_stream + + with patch.object(AsyncRunloop, "_get", side_effect=mock_get): + client = AsyncRunloop(api_key="test-key", base_url="http://test") + + stream = await client.axons.subscribe_sse("axon-123") + + # Should not crash + events = [] + async for event in stream: + events.append(event) + break + + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_subscribe_sse_preserves_request_options(self): + """Test that extra headers, query, etc. are preserved in async.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*args, **kwargs): + mock_stream = Mock(spec=AsyncStream) + + async def mock_iter(): + return + yield # Make it async generator + + mock_stream.__aiter__ = mock_iter + return mock_stream + + with patch.object(AsyncRunloop, "_get", side_effect=mock_get) as mock_get_method: + client = AsyncRunloop(api_key="test-key", base_url="http://test") + + extra_headers = {"X-Custom": "value"} + extra_query = {"param": "value"} + + await client.axons.subscribe_sse( + "axon-123", extra_headers=extra_headers, extra_query=extra_query, timeout=30.0 + ) + + # Verify _get was called with the options + call_args = mock_get_method.call_args + options = call_args.kwargs["options"] + + # Headers should include Accept: text/event-stream and custom header + assert "Accept" in options["extra_headers"] + assert options["extra_headers"]["Accept"] == "text/event-stream" + assert options["extra_headers"]["X-Custom"] == "value" + assert options["extra_query"] == extra_query + assert options["timeout"] == 30.0 + + +class TestAxonSubscribeSseParams: + """Test AxonSubscribeSseParams TypedDict.""" + + def test_params_structure(self): + """Test that AxonSubscribeSseParams has the correct structure.""" + from src.runloop_api_client.types.axons import AxonSubscribeSseParams + from src.runloop_api_client._types import NOT_GIVEN + + # Should be able to create with after_sequence + params: AxonSubscribeSseParams = {"after_sequence": 123} + assert params["after_sequence"] == 123 + + # Should be able to create with NOT_GIVEN + params2: AxonSubscribeSseParams = {"after_sequence": NOT_GIVEN} + assert params2["after_sequence"] is NOT_GIVEN + + # Should be able to create with None implicitly (total=False) + params3: AxonSubscribeSseParams = {} + assert "after_sequence" not in params3 diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py new file mode 100644 index 000000000..859d098e7 --- /dev/null +++ b/tests/test_cancellation.py @@ -0,0 +1,213 @@ +"""Tests for CancellationToken and PollingCancelled exception.""" + +import threading +import asyncio +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from src.runloop_api_client.lib.cancellation import CancellationToken, PollingCancelled + + +class TestPollingCancelled: + """Test PollingCancelled exception.""" + + def test_polling_cancelled_initialization(self): + """Test PollingCancelled exception initialization.""" + exception = PollingCancelled("Operation was cancelled") + assert "Operation was cancelled" in str(exception) + + def test_polling_cancelled_inherits_from_runloop_error(self): + """Test that PollingCancelled inherits from RunloopError.""" + from src.runloop_api_client._exceptions import RunloopError + + exception = PollingCancelled("Test") + assert isinstance(exception, RunloopError) + + +class TestCancellationToken: + """Test CancellationToken class.""" + + def test_initialization(self): + """Test token is not cancelled on initialization.""" + token = CancellationToken() + assert not token.is_cancelled() + + def test_cancel(self): + """Test cancelling a token.""" + token = CancellationToken() + token.cancel() + assert token.is_cancelled() + + def test_cancel_idempotent(self): + """Test that calling cancel() multiple times is safe.""" + token = CancellationToken() + token.cancel() + token.cancel() + token.cancel() + assert token.is_cancelled() + + def test_raise_if_cancelled_not_cancelled(self): + """Test raise_if_cancelled() when token is not cancelled.""" + token = CancellationToken() + # Should not raise + token.raise_if_cancelled() + + def test_raise_if_cancelled_when_cancelled(self): + """Test raise_if_cancelled() when token is cancelled.""" + token = CancellationToken() + token.cancel() + with pytest.raises(PollingCancelled, match="Polling operation was cancelled"): + token.raise_if_cancelled() + + def test_sync_event_property(self): + """Test sync_event property returns threading.Event.""" + token = CancellationToken() + event = token.sync_event + assert isinstance(event, threading.Event) + assert not event.is_set() + + token.cancel() + assert event.is_set() + + def test_sync_event_wait(self): + """Test sync_event can be used with wait().""" + token = CancellationToken() + event = token.sync_event + + # Should timeout since not cancelled + result = event.wait(timeout=0.01) + assert not result + + token.cancel() + # Should return immediately when cancelled + result = event.wait(timeout=1.0) + assert result + + @pytest.mark.asyncio + async def test_async_event_property(self): + """Test async_event property returns asyncio.Event.""" + token = CancellationToken() + event = token.async_event + assert isinstance(event, asyncio.Event) + assert not event.is_set() + + token.cancel() + assert event.is_set() + + @pytest.mark.asyncio + async def test_async_event_wait(self): + """Test async_event can be used with wait().""" + token = CancellationToken() + event = token.async_event + + # Should timeout since not cancelled + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(event.wait(), timeout=0.01) + + token.cancel() + # Should return immediately when cancelled + await asyncio.wait_for(event.wait(), timeout=1.0) + + @pytest.mark.asyncio + async def test_async_event_lazy_creation(self): + """Test that async_event is created lazily.""" + token = CancellationToken() + # Access _async_event directly to check it's None + assert token._async_event is None + + # Access property to trigger lazy creation + event = token.async_event + assert token._async_event is not None + assert event is token._async_event + + @pytest.mark.asyncio + async def test_async_event_set_if_already_cancelled(self): + """Test that async_event is set immediately if token was already cancelled.""" + token = CancellationToken() + token.cancel() + + # Async event should be set when created + event = token.async_event + assert event.is_set() + + def test_thread_safety(self): + """Test that CancellationToken is thread-safe.""" + token = CancellationToken() + results = [] + + def cancel_token(): + token.cancel() + results.append(token.is_cancelled()) + + def check_token(): + # Busy wait until cancelled + while not token.is_cancelled(): + pass + results.append(token.is_cancelled()) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + # Start 4 checking threads + for _ in range(4): + futures.append(executor.submit(check_token)) + # Start 1 cancelling thread + futures.append(executor.submit(cancel_token)) + + # Wait for all to complete + for future in futures: + future.result(timeout=2.0) + + # All results should be True + assert all(results) + assert len(results) == 5 + + def test_sync_and_async_events_synchronized(self): + """Test that sync and async events are synchronized.""" + token = CancellationToken() + + # Get both events + sync_event = token.sync_event + async_event = token.async_event + + assert not sync_event.is_set() + assert not async_event.is_set() + + # Cancel token + token.cancel() + + # Both should be set + assert sync_event.is_set() + assert async_event.is_set() + + def test_multiple_tokens_independent(self): + """Test that multiple tokens are independent.""" + token1 = CancellationToken() + token2 = CancellationToken() + + token1.cancel() + + assert token1.is_cancelled() + assert not token2.is_cancelled() + + @pytest.mark.asyncio + async def test_async_cancellation_propagation(self): + """Test cancellation in async context.""" + token = CancellationToken() + + async def wait_for_cancellation(): + await token.async_event.wait() + return token.is_cancelled() + + # Start waiting task + task = asyncio.create_task(wait_for_cancellation()) + + # Give it a moment to start + await asyncio.sleep(0.01) + + # Cancel token + token.cancel() + + # Task should complete and return True + result = await asyncio.wait_for(task, timeout=1.0) + assert result is True diff --git a/tests/test_polling.py b/tests/test_polling.py index 74819531b..59ab971d1 100644 --- a/tests/test_polling.py +++ b/tests/test_polling.py @@ -1,9 +1,11 @@ +import threading from typing import Any from unittest.mock import Mock, patch import pytest from src.runloop_api_client.lib.polling import PollingConfig, PollingTimeout, poll_until +from src.runloop_api_client.lib.cancellation import CancellationToken, PollingCancelled class TestPollingConfig: @@ -260,3 +262,174 @@ def dynamic_terminal(_: Any) -> bool: assert result == "value3" assert retriever.call_count == 3 + + +class TestPollUntilWithCancellation: + """Test poll_until function with CancellationToken.""" + + def test_cancellation_before_first_poll(self): + """Test cancellation before first poll attempt.""" + token = CancellationToken() + token.cancel() + + retriever = Mock(return_value="value") + is_terminal = Mock(return_value=False) + + with pytest.raises(PollingCancelled, match="Polling operation was cancelled"): + poll_until(retriever, is_terminal, cancellation_token=token) + + # Should not call retriever since cancelled before first attempt + assert retriever.call_count == 0 + + def test_cancellation_during_polling(self): + """Test cancellation during polling loop.""" + token = CancellationToken() + retriever = Mock(side_effect=["value1", "value2", "value3"]) + is_terminal = Mock(return_value=False) + + call_count = 0 + + def cancel_on_second_call(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + token.cancel() + # Don't actually sleep, just check cancellation + if token.is_cancelled(): + raise PollingCancelled("Polling operation was cancelled") + + with patch("time.sleep", side_effect=cancel_on_second_call): + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, cancellation_token=token) + + # Should have called retriever twice before cancellation + assert retriever.call_count == 2 + + def test_cancellation_during_sleep(self): + """Test that cancellation wakes up from sleep.""" + token = CancellationToken() + retriever = Mock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=10.0) # Long sleep + + def cancel_after_delay(): + import time + + time.sleep(0.05) # Wait a bit + token.cancel() + + # Start cancellation in background thread + cancel_thread = threading.Thread(target=cancel_after_delay) + cancel_thread.start() + + # Poll should be interrupted by cancellation + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, config, cancellation_token=token) + + cancel_thread.join() + + # Should have attempted once before cancellation during sleep + assert retriever.call_count == 1 + + def test_no_cancellation_completes_normally(self): + """Test that polling completes normally without cancellation.""" + token = CancellationToken() + retriever = Mock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + + with patch("time.sleep"): + result = poll_until(retriever, is_terminal, cancellation_token=token) + + assert result == "completed" + assert not token.is_cancelled() + + def test_cancellation_after_completion_no_effect(self): + """Test that cancelling after completion has no effect.""" + token = CancellationToken() + retriever = Mock(return_value="completed") + is_terminal = Mock(return_value=True) + + result = poll_until(retriever, is_terminal, cancellation_token=token) + + # Cancel after completion + token.cancel() + + assert result == "completed" + + def test_cancellation_with_error_handler(self): + """Test cancellation works with error handler.""" + token = CancellationToken() + retriever = Mock(side_effect=[ValueError("error"), "value"]) + is_terminal = Mock(return_value=False) + + def error_handler(_: Exception) -> str: + return "handled" + + def cancel_on_first_sleep(*args, **kwargs): + token.cancel() + raise PollingCancelled("Polling operation was cancelled") + + with patch("time.sleep", side_effect=cancel_on_first_sleep): + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, on_error=error_handler, cancellation_token=token) + + def test_cancellation_with_custom_config(self): + """Test cancellation with custom polling config.""" + token = CancellationToken() + retriever = Mock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=0.5, max_attempts=10) + + token.cancel() + + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, config, cancellation_token=token) + + def test_none_cancellation_token_works_normally(self): + """Test that passing None as cancellation_token works (backward compatibility).""" + retriever = Mock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + + with patch("time.sleep"): + result = poll_until(retriever, is_terminal, cancellation_token=None) + + assert result == "completed" + + def test_cancellable_sleep_blocks_correctly(self): + """Test that cancellable sleep blocks for the correct duration.""" + import time + + token = CancellationToken() + retriever = Mock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.1) + + start = time.time() + result = poll_until(retriever, is_terminal, config, cancellation_token=token) + elapsed = time.time() - start + + assert result == "completed" + # Should have slept approximately 0.1 seconds + assert 0.08 <= elapsed <= 0.15 # Allow some tolerance + + def test_multiple_cancellations_same_token(self): + """Test that the same token can be used for multiple poll operations.""" + token = CancellationToken() + + # First poll succeeds + retriever1 = Mock(return_value="done") + is_terminal1 = Mock(return_value=True) + result1 = poll_until(retriever1, is_terminal1, cancellation_token=token) + assert result1 == "done" + + # Cancel token + token.cancel() + + # Second poll should fail immediately + retriever2 = Mock(return_value="value") + is_terminal2 = Mock(return_value=False) + with pytest.raises(PollingCancelled): + poll_until(retriever2, is_terminal2, cancellation_token=token) + + # Second retriever should not be called + assert retriever2.call_count == 0 diff --git a/tests/test_polling_async.py b/tests/test_polling_async.py new file mode 100644 index 000000000..3bc18730f --- /dev/null +++ b/tests/test_polling_async.py @@ -0,0 +1,293 @@ +"""Tests for async polling with cancellation.""" + +import asyncio +from typing import Any +from unittest.mock import Mock, AsyncMock, patch + +import pytest + +from src.runloop_api_client.lib.polling import PollingConfig, PollingTimeout +from src.runloop_api_client.lib.polling_async import async_poll_until +from src.runloop_api_client.lib.cancellation import CancellationToken, PollingCancelled + + +class TestAsyncPollUntil: + """Test async_poll_until function.""" + + @pytest.mark.asyncio + async def test_immediate_success(self): + """Test when condition is met on first attempt.""" + retriever = AsyncMock(return_value="completed") + is_terminal = Mock(return_value=True) + + result = await async_poll_until(retriever, is_terminal) + + assert result == "completed" + assert retriever.call_count == 1 + assert is_terminal.call_count == 1 + + @pytest.mark.asyncio + async def test_success_after_multiple_attempts(self): + """Test when condition is met after several attempts.""" + values = ["pending", "running", "completed"] + retriever = AsyncMock(side_effect=values) + is_terminal = Mock(side_effect=[False, False, True]) + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await async_poll_until(retriever, is_terminal) + + assert result == "completed" + assert retriever.call_count == 3 + assert is_terminal.call_count == 3 + assert mock_sleep.call_count == 2 + + @pytest.mark.asyncio + async def test_max_attempts_exceeded(self): + """Test when max attempts is exceeded.""" + retriever = AsyncMock(return_value="still_running") + is_terminal = Mock(return_value=False) + config = PollingConfig(max_attempts=3, interval_seconds=0.01) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(PollingTimeout) as exc_info: + await async_poll_until(retriever, is_terminal, config) + + assert "Exceeded maximum attempts (3)" in str(exc_info.value) + assert retriever.call_count == 3 + + @pytest.mark.asyncio + async def test_timeout_exceeded(self): + """Test when timeout is exceeded.""" + retriever = AsyncMock(return_value="still_running") + is_terminal = Mock(return_value=False) + config = PollingConfig(timeout_seconds=0.1, interval_seconds=0.01) + + start_time = 1000.0 + with patch("time.time", side_effect=[start_time, start_time + 0.05, start_time + 0.15]): + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(PollingTimeout) as exc_info: + await async_poll_until(retriever, is_terminal, config) + + assert "Exceeded timeout of 0.1 seconds" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_error_without_handler(self): + """Test that exceptions are re-raised when no error handler is provided.""" + retriever = AsyncMock(side_effect=ValueError("Test error")) + is_terminal = Mock(return_value=False) + + with pytest.raises(ValueError, match="Test error"): + await async_poll_until(retriever, is_terminal) + + @pytest.mark.asyncio + async def test_error_with_handler_continue(self): + """Test error handler that allows polling to continue.""" + retriever = AsyncMock(side_effect=[ValueError("Test error"), "recovered"]) + is_terminal = Mock(side_effect=[False, True]) + + def error_handler(_: Exception) -> str: + return "error_handled" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await async_poll_until(retriever, is_terminal, on_error=error_handler) + + assert result == "recovered" + assert retriever.call_count == 2 + + +class TestAsyncPollUntilWithCancellation: + """Test async_poll_until with CancellationToken.""" + + @pytest.mark.asyncio + async def test_cancellation_before_first_poll(self): + """Test cancellation before first poll attempt.""" + token = CancellationToken() + token.cancel() + + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + + with pytest.raises(PollingCancelled, match="Polling operation was cancelled"): + await async_poll_until(retriever, is_terminal, cancellation_token=token) + + assert retriever.call_count == 0 + + @pytest.mark.asyncio + async def test_cancellation_during_polling(self): + """Test cancellation during polling loop.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["value1", "value2", "value3"]) + is_terminal = Mock(return_value=False) + + call_count = 0 + + async def cancel_on_second_call(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 2: + token.cancel() + if token.is_cancelled(): + raise PollingCancelled("Polling operation was cancelled") + + with patch("asyncio.sleep", side_effect=cancel_on_second_call): + with pytest.raises(PollingCancelled): + await async_poll_until(retriever, is_terminal, cancellation_token=token) + + assert retriever.call_count == 2 + + @pytest.mark.asyncio + async def test_cancellation_during_sleep(self): + """Test that cancellation wakes up from sleep.""" + token = CancellationToken() + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=10.0) # Long sleep + + async def cancel_after_delay(): + await asyncio.sleep(0.05) + token.cancel() + + # Start cancellation task + cancel_task = asyncio.create_task(cancel_after_delay()) + + # Poll should be interrupted by cancellation + with pytest.raises(PollingCancelled): + await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + + await cancel_task + + # Should have attempted once before cancellation + assert retriever.call_count == 1 + + @pytest.mark.asyncio + async def test_no_cancellation_completes_normally(self): + """Test that polling completes normally without cancellation.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await async_poll_until(retriever, is_terminal, cancellation_token=token) + + assert result == "completed" + assert not token.is_cancelled() + + @pytest.mark.asyncio + async def test_cancellation_after_completion_no_effect(self): + """Test that cancelling after completion has no effect.""" + token = CancellationToken() + retriever = AsyncMock(return_value="completed") + is_terminal = Mock(return_value=True) + + result = await async_poll_until(retriever, is_terminal, cancellation_token=token) + + token.cancel() + + assert result == "completed" + + @pytest.mark.asyncio + async def test_cancellation_with_error_handler(self): + """Test cancellation works with error handler.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=[ValueError("error"), "value"]) + is_terminal = Mock(return_value=False) + + def error_handler(_: Exception) -> str: + return "handled" + + async def cancel_on_first_sleep(*args, **kwargs): + token.cancel() + raise PollingCancelled("Polling operation was cancelled") + + with patch("asyncio.sleep", side_effect=cancel_on_first_sleep): + with pytest.raises(PollingCancelled): + await async_poll_until(retriever, is_terminal, on_error=error_handler, cancellation_token=token) + + @pytest.mark.asyncio + async def test_none_cancellation_token_works_normally(self): + """Test that passing None as cancellation_token works (backward compatibility).""" + retriever = AsyncMock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await async_poll_until(retriever, is_terminal, cancellation_token=None) + + assert result == "completed" + + @pytest.mark.asyncio + async def test_cancellable_sleep_blocks_correctly(self): + """Test that cancellable sleep blocks for the correct duration.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.1) + + start = asyncio.get_event_loop().time() + result = await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + elapsed = asyncio.get_event_loop().time() - start + + assert result == "completed" + # Should have slept approximately 0.1 seconds + assert 0.08 <= elapsed <= 0.15 + + @pytest.mark.asyncio + async def test_concurrent_polling_with_shared_token(self): + """Test multiple concurrent polls with the same token.""" + token = CancellationToken() + + async def poll_task(): + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=0.01) + await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + + # Start multiple polling tasks + tasks = [asyncio.create_task(poll_task()) for _ in range(3)] + + # Give them time to start + await asyncio.sleep(0.05) + + # Cancel the shared token + token.cancel() + + # All tasks should raise PollingCancelled + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + assert isinstance(result, PollingCancelled) + + @pytest.mark.asyncio + async def test_cancellation_timeout_error_handling(self): + """Test that asyncio.TimeoutError during cancellable sleep is handled correctly.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["value1", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.1) + + # The cancellable sleep should handle TimeoutError correctly and continue + result = await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + + assert result == "completed" + assert retriever.call_count == 2 + + @pytest.mark.asyncio + async def test_cancellation_from_different_task(self): + """Test that token can be cancelled from a different async task.""" + token = CancellationToken() + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=1.0) + + async def cancel_from_other_task(): + await asyncio.sleep(0.1) + token.cancel() + + # Start both tasks + poll_task = asyncio.create_task(async_poll_until(retriever, is_terminal, config, cancellation_token=token)) + cancel_task = asyncio.create_task(cancel_from_other_task()) + + # Wait for both + with pytest.raises(PollingCancelled): + await poll_task + + await cancel_task From 34d932fa9096eaab5fe82c7939f13da5650e43aa Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Wed, 1 Apr 2026 18:28:52 -0700 Subject: [PATCH 06/10] test: add comprehensive list pagination smoke tests (port of TS PR #767) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Port verification and testing from TypeScript PR #767 which fixed slow list endpoints that were auto-paginating through all pages. ## Analysis Results The Python SDK does NOT suffer from the bug that affected TypeScript because: - Python implementation uses direct property access (`page.items`) - TypeScript was using async iteration (`for await`) which auto-paginated - Our existing unit tests already used correct mocking patterns ## What Was Added New smoke tests in `tests/smoketests/sdk/test_list_pagination.py`: - 28 comprehensive tests (15 async, 13 sync) - Tests all 13 resource types (agents, devboxes, blueprints, etc.) - Verifies `list(limit=N)` returns at most N items - Ensures no auto-pagination occurs - Includes data creation test to verify with actual API calls ## Verification All SDK list methods verified correct: ✅ AsyncDevboxOps.list() - accesses page.devboxes ✅ AsyncSnapshotOps.list() - accesses page.snapshots ✅ AsyncBlueprintOps.list() - accesses page.blueprints ✅ AsyncStorageObjectOps.list() - accesses page.objects ✅ AsyncAxonOps.list() - accesses result.axons ✅ AsyncScorerOps.list() - accesses page.scorers ✅ AsyncAgentOps.list() - accesses page.agents ✅ AsyncScenarioOps.list() - accesses page.scenarios ✅ AsyncBenchmarkOps.list() - accesses page.benchmarks ✅ AsyncNetworkPolicyOps.list() - accesses page.network_policies ✅ AsyncGatewayConfigOps.list() - accesses page.gateway_configs ✅ AsyncMcpConfigOps.list() - accesses page.mcp_configs ✅ AsyncSecretOps.list() - accesses result.secrets (+ all sync equivalents) ## Benefits ✅ Faster list results - only fetches requested page ✅ Fewer API requests - no unnecessary pagination ✅ Better resource usage - respects limit parameter ✅ Documented behavior - tests serve as specification ✅ Regression prevention - ensures future changes maintain correctness ## Related - TypeScript PR: https://github.com/runloopai/api-client-ts/pull/767 - Detailed analysis: PR_767_PORT_SUMMARY.md - Code comparison: IMPLEMENTATION_COMPARISON.md - Quick reference: PR_767_PORT.md ## Testing Run the new tests: ```bash uv run pytest tests/smoketests/sdk/test_list_pagination.py -v ``` No source code changes required - Python implementation already correct. --- tests/smoketests/sdk/test_list_pagination.py | 306 +++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 tests/smoketests/sdk/test_list_pagination.py diff --git a/tests/smoketests/sdk/test_list_pagination.py b/tests/smoketests/sdk/test_list_pagination.py new file mode 100644 index 000000000..0bfb7ede3 --- /dev/null +++ b/tests/smoketests/sdk/test_list_pagination.py @@ -0,0 +1,306 @@ +"""Smoke tests to verify list methods respect limit parameter and only return one page. + +This test suite validates the fix for slow list endpoints, ensuring that +SDK list() methods return only the requested page of results instead of +auto-paginating through all available items. + +Related to TypeScript PR: https://github.com/runloopai/api-client-ts/pull/767 +""" + +from __future__ import annotations + +import pytest + +from runloop_api_client.sdk import AsyncRunloopSDK, RunloopSDK +from tests.smoketests.utils import unique_name +from runloop_api_client.types.shared_params import AgentSource + +pytestmark = pytest.mark.smoketest + +THIRTY_SECOND_TIMEOUT = 30 +AGENT_SOURCE: AgentSource = { + "type": "npm", + "npm": { + "package_name": "@runloop/hello-world-agent", + }, +} + + +class TestAsyncListPagination: + """Test async list methods respect limit and return only one page.""" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_agent_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify agent.list() with limit returns at most that many items.""" + # Request a small page + agents = await async_sdk_client.agent.list(limit=5) + + assert isinstance(agents, list) + # Should return at most 5 items (might be fewer if less data exists) + assert len(agents) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_agent_list_limit_one(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify agent.list() with limit=1 returns at most one item.""" + agents = await async_sdk_client.agent.list(limit=1) + + assert isinstance(agents, list) + assert len(agents) <= 1, "list(limit=1) should return at most 1 item" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_devbox_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify devbox.list() with limit returns at most that many items.""" + devboxes = await async_sdk_client.devbox.list(limit=3) + + assert isinstance(devboxes, list) + assert len(devboxes) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_blueprint_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify blueprint.list() with limit returns at most that many items.""" + blueprints = await async_sdk_client.blueprint.list(limit=5) + + assert isinstance(blueprints, list) + assert len(blueprints) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_storage_object_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify storage_object.list() with limit returns at most that many items.""" + objects = await async_sdk_client.storage_object.list(limit=4) + + assert isinstance(objects, list) + assert len(objects) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_snapshot_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify snapshot.list() with limit returns at most that many items.""" + snapshots = await async_sdk_client.snapshot.list(limit=2) + + assert isinstance(snapshots, list) + assert len(snapshots) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_axon_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify axon.list() with limit returns at most that many items.""" + axons = await async_sdk_client.axon.list(limit=3) + + assert isinstance(axons, list) + assert len(axons) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_scorer_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify scorer.list() with limit returns at most that many items.""" + scorers = await async_sdk_client.scorer.list(limit=5) + + assert isinstance(scorers, list) + assert len(scorers) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_scenario_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify scenario.list() with limit returns at most that many items.""" + scenarios = await async_sdk_client.scenario.list(limit=4) + + assert isinstance(scenarios, list) + assert len(scenarios) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_benchmark_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify benchmark.list() with limit returns at most that many items.""" + benchmarks = await async_sdk_client.benchmark.list(limit=3) + + assert isinstance(benchmarks, list) + assert len(benchmarks) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_network_policy_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify network_policy.list() with limit returns at most that many items.""" + policies = await async_sdk_client.network_policy.list(limit=5) + + assert isinstance(policies, list) + assert len(policies) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_gateway_config_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify gateway_config.list() with limit returns at most that many items.""" + configs = await async_sdk_client.gateway_config.list(limit=2) + + assert isinstance(configs, list) + assert len(configs) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_mcp_config_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify mcp_config.list() with limit returns at most that many items.""" + configs = await async_sdk_client.mcp_config.list(limit=3) + + assert isinstance(configs, list) + assert len(configs) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_secret_list_no_auto_pagination(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify secret.list() returns only one page (secrets don't have limit param).""" + secrets = await async_sdk_client.secret.list() + + # Secrets list doesn't have a limit parameter, but should still + # return only one page worth of results, not auto-paginate + assert isinstance(secrets, list) + + +class TestSyncListPagination: + """Test sync list methods respect limit and return only one page.""" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_agent_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify agent.list() with limit returns at most that many items.""" + agents = sdk_client.agent.list(limit=5) + + assert isinstance(agents, list) + assert len(agents) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_devbox_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify devbox.list() with limit returns at most that many items.""" + devboxes = sdk_client.devbox.list(limit=3) + + assert isinstance(devboxes, list) + assert len(devboxes) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_blueprint_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify blueprint.list() with limit returns at most that many items.""" + blueprints = sdk_client.blueprint.list(limit=5) + + assert isinstance(blueprints, list) + assert len(blueprints) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_storage_object_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify storage_object.list() with limit returns at most that many items.""" + objects = sdk_client.storage_object.list(limit=4) + + assert isinstance(objects, list) + assert len(objects) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_snapshot_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify snapshot.list() with limit returns at most that many items.""" + snapshots = sdk_client.snapshot.list(limit=2) + + assert isinstance(snapshots, list) + assert len(snapshots) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_axon_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify axon.list() with limit returns at most that many items.""" + axons = sdk_client.axon.list(limit=3) + + assert isinstance(axons, list) + assert len(axons) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_scorer_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify scorer.list() with limit returns at most that many items.""" + scorers = sdk_client.scorer.list(limit=5) + + assert isinstance(scorers, list) + assert len(scorers) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_scenario_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify scenario.list() with limit returns at most that many items.""" + scenarios = sdk_client.scenario.list(limit=4) + + assert isinstance(scenarios, list) + assert len(scenarios) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_benchmark_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify benchmark.list() with limit returns at most that many items.""" + benchmarks = sdk_client.benchmark.list(limit=3) + + assert isinstance(benchmarks, list) + assert len(benchmarks) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_network_policy_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify network_policy.list() with limit returns at most that many items.""" + policies = sdk_client.network_policy.list(limit=5) + + assert isinstance(policies, list) + assert len(policies) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_gateway_config_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify gateway_config.list() with limit returns at most that many items.""" + configs = sdk_client.gateway_config.list(limit=2) + + assert isinstance(configs, list) + assert len(configs) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_mcp_config_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify mcp_config.list() with limit returns at most that many items.""" + configs = sdk_client.mcp_config.list(limit=3) + + assert isinstance(configs, list) + assert len(configs) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_secret_list_no_auto_pagination(self, sdk_client: RunloopSDK) -> None: + """Verify secret.list() returns only one page.""" + secrets = sdk_client.secret.list() + + assert isinstance(secrets, list) + + +class TestListPaginationWithData: + """Test list pagination behavior when data is guaranteed to exist.""" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_list_limit_with_created_data(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Create multiple items and verify list limit works correctly.""" + # Create several agents to ensure we have data + created_agents = [] + for i in range(5): + agent = await async_sdk_client.agent.create( + name=unique_name(f"sdk-list-test-{i}"), + version="1.0.0", + source=AGENT_SOURCE, + ) + created_agents.append(agent) + + try: + # Request only 2 items + listed_agents = await async_sdk_client.agent.list(limit=2) + + # Should get at most 2, even though 5+ exist + assert len(listed_agents) <= 2, ( + f"Expected at most 2 items with limit=2, got {len(listed_agents)}. " + "This indicates auto-pagination is occurring when it shouldn't." + ) + + # Request 10 items - should get all we created (5) plus any existing ones, + # but should stop at first page (up to 10) + listed_agents = await async_sdk_client.agent.list(limit=10) + assert len(listed_agents) <= 10, ( + f"Expected at most 10 items with limit=10, got {len(listed_agents)}. " + "This indicates auto-pagination is occurring when it shouldn't." + ) + + finally: + # Cleanup is not possible yet as agents don't have delete + pass From b596e710ca1f9be2b7836d578bb36c041484cb5f Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Wed, 1 Apr 2026 20:52:20 -0700 Subject: [PATCH 07/10] Update tests, pass CI --- scripts/generate_examples_md.py | 6 +- src/runloop_api_client/_utils/_compat.py | 5 +- src/runloop_api_client/_utils/_utils.py | 3 +- src/runloop_api_client/lib/polling_async.py | 11 +- .../resources/axons/axons.py | 2 +- .../resources/blueprints.py | 2 +- .../resources/devboxes/devboxes.py | 2 +- .../resources/devboxes/disk_snapshots.py | 2 +- .../resources/devboxes/executions.py | 2 +- .../resources/scenarios/runs.py | 8 +- .../resources/scenarios/scenarios.py | 7 + src/runloop_api_client/sdk/_types.py | 2 +- .../types/axons/__init__.py | 2 +- .../types/axons/axon_subscribe_sse_params.py | 4 +- tests/api_resources/test_devboxes.py | 4 +- tests/sdk/test_async_scenario_run.py | 6 +- tests/sdk/test_scenario_run.py | 6 +- tests/smoketests/sdk/test_list_pagination.py | 4 +- tests/test_axon_sse_reconnect.py | 151 +++++++++--------- tests/test_cancellation.py | 10 +- tests/test_polling.py | 27 ++-- tests/test_polling_async.py | 30 ++-- uv.lock | 2 +- 23 files changed, 159 insertions(+), 139 deletions(-) diff --git a/scripts/generate_examples_md.py b/scripts/generate_examples_md.py index 9e47e9cae..43e4f81fe 100644 --- a/scripts/generate_examples_md.py +++ b/scripts/generate_examples_md.py @@ -11,10 +11,10 @@ import re import sys import argparse -from typing import Any +from typing import Any, cast from pathlib import Path -import frontmatter # type: ignore[import-untyped] +import frontmatter # type: ignore[import-not-found, import-untyped] ROOT = Path(__file__).parent.parent EXAMPLES_DIR = ROOT / "examples" @@ -38,7 +38,7 @@ def parse_example(path: Path) -> dict[str, Any]: raise ValueError(f"{path}: docstring must start with frontmatter (---)") try: - post = frontmatter.loads(docstring) + post = cast(Any, frontmatter).loads(docstring) return dict(post.metadata) except Exception as e: raise ValueError(f"{path}: invalid frontmatter: {e}") from e diff --git a/src/runloop_api_client/_utils/_compat.py b/src/runloop_api_client/_utils/_compat.py index 2c70b299c..38820cabc 100644 --- a/src/runloop_api_client/_utils/_compat.py +++ b/src/runloop_api_client/_utils/_compat.py @@ -2,7 +2,7 @@ import sys import typing_extensions -from typing import Any, Type, Union, Literal, Optional +from typing import Any, Type, Union, Literal, Optional, cast from datetime import date, datetime from typing_extensions import get_args as _get_args, get_origin as _get_origin @@ -34,7 +34,8 @@ def is_typeddict(tp: Type[Any]) -> bool: def is_literal_type(tp: Type[Any]) -> bool: - return get_origin(tp) in _LITERAL_TYPES + origin = get_origin(tp) + return cast(Any, origin) in _LITERAL_TYPES def parse_date(value: Union[date, StrBytesIntFloat]) -> date: diff --git a/src/runloop_api_client/_utils/_utils.py b/src/runloop_api_client/_utils/_utils.py index eec7f4a1f..ed08804e2 100644 --- a/src/runloop_api_client/_utils/_utils.py +++ b/src/runloop_api_client/_utils/_utils.py @@ -372,8 +372,7 @@ def file_from_path(path: str) -> FileTypes: def get_required_header(headers: HeadersLike, header: str) -> str: lower_header = header.lower() if is_mapping_t(headers): - # mypy doesn't understand the type narrowing here - for k, v in headers.items(): # type: ignore + for k, v in cast(Mapping[str, object], headers).items(): if k.lower() == lower_header and isinstance(v, str): return v diff --git a/src/runloop_api_client/lib/polling_async.py b/src/runloop_api_client/lib/polling_async.py index df3ffbc46..e462991e2 100644 --- a/src/runloop_api_client/lib/polling_async.py +++ b/src/runloop_api_client/lib/polling_async.py @@ -1,6 +1,7 @@ import time import asyncio from typing import Union, TypeVar, Callable, Optional, Awaitable +from contextlib import suppress from .polling import PollingConfig, PollingTimeout from .cancellation import CancellationToken @@ -67,13 +68,13 @@ async def async_poll_until( # Cancellable async sleep if cancellation_token is not None: + wait_task = asyncio.create_task(cancellation_token.async_event.wait()) try: - await asyncio.wait_for( - cancellation_token.async_event.wait(), - timeout=config.interval_seconds, - ) + await asyncio.wait_for(wait_task, timeout=config.interval_seconds) cancellation_token.raise_if_cancelled() except asyncio.TimeoutError: - pass # Normal sleep completion + wait_task.cancel() + with suppress(asyncio.CancelledError): + await wait_task else: await asyncio.sleep(config.interval_seconds) diff --git a/src/runloop_api_client/resources/axons/axons.py b/src/runloop_api_client/resources/axons/axons.py index 27031ee83..a09da2d9f 100644 --- a/src/runloop_api_client/resources/axons/axons.py +++ b/src/runloop_api_client/resources/axons/axons.py @@ -16,7 +16,6 @@ AsyncSqlResourceWithStreamingResponse, ) from ...types import axon_list_params, axon_create_params, axon_publish_params -from ...types.axons import axon_subscribe_sse_params from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given from ..._utils import path_template, maybe_transform, async_maybe_transform from ..._compat import cached_property @@ -30,6 +29,7 @@ from ..._constants import RAW_RESPONSE_HEADER from ..._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream from ...pagination import SyncAxonsCursorIDPage, AsyncAxonsCursorIDPage +from ...types.axons import axon_subscribe_sse_params from ..._base_client import AsyncPaginator, make_request_options from ...types.axon_view import AxonView from ...types.axon_event_view import AxonEventView diff --git a/src/runloop_api_client/resources/blueprints.py b/src/runloop_api_client/resources/blueprints.py index 94a40bb85..295cf0785 100644 --- a/src/runloop_api_client/resources/blueprints.py +++ b/src/runloop_api_client/resources/blueprints.py @@ -28,8 +28,8 @@ from .._exceptions import RunloopError from ..lib.polling import PollingConfig, poll_until from .._base_client import AsyncPaginator, make_request_options -from ..lib.polling_async import async_poll_until from ..lib.cancellation import CancellationToken +from ..lib.polling_async import async_poll_until from .._utils._validation import ValidationNotification from ..types.blueprint_view import BlueprintView from ..types.blueprint_preview_view import BlueprintPreviewView diff --git a/src/runloop_api_client/resources/devboxes/devboxes.py b/src/runloop_api_client/resources/devboxes/devboxes.py index dccc386c8..e383e8abe 100644 --- a/src/runloop_api_client/resources/devboxes/devboxes.py +++ b/src/runloop_api_client/resources/devboxes/devboxes.py @@ -97,8 +97,8 @@ DiskSnapshotsResourceWithStreamingResponse, AsyncDiskSnapshotsResourceWithStreamingResponse, ) -from ...lib.polling_async import async_poll_until from ...lib.cancellation import CancellationToken +from ...lib.polling_async import async_poll_until from ...types.devbox_view import DevboxView from ...types.tunnel_view import TunnelView from ...types.shared_params.mount import Mount diff --git a/src/runloop_api_client/resources/devboxes/disk_snapshots.py b/src/runloop_api_client/resources/devboxes/disk_snapshots.py index 4b4cd3678..76059ee2a 100644 --- a/src/runloop_api_client/resources/devboxes/disk_snapshots.py +++ b/src/runloop_api_client/resources/devboxes/disk_snapshots.py @@ -21,8 +21,8 @@ from ...lib.polling import PollingConfig, poll_until from ..._base_client import AsyncPaginator, make_request_options from ...types.devboxes import disk_snapshot_list_params, disk_snapshot_update_params -from ...lib.polling_async import async_poll_until from ...lib.cancellation import CancellationToken +from ...lib.polling_async import async_poll_until from ...types.devbox_snapshot_view import DevboxSnapshotView from ...types.devboxes.devbox_snapshot_async_status_view import DevboxSnapshotAsyncStatusView diff --git a/src/runloop_api_client/resources/devboxes/executions.py b/src/runloop_api_client/resources/devboxes/executions.py index 90896a198..9dd394365 100755 --- a/src/runloop_api_client/resources/devboxes/executions.py +++ b/src/runloop_api_client/resources/devboxes/executions.py @@ -32,8 +32,8 @@ execution_stream_stderr_updates_params, execution_stream_stdout_updates_params, ) -from ...lib.polling_async import async_poll_until from ...lib.cancellation import CancellationToken +from ...lib.polling_async import async_poll_until from ...types.devbox_send_std_in_result import DevboxSendStdInResult from ...types.devbox_execution_detail_view import DevboxExecutionDetailView from ...types.devboxes.execution_update_chunk import ExecutionUpdateChunk diff --git a/src/runloop_api_client/resources/scenarios/runs.py b/src/runloop_api_client/resources/scenarios/runs.py index 0cdb04df4..38e61af32 100644 --- a/src/runloop_api_client/resources/scenarios/runs.py +++ b/src/runloop_api_client/resources/scenarios/runs.py @@ -27,8 +27,8 @@ from ...lib.polling import PollingConfig, poll_until from ..._base_client import AsyncPaginator, make_request_options from ...types.scenarios import run_list_params -from ...lib.polling_async import async_poll_until from ...lib.cancellation import CancellationToken +from ...lib.polling_async import async_poll_until from ...types.scenario_run_view import ScenarioRunView __all__ = ["RunsResource", "AsyncRunsResource"] @@ -429,6 +429,7 @@ def score_and_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -441,6 +442,7 @@ def score_and_complete( Args: id: The ID of the scenario run to score, wait for, and complete polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -456,6 +458,7 @@ def score_and_complete( self.score_and_await( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, @@ -866,6 +869,7 @@ async def score_and_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -878,6 +882,7 @@ async def score_and_complete( Args: id: The ID of the scenario run to score, wait for, and complete polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -893,6 +898,7 @@ async def score_and_complete( await self.score_and_await( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, diff --git a/src/runloop_api_client/resources/scenarios/scenarios.py b/src/runloop_api_client/resources/scenarios/scenarios.py index e3ce8c91b..327f9a1a1 100644 --- a/src/runloop_api_client/resources/scenarios/scenarios.py +++ b/src/runloop_api_client/resources/scenarios/scenarios.py @@ -43,6 +43,7 @@ from ...pagination import SyncScenariosCursorIDPage, AsyncScenariosCursorIDPage from ...lib.polling import PollingConfig from ..._base_client import AsyncPaginator, make_request_options +from ...lib.cancellation import CancellationToken from ...types.scenario_view import ScenarioView from ...types.scenario_run_view import ScenarioRunView from ...types.input_context_param import InputContextParam @@ -527,6 +528,7 @@ def start_run_and_await_env_ready( run_name: Optional[str] | Omit = omit, run_profile: Optional[scenario_start_run_params.RunProfile] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -543,6 +545,7 @@ def start_run_and_await_env_ready( run_name: Display name of the run run_profile: Runtime configuration to use for this benchmark run polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -572,6 +575,7 @@ def start_run_and_await_env_ready( self._client.devboxes.await_running( run.devbox_id, polling_config=polling_config, + cancellation_token=cancellation_token, ) return run @@ -1048,6 +1052,7 @@ async def start_run_and_await_env_ready( run_name: Optional[str] | Omit = omit, run_profile: Optional[scenario_start_run_params.RunProfile] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -1064,6 +1069,7 @@ async def start_run_and_await_env_ready( run_name: Display name of the run run_profile: Runtime configuration to use for this benchmark run polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The scenario run in running state @@ -1088,6 +1094,7 @@ async def start_run_and_await_env_ready( await self._client.devboxes.await_running( run.devbox_id, polling_config=polling_config, + cancellation_token=cancellation_token, ) return run diff --git a/src/runloop_api_client/sdk/_types.py b/src/runloop_api_client/sdk/_types.py index 39a1b42de..4f2075429 100644 --- a/src/runloop_api_client/sdk/_types.py +++ b/src/runloop_api_client/sdk/_types.py @@ -40,9 +40,9 @@ ) from .._types import Body, Query, Headers, Timeout, NotGiven from ..lib.polling import PollingConfig -from ..lib.cancellation import CancellationToken from ..types.devboxes import DiskSnapshotListParams, DiskSnapshotUpdateParams from ..types.scenarios import ScorerListParams, ScorerCreateParams, ScorerUpdateParams +from ..lib.cancellation import CancellationToken from ..types.devbox_create_params import DevboxBaseCreateParams from ..types.axons.sql_batch_params import SqlBatchParams from ..types.axons.sql_query_params import SqlQueryParams diff --git a/src/runloop_api_client/types/axons/__init__.py b/src/runloop_api_client/types/axons/__init__.py index 9602bf477..1a00da9e0 100644 --- a/src/runloop_api_client/types/axons/__init__.py +++ b/src/runloop_api_client/types/axons/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations -from .axon_subscribe_sse_params import AxonSubscribeSseParams as AxonSubscribeSseParams from .sql_batch_params import SqlBatchParams as SqlBatchParams from .sql_query_params import SqlQueryParams as SqlQueryParams from .sql_step_error_view import SqlStepErrorView as SqlStepErrorView @@ -12,3 +11,4 @@ from .sql_step_result_view import SqlStepResultView as SqlStepResultView from .sql_batch_result_view import SqlBatchResultView as SqlBatchResultView from .sql_query_result_view import SqlQueryResultView as SqlQueryResultView +from .axon_subscribe_sse_params import AxonSubscribeSseParams as AxonSubscribeSseParams diff --git a/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py b/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py index dc3bebb73..c88c2e3da 100644 --- a/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py +++ b/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py @@ -4,11 +4,9 @@ from typing_extensions import TypedDict -from ..._types import NotGiven - __all__ = ["AxonSubscribeSseParams"] class AxonSubscribeSseParams(TypedDict, total=False): - after_sequence: int | NotGiven + after_sequence: int """Resume SSE stream from events after this sequence number (used internally for reconnection)""" diff --git a/tests/api_resources/test_devboxes.py b/tests/api_resources/test_devboxes.py index c04de1d06..b33c8ea13 100644 --- a/tests/api_resources/test_devboxes.py +++ b/tests/api_resources/test_devboxes.py @@ -1386,7 +1386,7 @@ def test_method_create_and_await_running_success(self, client: Runloop) -> None: assert result.id == "test_id" assert result.status == "running" mock_create.assert_called_once() - mock_await.assert_called_once_with("test_id", polling_config=None) + mock_await.assert_called_once_with("test_id", polling_config=None, cancellation_token=None) @parametrize def test_method_create_and_await_running_with_config(self, client: Runloop) -> None: @@ -1426,7 +1426,7 @@ def test_method_create_and_await_running_with_config(self, client: Runloop) -> N assert result.id == "test_id" assert result.status == "running" - mock_await.assert_called_once_with("test_id", polling_config=config) + mock_await.assert_called_once_with("test_id", polling_config=config, cancellation_token=None) @parametrize def test_method_create_and_await_running_create_failure(self, client: Runloop) -> None: diff --git a/tests/sdk/test_async_scenario_run.py b/tests/sdk/test_async_scenario_run.py index c034524a0..290caf545 100644 --- a/tests/sdk/test_async_scenario_run.py +++ b/tests/sdk/test_async_scenario_run.py @@ -54,7 +54,11 @@ async def test_await_env_ready( run = AsyncScenarioRun(mock_async_client, "scr_123", "dbx_123") result = await run.await_env_ready() - mock_async_client.devboxes.await_running.assert_awaited_once_with("dbx_123", polling_config=None) + mock_async_client.devboxes.await_running.assert_awaited_once_with( + "dbx_123", + polling_config=None, + cancellation_token=None, + ) assert result == scenario_run_view async def test_score(self, mock_async_client: AsyncMock, scenario_run_view: MockScenarioRunView) -> None: diff --git a/tests/sdk/test_scenario_run.py b/tests/sdk/test_scenario_run.py index 339e365f8..82c004da8 100644 --- a/tests/sdk/test_scenario_run.py +++ b/tests/sdk/test_scenario_run.py @@ -51,7 +51,11 @@ def test_await_env_ready( run = ScenarioRun(mock_client, "scr_123", "dbx_123") result = run.await_env_ready() - mock_client.devboxes.await_running.assert_called_once_with("dbx_123", polling_config=None) + mock_client.devboxes.await_running.assert_called_once_with( + "dbx_123", + polling_config=None, + cancellation_token=None, + ) assert result == scenario_run_view def test_score(self, mock_client: Mock, scenario_run_view: MockScenarioRunView) -> None: diff --git a/tests/smoketests/sdk/test_list_pagination.py b/tests/smoketests/sdk/test_list_pagination.py index 0bfb7ede3..72e2c14b2 100644 --- a/tests/smoketests/sdk/test_list_pagination.py +++ b/tests/smoketests/sdk/test_list_pagination.py @@ -11,7 +11,7 @@ import pytest -from runloop_api_client.sdk import AsyncRunloopSDK, RunloopSDK +from runloop_api_client.sdk import RunloopSDK, AsyncRunloopSDK from tests.smoketests.utils import unique_name from runloop_api_client.types.shared_params import AgentSource @@ -274,7 +274,7 @@ class TestListPaginationWithData: async def test_list_limit_with_created_data(self, async_sdk_client: AsyncRunloopSDK) -> None: """Create multiple items and verify list limit works correctly.""" # Create several agents to ensure we have data - created_agents = [] + created_agents: list[object] = [] for i in range(5): agent = await async_sdk_client.agent.create( name=unique_name(f"sdk-list-test-{i}"), diff --git a/tests/test_axon_sse_reconnect.py b/tests/test_axon_sse_reconnect.py index 6b5328cbf..cc2575c64 100644 --- a/tests/test_axon_sse_reconnect.py +++ b/tests/test_axon_sse_reconnect.py @@ -1,13 +1,13 @@ """Tests for Axon SSE auto-reconnect functionality.""" -from unittest.mock import Mock, patch, MagicMock -from typing import Iterator, AsyncIterator +from typing import Any, Iterator, AsyncIterator, cast +from unittest.mock import Mock, AsyncMock, patch +import httpx import pytest -from src.runloop_api_client._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream from src.runloop_api_client._constants import RAW_RESPONSE_HEADER -from src.runloop_api_client.types.axon_event_view import AxonEventView +from src.runloop_api_client._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream class MockAxonEvent: @@ -25,13 +25,13 @@ def test_subscribe_sse_returns_reconnecting_stream(self): """Test that subscribe_sse returns a ReconnectingStream.""" from src.runloop_api_client import Runloop - with patch.object(Runloop, "_get") as mock_get: + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: # Mock the initial stream mock_stream = Mock(spec=Stream) mock_get.return_value = mock_stream - client = Runloop(api_key="test-key", base_url="http://test") - result = client.axons.subscribe_sse("axon-123") # Should return a ReconnectingStream @@ -41,12 +41,12 @@ def test_subscribe_sse_with_raw_header_returns_plain_stream(self): """Test that RAW_RESPONSE_HEADER opts out of reconnection.""" from src.runloop_api_client import Runloop - with patch.object(Runloop, "_get") as mock_get: + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: mock_stream = Mock(spec=Stream) mock_get.return_value = mock_stream - client = Runloop(api_key="test-key", base_url="http://test") - result = client.axons.subscribe_sse("axon-123", extra_headers={RAW_RESPONSE_HEADER: "true"}) # Should return plain Stream, not ReconnectingStream @@ -58,44 +58,44 @@ def test_reconnection_uses_last_sequence(self): from src.runloop_api_client import Runloop call_count = 0 - query_params = [] + query_params: list[dict[str, object]] = [] - def mock_get(*args, **kwargs): + def mock_get(*_args: object, **kwargs: Any) -> Mock: nonlocal call_count # Capture query params - if "query" in kwargs.get("options", {}): - query_params.append(kwargs["options"]["query"]) + options = cast(dict[str, object], kwargs.get("options", {})) + if "params" in options: + query_params.append(cast(dict[str, object], options["params"])) # First call: return stream with events if call_count == 0: call_count += 1 mock_stream = Mock(spec=Stream) - def mock_iter(): + def first_iter(_self: object) -> Iterator[MockAxonEvent]: yield MockAxonEvent(sequence=1, data="event1") yield MockAxonEvent(sequence=2, data="event2") - # Simulate timeout/disconnect - raise StopIteration() + raise httpx.ReadTimeout("stream timed out") - mock_stream.__iter__ = mock_iter + mock_stream.__iter__ = first_iter return mock_stream # Second call (reconnection): return stream continuing from sequence 2 mock_stream = Mock(spec=Stream) - def mock_iter(): + def second_iter(_self: object) -> Iterator[MockAxonEvent]: yield MockAxonEvent(sequence=3, data="event3") - mock_stream.__iter__ = mock_iter + mock_stream.__iter__ = second_iter return mock_stream - with patch.object(Runloop, "_get", side_effect=mock_get): - client = Runloop(api_key="test-key", base_url="http://test") + client = Runloop(bearer_token="test-key", base_url="http://test") + with patch.object(client.axons, "_get", side_effect=mock_get): stream = client.axons.subscribe_sse("axon-123") # Consume events - events = list(stream) + events = cast(list[MockAxonEvent], list(stream)) # Should have 3 events total (2 from first stream, 1 from reconnected stream) assert len(events) == 3 @@ -111,21 +111,21 @@ def test_sequence_extraction_handles_missing_sequence(self): """Test that missing sequence fields are handled gracefully.""" from src.runloop_api_client import Runloop - with patch.object(Runloop, "_get") as mock_get: + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: mock_stream = Mock(spec=Stream) # Event without sequence attribute class EventWithoutSequence: pass - def mock_iter(): + def mock_iter(_self: object) -> Iterator[object]: yield EventWithoutSequence() mock_stream.__iter__ = mock_iter mock_get.return_value = mock_stream - client = Runloop(api_key="test-key", base_url="http://test") - stream = client.axons.subscribe_sse("axon-123") # Should not crash, sequence extractor should return None @@ -136,29 +136,32 @@ def test_subscribe_sse_preserves_request_options(self): """Test that extra headers, query, etc. are preserved.""" from src.runloop_api_client import Runloop - with patch.object(Runloop, "_get") as mock_get: + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: mock_stream = Mock(spec=Stream) - mock_stream.__iter__ = lambda self: iter([]) - mock_get.return_value = mock_stream - client = Runloop(api_key="test-key", base_url="http://test") + def empty_iter(_self: object) -> Iterator[object]: + return iter([]) + + mock_stream.__iter__ = empty_iter + mock_get.return_value = mock_stream extra_headers = {"X-Custom": "value"} extra_query = {"param": "value"} - client.axons.subscribe_sse( - "axon-123", extra_headers=extra_headers, extra_query=extra_query, timeout=30.0 - ) + client.axons.subscribe_sse("axon-123", extra_headers=extra_headers, extra_query=extra_query, timeout=30.0) # Verify _get was called with the options call_args = mock_get.call_args options = call_args.kwargs["options"] # Headers should include Accept: text/event-stream and custom header - assert "Accept" in options["extra_headers"] - assert options["extra_headers"]["Accept"] == "text/event-stream" - assert options["extra_headers"]["X-Custom"] == "value" - assert options["extra_query"] == extra_query + assert "Accept" in options["headers"] + assert options["headers"]["Accept"] == "text/event-stream" + assert options["headers"]["X-Custom"] == "value" + assert options["params"]["param"] == "value" + assert options["params"]["after_sequence"] is None assert options["timeout"] == 30.0 @@ -170,13 +173,13 @@ async def test_subscribe_sse_returns_reconnecting_stream(self): """Test that subscribe_sse returns an AsyncReconnectingStream.""" from src.runloop_api_client import AsyncRunloop - async def mock_get(*args, **kwargs): + async def mock_get(*_args: object, **_kwargs: object) -> Mock: mock_stream = Mock(spec=AsyncStream) return mock_stream - with patch.object(AsyncRunloop, "_get", side_effect=mock_get): - client = AsyncRunloop(api_key="test-key", base_url="http://test") + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): result = await client.axons.subscribe_sse("axon-123") # Should return an AsyncReconnectingStream @@ -187,13 +190,13 @@ async def test_subscribe_sse_with_raw_header_returns_plain_stream(self): """Test that RAW_RESPONSE_HEADER opts out of reconnection.""" from src.runloop_api_client import AsyncRunloop - async def mock_get(*args, **kwargs): + async def mock_get(*_args: object, **_kwargs: object) -> Mock: mock_stream = Mock(spec=AsyncStream) return mock_stream - with patch.object(AsyncRunloop, "_get", side_effect=mock_get) as mock_get_method: - client = AsyncRunloop(api_key="test-key", base_url="http://test") + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): result = await client.axons.subscribe_sse("axon-123", extra_headers={RAW_RESPONSE_HEADER: "true"}) # Should return plain AsyncStream, not AsyncReconnectingStream @@ -205,42 +208,44 @@ async def test_reconnection_uses_last_sequence(self): from src.runloop_api_client import AsyncRunloop call_count = 0 - query_params = [] + query_params: list[dict[str, object]] = [] - async def mock_get(*args, **kwargs): + async def mock_get(*_args: object, **kwargs: Any) -> Mock: nonlocal call_count # Capture query params - if "query" in kwargs.get("options", {}): - query_params.append(kwargs["options"]["query"]) + options = cast(dict[str, object], kwargs.get("options", {})) + if "params" in options: + query_params.append(cast(dict[str, object], options["params"])) # First call: return stream with events if call_count == 0: call_count += 1 mock_stream = Mock(spec=AsyncStream) - async def mock_iter(): + async def first_iter(_self: object) -> AsyncIterator[MockAxonEvent]: yield MockAxonEvent(sequence=1, data="event1") yield MockAxonEvent(sequence=2, data="event2") + raise httpx.ReadTimeout("stream timed out") - mock_stream.__aiter__ = mock_iter + mock_stream.__aiter__ = first_iter return mock_stream # Second call (reconnection) mock_stream = Mock(spec=AsyncStream) - async def mock_iter(): + async def second_iter(_self: object) -> AsyncIterator[MockAxonEvent]: yield MockAxonEvent(sequence=3, data="event3") - mock_stream.__aiter__ = mock_iter + mock_stream.__aiter__ = second_iter return mock_stream - with patch.object(AsyncRunloop, "_get", side_effect=mock_get): - client = AsyncRunloop(api_key="test-key", base_url="http://test") + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): stream = await client.axons.subscribe_sse("axon-123") # Consume events - events = [] + events: list[Any] = [] async for event in stream: events.append(event) if len(events) >= 3: @@ -257,26 +262,26 @@ async def test_sequence_extraction_handles_none(self): """Test that None sequences are handled gracefully.""" from src.runloop_api_client import AsyncRunloop - async def mock_get(*args, **kwargs): + async def mock_get(*_args: object, **_kwargs: object) -> Mock: mock_stream = Mock(spec=AsyncStream) # Event with sequence = None class EventWithNoneSequence: sequence = None - async def mock_iter(): + async def mock_iter(_self: object) -> AsyncIterator[object]: yield EventWithNoneSequence() mock_stream.__aiter__ = mock_iter return mock_stream - with patch.object(AsyncRunloop, "_get", side_effect=mock_get): - client = AsyncRunloop(api_key="test-key", base_url="http://test") + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): stream = await client.axons.subscribe_sse("axon-123") # Should not crash - events = [] + events: list[object] = [] async for event in stream: events.append(event) break @@ -288,7 +293,7 @@ async def test_subscribe_sse_preserves_request_options(self): """Test that extra headers, query, etc. are preserved in async.""" from src.runloop_api_client import AsyncRunloop - async def mock_get(*args, **kwargs): + async def mock_get(*_args: object, **_kwargs: object) -> Mock: mock_stream = Mock(spec=AsyncStream) async def mock_iter(): @@ -298,9 +303,9 @@ async def mock_iter(): mock_stream.__aiter__ = mock_iter return mock_stream - with patch.object(AsyncRunloop, "_get", side_effect=mock_get) as mock_get_method: - client = AsyncRunloop(api_key="test-key", base_url="http://test") + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)) as mock_get_method: extra_headers = {"X-Custom": "value"} extra_query = {"param": "value"} @@ -313,10 +318,11 @@ async def mock_iter(): options = call_args.kwargs["options"] # Headers should include Accept: text/event-stream and custom header - assert "Accept" in options["extra_headers"] - assert options["extra_headers"]["Accept"] == "text/event-stream" - assert options["extra_headers"]["X-Custom"] == "value" - assert options["extra_query"] == extra_query + assert "Accept" in options["headers"] + assert options["headers"]["Accept"] == "text/event-stream" + assert options["headers"]["X-Custom"] == "value" + assert options["params"]["param"] == "value" + assert options["params"]["after_sequence"] is None assert options["timeout"] == 30.0 @@ -326,16 +332,11 @@ class TestAxonSubscribeSseParams: def test_params_structure(self): """Test that AxonSubscribeSseParams has the correct structure.""" from src.runloop_api_client.types.axons import AxonSubscribeSseParams - from src.runloop_api_client._types import NOT_GIVEN # Should be able to create with after_sequence params: AxonSubscribeSseParams = {"after_sequence": 123} assert params["after_sequence"] == 123 - # Should be able to create with NOT_GIVEN - params2: AxonSubscribeSseParams = {"after_sequence": NOT_GIVEN} - assert params2["after_sequence"] is NOT_GIVEN - - # Should be able to create with None implicitly (total=False) - params3: AxonSubscribeSseParams = {} - assert "after_sequence" not in params3 + # The field is optional via total=False, so it can also be omitted. + params2: AxonSubscribeSseParams = {} + assert "after_sequence" not in params2 diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 859d098e7..0a24153da 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -1,12 +1,12 @@ """Tests for CancellationToken and PollingCancelled exception.""" -import threading import asyncio -from concurrent.futures import ThreadPoolExecutor +import threading +from concurrent.futures import Future, ThreadPoolExecutor import pytest -from src.runloop_api_client.lib.cancellation import CancellationToken, PollingCancelled +from src.runloop_api_client.lib.cancellation import PollingCancelled, CancellationToken class TestPollingCancelled: @@ -134,7 +134,7 @@ async def test_async_event_set_if_already_cancelled(self): def test_thread_safety(self): """Test that CancellationToken is thread-safe.""" token = CancellationToken() - results = [] + results: list[bool] = [] def cancel_token(): token.cancel() @@ -147,7 +147,7 @@ def check_token(): results.append(token.is_cancelled()) with ThreadPoolExecutor(max_workers=5) as executor: - futures = [] + futures: list[Future[None]] = [] # Start 4 checking threads for _ in range(4): futures.append(executor.submit(check_token)) diff --git a/tests/test_polling.py b/tests/test_polling.py index 59ab971d1..bb92ace3c 100644 --- a/tests/test_polling.py +++ b/tests/test_polling.py @@ -5,7 +5,7 @@ import pytest from src.runloop_api_client.lib.polling import PollingConfig, PollingTimeout, poll_until -from src.runloop_api_client.lib.cancellation import CancellationToken, PollingCancelled +from src.runloop_api_client.lib.cancellation import PollingCancelled, CancellationToken class TestPollingConfig: @@ -287,18 +287,17 @@ def test_cancellation_during_polling(self): retriever = Mock(side_effect=["value1", "value2", "value3"]) is_terminal = Mock(return_value=False) - call_count = 0 + wait_call_count = 0 - def cancel_on_second_call(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 2: + def cancel_on_second_wait(*_args: object, **_kwargs: object) -> bool: + nonlocal wait_call_count + wait_call_count += 1 + if wait_call_count == 2: token.cancel() - # Don't actually sleep, just check cancellation - if token.is_cancelled(): - raise PollingCancelled("Polling operation was cancelled") + return True + return False - with patch("time.sleep", side_effect=cancel_on_second_call): + with patch.object(token.sync_event, "wait", side_effect=cancel_on_second_wait): with pytest.raises(PollingCancelled): poll_until(retriever, is_terminal, cancellation_token=token) @@ -337,7 +336,7 @@ def test_no_cancellation_completes_normally(self): retriever = Mock(side_effect=["pending", "completed"]) is_terminal = Mock(side_effect=[False, True]) - with patch("time.sleep"): + with patch.object(token.sync_event, "wait", return_value=False): result = poll_until(retriever, is_terminal, cancellation_token=token) assert result == "completed" @@ -365,11 +364,11 @@ def test_cancellation_with_error_handler(self): def error_handler(_: Exception) -> str: return "handled" - def cancel_on_first_sleep(*args, **kwargs): + def cancel_on_first_wait(*_args: object, **_kwargs: object) -> bool: token.cancel() - raise PollingCancelled("Polling operation was cancelled") + return True - with patch("time.sleep", side_effect=cancel_on_first_sleep): + with patch.object(token.sync_event, "wait", side_effect=cancel_on_first_wait): with pytest.raises(PollingCancelled): poll_until(retriever, is_terminal, on_error=error_handler, cancellation_token=token) diff --git a/tests/test_polling_async.py b/tests/test_polling_async.py index 3bc18730f..62f1bdf65 100644 --- a/tests/test_polling_async.py +++ b/tests/test_polling_async.py @@ -1,14 +1,13 @@ """Tests for async polling with cancellation.""" import asyncio -from typing import Any from unittest.mock import Mock, AsyncMock, patch import pytest from src.runloop_api_client.lib.polling import PollingConfig, PollingTimeout +from src.runloop_api_client.lib.cancellation import PollingCancelled, CancellationToken from src.runloop_api_client.lib.polling_async import async_poll_until -from src.runloop_api_client.lib.cancellation import CancellationToken, PollingCancelled class TestAsyncPollUntil: @@ -119,17 +118,17 @@ async def test_cancellation_during_polling(self): retriever = AsyncMock(side_effect=["value1", "value2", "value3"]) is_terminal = Mock(return_value=False) - call_count = 0 + wait_call_count = 0 - async def cancel_on_second_call(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 2: + async def cancel_on_second_wait(*_args: object, **_kwargs: object) -> None: + nonlocal wait_call_count + wait_call_count += 1 + if wait_call_count == 2: token.cancel() - if token.is_cancelled(): - raise PollingCancelled("Polling operation was cancelled") + return None + raise asyncio.TimeoutError - with patch("asyncio.sleep", side_effect=cancel_on_second_call): + with patch("asyncio.wait_for", side_effect=cancel_on_second_wait): with pytest.raises(PollingCancelled): await async_poll_until(retriever, is_terminal, cancellation_token=token) @@ -165,9 +164,9 @@ async def test_no_cancellation_completes_normally(self): token = CancellationToken() retriever = AsyncMock(side_effect=["pending", "completed"]) is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.001) - with patch("asyncio.sleep", new_callable=AsyncMock): - result = await async_poll_until(retriever, is_terminal, cancellation_token=token) + result = await async_poll_until(retriever, is_terminal, config, cancellation_token=token) assert result == "completed" assert not token.is_cancelled() @@ -195,11 +194,12 @@ async def test_cancellation_with_error_handler(self): def error_handler(_: Exception) -> str: return "handled" - async def cancel_on_first_sleep(*args, **kwargs): + async def cancel_on_first_wait(awaitable: asyncio.Task[object], *_args: object, **_kwargs: object) -> None: + awaitable.cancel() token.cancel() - raise PollingCancelled("Polling operation was cancelled") + return None - with patch("asyncio.sleep", side_effect=cancel_on_first_sleep): + with patch("asyncio.wait_for", side_effect=cancel_on_first_wait): with pytest.raises(PollingCancelled): await async_poll_until(retriever, is_terminal, on_error=error_handler, cancellation_token=token) diff --git a/uv.lock b/uv.lock index e2e199a71..34fb7a894 100644 --- a/uv.lock +++ b/uv.lock @@ -2386,7 +2386,7 @@ wheels = [ [[package]] name = "runloop-api-client" -version = "1.14.0" +version = "1.14.1" source = { editable = "." } dependencies = [ { name = "anyio" }, From 14085b944b0e169c2a12423a1961c577155bcf7a Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Thu, 2 Apr 2026 11:55:45 -0700 Subject: [PATCH 08/10] fix: resolve merge conflicts and fix type errors in axons.py After merging origin/main, fixed type handling and test assertions: 1. Fixed type handling in subscribe_sse(): - Properly handle after_sequence parameter of type int | Omit - When last_sequence is provided (reconnection), use it as int - When user provides after_sequence, use it - When neither (default), use omit which gets filtered by transform 2. Fixed test assertions in test_axon_sse_reconnect.py: - Changed from `options["params"]["after_sequence"] is None` - To `"after_sequence" not in options["params"]` - Reason: transform() filters out Omit values, they don't become None 3. Created AGENTS.md: - Documents workflow for making code changes - Emphasizes running lint before and after changes - Includes common patterns and merge conflict resolution The merge combined two features: - Auto-reconnection (from feature branch) - User-provided after_sequence parameter (from main v1.15.0) Co-Authored-By: Claude Sonnet 4.5 --- AGENTS.md | 144 ++++++++++++++++++ .../resources/axons/axons.py | 14 +- tests/test_axon_sse_reconnect.py | 4 +- 3 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..51d76f4d4 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,144 @@ +# Agent Workflow Guidelines + +## Making Code Changes + +When making changes to the codebase, follow this workflow to ensure code quality: + +### 1. Run Lint Before Changes + +Before making any code changes, run the linter to establish a baseline: + +```bash +./scripts/lint +``` + +This runs: +- `pyright` - Type checking +- `mypy` - Additional type checking +- `ruff check` - Code linting +- `ruff format --check` - Format checking +- Import validation + +### 2. Make Your Changes + +Make the necessary code changes, ensuring you: +- Follow existing code patterns +- Update type annotations +- Handle edge cases (e.g., `Omit`, `NotGiven`, `None`) +- Maintain backward compatibility + +### 3. Update Tests + +Update or add tests for your changes: +- Unit tests in `tests/` +- Smoke tests in `tests/smoketests/` +- Ensure tests match the new behavior + +### 4. Run Lint After Changes + +After making changes, run the linter again to catch any issues: + +```bash +./scripts/lint +``` + +Fix any new errors or warnings that appear. + +### 5. Run Tests + +Run the test suite to ensure everything works: + +```bash +# Run all tests +uv run pytest + +# Run specific tests +uv run pytest tests/test_axon_sse_reconnect.py -xvs + +# Run smoke tests +uv run pytest tests/smoketests/ -m smoketest +``` + +### 6. Commit Changes + +Once lint and tests pass, commit your changes: + +```bash +git add -A +git commit -m "type: description + +Detailed explanation of changes + +Co-Authored-By: Claude Sonnet 4.5 " +``` + +## Common Patterns + +### Handling Optional Parameters + +When dealing with parameters that can be omitted: + +```python +# Parameter definition +def method(self, param: int | Omit = omit): + ... + +# In implementation, check for Omit before using +if not isinstance(param, Omit): + use_param(param) +else: + # param was omitted, use default behavior + pass +``` + +### Type-Safe Transformations + +The `transform` function automatically filters out `Omit` and `NotGiven` values: + +```python +# This dict with omitted values +{"field": omit, "other": 123} + +# Becomes this after transform +{"other": 123} +``` + +### Testing with Mocks + +When testing methods that call `_get` or similar: + +```python +with patch.object(client.resource, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + mock_get.return_value = mock_stream + + # Your test code + result = client.resource.method() + + # Verify the call + call_args = mock_get.call_args + options = call_args.kwargs["options"] + assert options["params"]["field"] == expected_value +``` + +## Merge Conflict Resolution + +When merging branches: + +1. **Understand both changes** - Read the diff from both sides +2. **Combine features intelligently** - Don't just pick one side +3. **Update tests** - Tests may need adjustments for merged behavior +4. **Fix type errors** - Merges can introduce type mismatches +5. **Validate syntax** - Run `python3 -m py_compile` on changed files +6. **Run full lint** - Ensure everything passes + +## Example: Recent Merge Fix + +The merge of `origin/main` into `feature/ts-pr-765-port` required: + +1. **Code fix** - Handle `Omit` type properly in reconnection logic +2. **Test fix** - Update assertions from `is None` to `not in dict` +3. **Syntax check** - Verify Python syntax is valid +4. **Type check** - Ensure mypy/pyright pass + +See commit history for the resolution pattern. diff --git a/src/runloop_api_client/resources/axons/axons.py b/src/runloop_api_client/resources/axons/axons.py index 31e6574e5..6dab5e7a6 100644 --- a/src/runloop_api_client/resources/axons/axons.py +++ b/src/runloop_api_client/resources/axons/axons.py @@ -311,7 +311,12 @@ def subscribe_sse( def create_stream(last_sequence: str | None) -> Stream[AxonEventView]: # Use user-provided after_sequence for initial stream, then use last_sequence for reconnections - sequence_to_use = after_sequence if last_sequence is None else int(last_sequence) + if last_sequence is not None: + sequence_to_use = int(last_sequence) + elif not isinstance(after_sequence, Omit): + sequence_to_use = after_sequence + else: + sequence_to_use = omit return self._get( path_template("/v1/axons/{id}/subscribe/sse", id=id), options=make_request_options( @@ -621,7 +626,12 @@ async def subscribe_sse( async def create_stream(last_sequence: str | None) -> AsyncStream[AxonEventView]: # Use user-provided after_sequence for initial stream, then use last_sequence for reconnections - sequence_to_use = after_sequence if last_sequence is None else int(last_sequence) + if last_sequence is not None: + sequence_to_use = int(last_sequence) + elif not isinstance(after_sequence, Omit): + sequence_to_use = after_sequence + else: + sequence_to_use = omit return await self._get( path_template("/v1/axons/{id}/subscribe/sse", id=id), options=make_request_options( diff --git a/tests/test_axon_sse_reconnect.py b/tests/test_axon_sse_reconnect.py index cc2575c64..5fafef802 100644 --- a/tests/test_axon_sse_reconnect.py +++ b/tests/test_axon_sse_reconnect.py @@ -161,7 +161,7 @@ def empty_iter(_self: object) -> Iterator[object]: assert options["headers"]["Accept"] == "text/event-stream" assert options["headers"]["X-Custom"] == "value" assert options["params"]["param"] == "value" - assert options["params"]["after_sequence"] is None + assert "after_sequence" not in options["params"] assert options["timeout"] == 30.0 @@ -322,7 +322,7 @@ async def mock_iter(): assert options["headers"]["Accept"] == "text/event-stream" assert options["headers"]["X-Custom"] == "value" assert options["params"]["param"] == "value" - assert options["params"]["after_sequence"] is None + assert "after_sequence" not in options["params"] assert options["timeout"] == 30.0 From 498a24a51d1ead3ca1c83fa80205b9a7422d7dbc Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Thu, 2 Apr 2026 13:39:35 -0700 Subject: [PATCH 09/10] fix: resolve mypy type errors in axons.py Fixed 3 mypy errors: 1. Removed duplicate import of axon_subscribe_sse_params on line 32 2. Added explicit type annotation for sequence_to_use: int | Omit (sync) 3. Added explicit type annotation for sequence_to_use: int | Omit (async) This allows the variable to hold either an int or an Omit value, which is then properly filtered by the transform() function. Co-Authored-By: Claude Sonnet 4.5 --- src/runloop_api_client/resources/axons/axons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/runloop_api_client/resources/axons/axons.py b/src/runloop_api_client/resources/axons/axons.py index 6dab5e7a6..22ac53960 100644 --- a/src/runloop_api_client/resources/axons/axons.py +++ b/src/runloop_api_client/resources/axons/axons.py @@ -29,7 +29,6 @@ from ..._constants import RAW_RESPONSE_HEADER from ..._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream from ...pagination import SyncAxonsCursorIDPage, AsyncAxonsCursorIDPage -from ...types.axons import axon_subscribe_sse_params from ..._base_client import AsyncPaginator, make_request_options from ...types.axon_view import AxonView from ...types.axon_event_view import AxonEventView @@ -311,6 +310,7 @@ def subscribe_sse( def create_stream(last_sequence: str | None) -> Stream[AxonEventView]: # Use user-provided after_sequence for initial stream, then use last_sequence for reconnections + sequence_to_use: int | Omit if last_sequence is not None: sequence_to_use = int(last_sequence) elif not isinstance(after_sequence, Omit): @@ -626,6 +626,7 @@ async def subscribe_sse( async def create_stream(last_sequence: str | None) -> AsyncStream[AxonEventView]: # Use user-provided after_sequence for initial stream, then use last_sequence for reconnections + sequence_to_use: int | Omit if last_sequence is not None: sequence_to_use = int(last_sequence) elif not isinstance(after_sequence, Omit): From 427cd6c559f614850ed528d856220d56d77698f7 Mon Sep 17 00:00:00 2001 From: Tony Deng Date: Thu, 2 Apr 2026 14:30:00 -0700 Subject: [PATCH 10/10] chore: apply code formatting and update workflow docs - Run ruff format to fix code style in axons.py - Update AGENTS.md to include format step in workflow - Document that format should be run after lint and before tests Co-Authored-By: Claude Sonnet 4.5 --- AGENTS.md | 16 +++++++++++++--- src/runloop_api_client/resources/axons/axons.py | 2 -- uv.lock | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 51d76f4d4..d75d9f1c9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -44,7 +44,17 @@ After making changes, run the linter again to catch any issues: Fix any new errors or warnings that appear. -### 5. Run Tests +### 5. Run Format + +Apply code formatting to ensure consistent style: + +```bash +ruff format . +``` + +This auto-formats all Python files to match the project's style guidelines. + +### 6. Run Tests Run the test suite to ensure everything works: @@ -59,9 +69,9 @@ uv run pytest tests/test_axon_sse_reconnect.py -xvs uv run pytest tests/smoketests/ -m smoketest ``` -### 6. Commit Changes +### 7. Commit Changes -Once lint and tests pass, commit your changes: +Once lint, format, and tests pass, commit your changes: ```bash git add -A diff --git a/src/runloop_api_client/resources/axons/axons.py b/src/runloop_api_client/resources/axons/axons.py index 22ac53960..8d9b5c2b6 100644 --- a/src/runloop_api_client/resources/axons/axons.py +++ b/src/runloop_api_client/resources/axons/axons.py @@ -352,7 +352,6 @@ def get_sequence(item: AxonEventView) -> str | None: ) - class AsyncAxonsResource(AsyncAPIResource): @cached_property def sql(self) -> AsyncSqlResource: @@ -668,7 +667,6 @@ def get_sequence(item: AxonEventView) -> str | None: ) - class AxonsResourceWithRawResponse: def __init__(self, axons: AxonsResource) -> None: self._axons = axons diff --git a/uv.lock b/uv.lock index 34fb7a894..afe2f32e5 100644 --- a/uv.lock +++ b/uv.lock @@ -2386,7 +2386,7 @@ wheels = [ [[package]] name = "runloop-api-client" -version = "1.14.1" +version = "1.15.0" source = { editable = "." } dependencies = [ { name = "anyio" },