diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..d75d9f1c9 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,154 @@ +# Agent Workflow Guidelines + +## Making Code Changes + +When making changes to the codebase, follow this workflow to ensure code quality: + +### 1. Run Lint Before Changes + +Before making any code changes, run the linter to establish a baseline: + +```bash +./scripts/lint +``` + +This runs: +- `pyright` - Type checking +- `mypy` - Additional type checking +- `ruff check` - Code linting +- `ruff format --check` - Format checking +- Import validation + +### 2. Make Your Changes + +Make the necessary code changes, ensuring you: +- Follow existing code patterns +- Update type annotations +- Handle edge cases (e.g., `Omit`, `NotGiven`, `None`) +- Maintain backward compatibility + +### 3. Update Tests + +Update or add tests for your changes: +- Unit tests in `tests/` +- Smoke tests in `tests/smoketests/` +- Ensure tests match the new behavior + +### 4. Run Lint After Changes + +After making changes, run the linter again to catch any issues: + +```bash +./scripts/lint +``` + +Fix any new errors or warnings that appear. + +### 5. Run Format + +Apply code formatting to ensure consistent style: + +```bash +ruff format . +``` + +This auto-formats all Python files to match the project's style guidelines. + +### 6. Run Tests + +Run the test suite to ensure everything works: + +```bash +# Run all tests +uv run pytest + +# Run specific tests +uv run pytest tests/test_axon_sse_reconnect.py -xvs + +# Run smoke tests +uv run pytest tests/smoketests/ -m smoketest +``` + +### 7. Commit Changes + +Once lint, format, and tests pass, commit your changes: + +```bash +git add -A +git commit -m "type: description + +Detailed explanation of changes + +Co-Authored-By: Claude Sonnet 4.5 " +``` + +## Common Patterns + +### Handling Optional Parameters + +When dealing with parameters that can be omitted: + +```python +# Parameter definition +def method(self, param: int | Omit = omit): + ... + +# In implementation, check for Omit before using +if not isinstance(param, Omit): + use_param(param) +else: + # param was omitted, use default behavior + pass +``` + +### Type-Safe Transformations + +The `transform` function automatically filters out `Omit` and `NotGiven` values: + +```python +# This dict with omitted values +{"field": omit, "other": 123} + +# Becomes this after transform +{"other": 123} +``` + +### Testing with Mocks + +When testing methods that call `_get` or similar: + +```python +with patch.object(client.resource, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + mock_get.return_value = mock_stream + + # Your test code + result = client.resource.method() + + # Verify the call + call_args = mock_get.call_args + options = call_args.kwargs["options"] + assert options["params"]["field"] == expected_value +``` + +## Merge Conflict Resolution + +When merging branches: + +1. **Understand both changes** - Read the diff from both sides +2. **Combine features intelligently** - Don't just pick one side +3. **Update tests** - Tests may need adjustments for merged behavior +4. **Fix type errors** - Merges can introduce type mismatches +5. **Validate syntax** - Run `python3 -m py_compile` on changed files +6. **Run full lint** - Ensure everything passes + +## Example: Recent Merge Fix + +The merge of `origin/main` into `feature/ts-pr-765-port` required: + +1. **Code fix** - Handle `Omit` type properly in reconnection logic +2. **Test fix** - Update assertions from `is None` to `not in dict` +3. **Syntax check** - Verify Python syntax is valid +4. **Type check** - Ensure mypy/pyright pass + +See commit history for the resolution pattern. diff --git a/scripts/generate_examples_md.py b/scripts/generate_examples_md.py index 9e47e9cae..43e4f81fe 100644 --- a/scripts/generate_examples_md.py +++ b/scripts/generate_examples_md.py @@ -11,10 +11,10 @@ import re import sys import argparse -from typing import Any +from typing import Any, cast from pathlib import Path -import frontmatter # type: ignore[import-untyped] +import frontmatter # type: ignore[import-not-found, import-untyped] ROOT = Path(__file__).parent.parent EXAMPLES_DIR = ROOT / "examples" @@ -38,7 +38,7 @@ def parse_example(path: Path) -> dict[str, Any]: raise ValueError(f"{path}: docstring must start with frontmatter (---)") try: - post = frontmatter.loads(docstring) + post = cast(Any, frontmatter).loads(docstring) return dict(post.metadata) except Exception as e: raise ValueError(f"{path}: invalid frontmatter: {e}") from e diff --git a/src/runloop_api_client/_utils/_compat.py b/src/runloop_api_client/_utils/_compat.py index 2c70b299c..38820cabc 100644 --- a/src/runloop_api_client/_utils/_compat.py +++ b/src/runloop_api_client/_utils/_compat.py @@ -2,7 +2,7 @@ import sys import typing_extensions -from typing import Any, Type, Union, Literal, Optional +from typing import Any, Type, Union, Literal, Optional, cast from datetime import date, datetime from typing_extensions import get_args as _get_args, get_origin as _get_origin @@ -34,7 +34,8 @@ def is_typeddict(tp: Type[Any]) -> bool: def is_literal_type(tp: Type[Any]) -> bool: - return get_origin(tp) in _LITERAL_TYPES + origin = get_origin(tp) + return cast(Any, origin) in _LITERAL_TYPES def parse_date(value: Union[date, StrBytesIntFloat]) -> date: diff --git a/src/runloop_api_client/_utils/_utils.py b/src/runloop_api_client/_utils/_utils.py index eec7f4a1f..ed08804e2 100644 --- a/src/runloop_api_client/_utils/_utils.py +++ b/src/runloop_api_client/_utils/_utils.py @@ -372,8 +372,7 @@ def file_from_path(path: str) -> FileTypes: def get_required_header(headers: HeadersLike, header: str) -> str: lower_header = header.lower() if is_mapping_t(headers): - # mypy doesn't understand the type narrowing here - for k, v in headers.items(): # type: ignore + for k, v in cast(Mapping[str, object], headers).items(): if k.lower() == lower_header and isinstance(v, str): return v diff --git a/src/runloop_api_client/lib/cancellation.py b/src/runloop_api_client/lib/cancellation.py new file mode 100644 index 000000000..3e31a76b8 --- /dev/null +++ b/src/runloop_api_client/lib/cancellation.py @@ -0,0 +1,103 @@ +"""Cancellation support for polling operations.""" + +from __future__ import annotations + +import asyncio +import threading +from typing import TYPE_CHECKING + +from .._exceptions import RunloopError + +if TYPE_CHECKING: + pass + +__all__ = ["PollingCancelled", "CancellationToken"] + + +class PollingCancelled(RunloopError): + """Exception raised when a polling operation is cancelled.""" + + pass + + +class CancellationToken: + """Thread-safe cancellation token for polling operations. + + Similar to JavaScript's AbortSignal. Works in both sync and async contexts. + + Example (sync): + >>> token = CancellationToken() + >>> # In another thread: + >>> token.cancel() + >>> # In polling code: + >>> token.raise_if_cancelled() # Raises PollingCancelled + + Example (async): + >>> token = CancellationToken() + >>> # In another task: + >>> token.cancel() + >>> # In async polling code: + >>> await asyncio.wait_for(token.async_event.wait(), timeout=1.0) + """ + + def __init__(self) -> None: + """Create a new cancellation token.""" + self._cancelled = False + self._sync_event = threading.Event() + self._async_event: asyncio.Event | None = None + self._lock = threading.Lock() + + def cancel(self) -> None: + """Mark this token as cancelled. + + Thread-safe and can be called multiple times. Sets both sync and async events. + """ + with self._lock: + if self._cancelled: + return + self._cancelled = True + self._sync_event.set() + if self._async_event is not None: + self._async_event.set() + + def is_cancelled(self) -> bool: + """Check if this token has been cancelled. + + Returns: + True if cancel() has been called, False otherwise. + """ + return self._cancelled + + def raise_if_cancelled(self) -> None: + """Raise PollingCancelled if this token has been cancelled. + + Raises: + PollingCancelled: If cancel() has been called. + """ + if self._cancelled: + raise PollingCancelled("Polling operation was cancelled") + + @property + def sync_event(self) -> threading.Event: + """Get the synchronous event for cancellation checking. + + Returns: + threading.Event that is set when cancel() is called. + """ + return self._sync_event + + @property + def async_event(self) -> asyncio.Event: + """Get the asynchronous event for cancellation checking. + + Lazily creates the async event on first access. If cancel() was already called, + the event will be set immediately. + + Returns: + asyncio.Event that is set when cancel() is called. + """ + if self._async_event is None: + self._async_event = asyncio.Event() + if self._cancelled: + self._async_event.set() + return self._async_event diff --git a/src/runloop_api_client/lib/polling.py b/src/runloop_api_client/lib/polling.py index 899d2a9bf..95440f235 100644 --- a/src/runloop_api_client/lib/polling.py +++ b/src/runloop_api_client/lib/polling.py @@ -2,6 +2,8 @@ from typing import Any, TypeVar, Callable, Optional from dataclasses import dataclass +from .cancellation import CancellationToken + T = TypeVar("T") @@ -27,6 +29,7 @@ def poll_until( is_terminal: Callable[[T], bool], config: Optional[PollingConfig] = None, on_error: Optional[Callable[[Exception], T]] = None, + cancellation_token: Optional[CancellationToken] = None, ) -> T: """ Poll until a condition is met or timeout/max attempts are reached. @@ -37,12 +40,14 @@ def poll_until( config: Optional polling configuration on_error: Optional error handler that can return a value to continue polling or re-raise the exception to stop polling + cancellation_token: Optional token to cancel the polling operation Returns: The final state of the polled object Raises: PollingTimeout: When max attempts or timeout is reached + PollingCancelled: If cancellation_token.cancel() is called """ if config is None: config = PollingConfig() @@ -52,6 +57,10 @@ def poll_until( last_result = None while True: + # Check for cancellation before each iteration + if cancellation_token is not None: + cancellation_token.raise_if_cancelled() + try: last_result = retriever() except Exception as e: @@ -72,4 +81,9 @@ def poll_until( if elapsed >= config.timeout_seconds: raise PollingTimeout(f"Exceeded timeout of {config.timeout_seconds} seconds", last_result) - time.sleep(config.interval_seconds) + # Cancellable sleep + if cancellation_token is not None: + if cancellation_token.sync_event.wait(timeout=config.interval_seconds): + cancellation_token.raise_if_cancelled() + else: + time.sleep(config.interval_seconds) diff --git a/src/runloop_api_client/lib/polling_async.py b/src/runloop_api_client/lib/polling_async.py index 7ba192e86..e462991e2 100644 --- a/src/runloop_api_client/lib/polling_async.py +++ b/src/runloop_api_client/lib/polling_async.py @@ -1,8 +1,10 @@ import time import asyncio from typing import Union, TypeVar, Callable, Optional, Awaitable +from contextlib import suppress from .polling import PollingConfig, PollingTimeout +from .cancellation import CancellationToken T = TypeVar("T") @@ -12,6 +14,7 @@ async def async_poll_until( is_terminal: Callable[[T], bool], config: Optional[PollingConfig] = None, on_error: Optional[Callable[[Exception], T]] = None, + cancellation_token: Optional[CancellationToken] = None, ) -> T: """ Poll until a condition is met or timeout/max attempts are reached. @@ -22,12 +25,14 @@ async def async_poll_until( config: Optional polling configuration on_error: Optional error handler that can return a value to continue polling or re-raise the exception to stop polling + cancellation_token: Optional token to cancel the polling operation Returns: The final state of the polled object Raises: PollingTimeout: When max attempts or timeout is reached + PollingCancelled: If cancellation_token.cancel() is called """ if config is None: config = PollingConfig() @@ -37,6 +42,10 @@ async def async_poll_until( last_result: Union[T, None] = None while True: + # Check for cancellation before each iteration + if cancellation_token is not None: + cancellation_token.raise_if_cancelled() + try: last_result = await retriever() except Exception as e: @@ -57,4 +66,15 @@ async def async_poll_until( if elapsed >= config.timeout_seconds: raise PollingTimeout(f"Exceeded timeout of {config.timeout_seconds} seconds", last_result) - await asyncio.sleep(config.interval_seconds) + # Cancellable async sleep + if cancellation_token is not None: + wait_task = asyncio.create_task(cancellation_token.async_event.wait()) + try: + await asyncio.wait_for(wait_task, timeout=config.interval_seconds) + cancellation_token.raise_if_cancelled() + except asyncio.TimeoutError: + wait_task.cancel() + with suppress(asyncio.CancelledError): + await wait_task + else: + await asyncio.sleep(config.interval_seconds) diff --git a/src/runloop_api_client/resources/axons/axons.py b/src/runloop_api_client/resources/axons/axons.py index 977d2dd93..8d9b5c2b6 100644 --- a/src/runloop_api_client/resources/axons/axons.py +++ b/src/runloop_api_client/resources/axons/axons.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, cast from typing_extensions import Literal import httpx @@ -26,7 +26,8 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from ..._streaming import Stream, AsyncStream +from ..._constants import RAW_RESPONSE_HEADER +from ..._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream from ...pagination import SyncAxonsCursorIDPage, AsyncAxonsCursorIDPage from ..._base_client import AsyncPaginator, make_request_options from ...types.axon_view import AxonView @@ -270,6 +271,8 @@ def subscribe_sse( """ [Beta] Subscribe to an axon event stream via server-sent events. + Automatically reconnects on timeout, resuming from last received event. + Args: after_sequence: Sequence number after which to start streaming. Events with sequence > this value are returned. If unset, replay from the beginning. @@ -286,20 +289,66 @@ def subscribe_sse( raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") default_headers: Headers = {"Accept": "text/event-stream"} merged_headers = default_headers if extra_headers is None else {**default_headers, **extra_headers} - return self._get( - path_template("/v1/axons/{id}/subscribe/sse", id=id), - options=make_request_options( - extra_headers=merged_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform( - {"after_sequence": after_sequence}, axon_subscribe_sse_params.AxonSubscribeSseParams + + # Check if user wants raw response (opt-out of reconnection) + if extra_headers is not None and RAW_RESPONSE_HEADER in extra_headers: + return self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + {"after_sequence": after_sequence}, axon_subscribe_sse_params.AxonSubscribeSseParams + ), + ), + cast_to=AxonEventView, + stream=True, + stream_cls=Stream[AxonEventView], + ) + + def create_stream(last_sequence: str | None) -> Stream[AxonEventView]: + # Use user-provided after_sequence for initial stream, then use last_sequence for reconnections + sequence_to_use: int | Omit + if last_sequence is not None: + sequence_to_use = int(last_sequence) + elif not isinstance(after_sequence, Omit): + sequence_to_use = after_sequence + else: + sequence_to_use = omit + return self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + {"after_sequence": sequence_to_use}, + axon_subscribe_sse_params.AxonSubscribeSseParams, + ), ), + cast_to=AxonEventView, + stream=True, + stream_cls=Stream[AxonEventView], + ) + + initial_stream = create_stream(None) + + def get_sequence(item: AxonEventView) -> str | None: + value = getattr(item, "sequence", None) + if value is None: + return None + return str(value) + + return cast( + Stream[AxonEventView], + ReconnectingStream( + current_stream=initial_stream, + stream_creator=create_stream, + get_offset=get_sequence, ), - cast_to=AxonEventView, - stream=True, - stream_cls=Stream[AxonEventView], ) @@ -537,6 +586,8 @@ async def subscribe_sse( """ [Beta] Subscribe to an axon event stream via server-sent events. + Automatically reconnects on timeout, resuming from last received event. + Args: after_sequence: Sequence number after which to start streaming. Events with sequence > this value are returned. If unset, replay from the beginning. @@ -553,20 +604,66 @@ async def subscribe_sse( raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") default_headers: Headers = {"Accept": "text/event-stream"} merged_headers = default_headers if extra_headers is None else {**default_headers, **extra_headers} - return await self._get( - path_template("/v1/axons/{id}/subscribe/sse", id=id), - options=make_request_options( - extra_headers=merged_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=await async_maybe_transform( - {"after_sequence": after_sequence}, axon_subscribe_sse_params.AxonSubscribeSseParams + + # Check if user wants raw response (opt-out of reconnection) + if extra_headers is not None and RAW_RESPONSE_HEADER in extra_headers: + return await self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + {"after_sequence": after_sequence}, axon_subscribe_sse_params.AxonSubscribeSseParams + ), + ), + cast_to=AxonEventView, + stream=True, + stream_cls=AsyncStream[AxonEventView], + ) + + async def create_stream(last_sequence: str | None) -> AsyncStream[AxonEventView]: + # Use user-provided after_sequence for initial stream, then use last_sequence for reconnections + sequence_to_use: int | Omit + if last_sequence is not None: + sequence_to_use = int(last_sequence) + elif not isinstance(after_sequence, Omit): + sequence_to_use = after_sequence + else: + sequence_to_use = omit + return await self._get( + path_template("/v1/axons/{id}/subscribe/sse", id=id), + options=make_request_options( + extra_headers=merged_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + {"after_sequence": sequence_to_use}, + axon_subscribe_sse_params.AxonSubscribeSseParams, + ), ), + cast_to=AxonEventView, + stream=True, + stream_cls=AsyncStream[AxonEventView], + ) + + initial_stream = await create_stream(None) + + def get_sequence(item: AxonEventView) -> str | None: + value = getattr(item, "sequence", None) + if value is None: + return None + return str(value) + + return cast( + AsyncStream[AxonEventView], + AsyncReconnectingStream( + current_stream=initial_stream, + stream_creator=create_stream, + get_offset=get_sequence, ), - cast_to=AxonEventView, - stream=True, - stream_cls=AsyncStream[AxonEventView], ) diff --git a/src/runloop_api_client/resources/blueprints.py b/src/runloop_api_client/resources/blueprints.py index 7e5b09939..295cf0785 100644 --- a/src/runloop_api_client/resources/blueprints.py +++ b/src/runloop_api_client/resources/blueprints.py @@ -28,6 +28,7 @@ from .._exceptions import RunloopError from ..lib.polling import PollingConfig, poll_until from .._base_client import AsyncPaginator, make_request_options +from ..lib.cancellation import CancellationToken from ..lib.polling_async import async_poll_until from .._utils._validation import ValidationNotification from ..types.blueprint_view import BlueprintView @@ -41,6 +42,7 @@ # Type for request arguments that combine polling config with additional request options class BlueprintRequestArgs(TypedDict, total=False): polling_config: PollingConfig | None + cancellation_token: CancellationToken | None # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None @@ -280,6 +282,7 @@ def await_build_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -292,6 +295,7 @@ def await_build_complete( Args: id: The ID of the blueprint to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -302,6 +306,7 @@ def await_build_complete( Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ @@ -313,7 +318,12 @@ def retrieve_blueprint() -> BlueprintView: def is_done_building(blueprint: BlueprintView) -> bool: return blueprint.status not in ["queued", "building", "provisioning"] - blueprint = poll_until(retrieve_blueprint, is_done_building, polling_config) + blueprint = poll_until( + retrieve_blueprint, + is_done_building, + polling_config, + cancellation_token=cancellation_token, + ) if blueprint.status != "build_complete": raise RunloopError(f"Blueprint entered non-built terminal state: {blueprint.status}") @@ -338,6 +348,7 @@ def create_and_await_build_complete( services: Optional[Iterable[blueprint_create_params.Service]] | Omit = omit, system_setup_commands: Optional[SequenceNotStr[str]] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -353,12 +364,14 @@ def create_and_await_build_complete( Args: See the `create` method for detailed documentation. polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The built blueprint Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ # Pass all create_args to the underlying create method @@ -387,6 +400,7 @@ def create_and_await_build_complete( return self.await_build_complete( blueprint.id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, @@ -960,6 +974,7 @@ async def await_build_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -972,6 +987,7 @@ async def await_build_complete( Args: id: The ID of the blueprint to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -982,6 +998,7 @@ async def await_build_complete( Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ @@ -993,7 +1010,12 @@ async def retrieve_blueprint() -> BlueprintView: def is_done_building(blueprint: BlueprintView) -> bool: return blueprint.status not in ["queued", "building", "provisioning"] - blueprint = await async_poll_until(retrieve_blueprint, is_done_building, polling_config) + blueprint = await async_poll_until( + retrieve_blueprint, + is_done_building, + polling_config, + cancellation_token=cancellation_token, + ) if blueprint.status != "build_complete": raise RunloopError(f"Blueprint entered non-built terminal state: {blueprint.status}") @@ -1018,6 +1040,7 @@ async def create_and_await_build_complete( services: Optional[Iterable[blueprint_create_params.Service]] | Omit = omit, system_setup_commands: Optional[SequenceNotStr[str]] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -1033,12 +1056,14 @@ async def create_and_await_build_complete( Args: See the `create` method for detailed documentation. polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The built blueprint Raises: PollingTimeout: If polling times out before blueprint is built + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If blueprint enters a non-built terminal state """ # Pass all create_args to the underlying create method @@ -1067,6 +1092,7 @@ async def create_and_await_build_complete( return await self.await_build_complete( blueprint.id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, diff --git a/src/runloop_api_client/resources/devboxes/devboxes.py b/src/runloop_api_client/resources/devboxes/devboxes.py index 3e46b17b6..e383e8abe 100644 --- a/src/runloop_api_client/resources/devboxes/devboxes.py +++ b/src/runloop_api_client/resources/devboxes/devboxes.py @@ -97,6 +97,7 @@ DiskSnapshotsResourceWithStreamingResponse, AsyncDiskSnapshotsResourceWithStreamingResponse, ) +from ...lib.cancellation import CancellationToken from ...lib.polling_async import async_poll_until from ...types.devbox_view import DevboxView from ...types.tunnel_view import TunnelView @@ -401,12 +402,14 @@ def await_running( *, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to be in running state. Args: id: The ID of the devbox to wait for config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -417,6 +420,7 @@ def await_running( Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ @@ -443,7 +447,13 @@ def handle_timeout_error(error: Exception) -> DevboxView: def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - devbox = poll_until(wait_for_devbox_status, is_done_booting, polling_config, handle_timeout_error) + devbox = poll_until( + wait_for_devbox_status, + is_done_booting, + polling_config, + handle_timeout_error, + cancellation_token=cancellation_token, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -455,18 +465,21 @@ def await_suspended( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to reach the suspended state. Args: id: The ID of the devbox to wait for. polling_config: Optional polling configuration. + cancellation_token: Token to cancel the wait operation. Returns: The devbox in the suspended state. Raises: PollingTimeout: If polling times out before the devbox is suspended. + PollingCancelled: If cancellation_token.cancel() is called. RunloopError: If the devbox enters a non-suspended terminal state. """ @@ -487,7 +500,13 @@ def handle_timeout_error(error: Exception) -> DevboxView: def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = poll_until(wait_for_devbox_status, is_terminal_state, polling_config, handle_timeout_error) + devbox = poll_until( + wait_for_devbox_status, + is_terminal_state, + polling_config, + handle_timeout_error, + cancellation_token=cancellation_token, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") @@ -510,6 +529,7 @@ def create_and_await_running( mounts: Optional[Iterable[Mount]] | Omit = omit, name: Optional[str] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, repo_connection_id: Optional[str] | Omit = omit, secrets: Optional[Dict[str, str]] | Omit = omit, snapshot_id: Optional[str] | Omit = omit, @@ -535,6 +555,7 @@ def create_and_await_running( Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ # Pass all create_args to the underlying create method @@ -565,6 +586,7 @@ def create_and_await_running( return self.await_running( devbox.id, polling_config=polling_config, + cancellation_token=cancellation_token, ) def list( @@ -2001,6 +2023,7 @@ async def create_and_await_running( mounts: Optional[Iterable[Mount]] | Omit = omit, name: Optional[str] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, repo_connection_id: Optional[str] | Omit = omit, secrets: Optional[Dict[str, str]] | Omit = omit, snapshot_id: Optional[str] | Omit = omit, @@ -2020,12 +2043,14 @@ async def create_and_await_running( Args: See the `create` method for detailed documentation. polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The devbox in running state Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ @@ -2057,6 +2082,7 @@ async def create_and_await_running( return await self.await_running( devbox.id, polling_config=polling_config, + cancellation_token=cancellation_token, ) async def await_running( @@ -2065,12 +2091,14 @@ async def await_running( *, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to be in running state. Args: id: The ID of the devbox to wait for config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -2081,6 +2109,7 @@ async def await_running( Raises: PollingTimeout: If polling times out before devbox is running + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If devbox enters a non-running terminal state """ @@ -2105,7 +2134,12 @@ async def wait_for_devbox_status() -> DevboxView: def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - devbox = await async_poll_until(wait_for_devbox_status, is_done_booting, polling_config) + devbox = await async_poll_until( + wait_for_devbox_status, + is_done_booting, + polling_config, + cancellation_token=cancellation_token, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -2117,18 +2151,21 @@ async def await_suspended( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxView: """Wait for a devbox to reach the suspended state. Args: id: The ID of the devbox to wait for. polling_config: Optional polling configuration. + cancellation_token: Token to cancel the wait operation. Returns: The devbox in the suspended state. Raises: PollingTimeout: If polling times out before the devbox is suspended. + PollingCancelled: If cancellation_token.cancel() is called. RunloopError: If the devbox enters a non-suspended terminal state. """ @@ -2147,7 +2184,12 @@ async def wait_for_devbox_status() -> DevboxView: def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = await async_poll_until(wait_for_devbox_status, is_terminal_state, polling_config) + devbox = await async_poll_until( + wait_for_devbox_status, + is_terminal_state, + polling_config, + cancellation_token=cancellation_token, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") diff --git a/src/runloop_api_client/resources/devboxes/disk_snapshots.py b/src/runloop_api_client/resources/devboxes/disk_snapshots.py index c4f723359..76059ee2a 100644 --- a/src/runloop_api_client/resources/devboxes/disk_snapshots.py +++ b/src/runloop_api_client/resources/devboxes/disk_snapshots.py @@ -21,6 +21,7 @@ from ...lib.polling import PollingConfig, poll_until from ..._base_client import AsyncPaginator, make_request_options from ...types.devboxes import disk_snapshot_list_params, disk_snapshot_update_params +from ...lib.cancellation import CancellationToken from ...lib.polling_async import async_poll_until from ...types.devbox_snapshot_view import DevboxSnapshotView from ...types.devboxes.devbox_snapshot_async_status_view import DevboxSnapshotAsyncStatusView @@ -256,12 +257,31 @@ def await_completed( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DevboxSnapshotAsyncStatusView: - """Wait for a disk snapshot operation to complete.""" + """Wait for a disk snapshot operation to complete. + + Args: + id: The ID of the disk snapshot to wait for + polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation + extra_headers: Send extra headers + extra_query: Add additional query parameters to the request + extra_body: Add additional JSON properties to the request + timeout: Override the client-level default timeout for this request, in seconds + + Returns: + The completed snapshot status + + Raises: + PollingTimeout: If polling times out before snapshot completes + PollingCancelled: If cancellation_token.cancel() is called + RunloopError: If snapshot enters error state + """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") @@ -275,6 +295,7 @@ def is_terminal(result: DevboxSnapshotAsyncStatusView) -> bool: ), is_terminal, polling_config, + cancellation_token=cancellation_token, ) if status.status == "error": @@ -512,12 +533,31 @@ async def await_completed( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> DevboxSnapshotAsyncStatusView: - """Wait asynchronously for a disk snapshot operation to complete.""" + """Wait asynchronously for a disk snapshot operation to complete. + + Args: + id: The ID of the disk snapshot to wait for + polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation + extra_headers: Send extra headers + extra_query: Add additional query parameters to the request + extra_body: Add additional JSON properties to the request + timeout: Override the client-level default timeout for this request, in seconds + + Returns: + The completed snapshot status + + Raises: + PollingTimeout: If polling times out before snapshot completes + PollingCancelled: If cancellation_token.cancel() is called + RunloopError: If snapshot enters error state + """ if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") @@ -531,6 +571,7 @@ def is_terminal(result: DevboxSnapshotAsyncStatusView) -> bool: ), is_terminal, polling_config, + cancellation_token=cancellation_token, ) if status.status == "error": diff --git a/src/runloop_api_client/resources/devboxes/executions.py b/src/runloop_api_client/resources/devboxes/executions.py index ff7638798..9dd394365 100755 --- a/src/runloop_api_client/resources/devboxes/executions.py +++ b/src/runloop_api_client/resources/devboxes/executions.py @@ -32,6 +32,7 @@ execution_stream_stderr_updates_params, execution_stream_stdout_updates_params, ) +from ...lib.cancellation import CancellationToken from ...lib.polling_async import async_poll_until from ...types.devbox_send_std_in_result import DevboxSendStdInResult from ...types.devbox_execution_detail_view import DevboxExecutionDetailView @@ -124,6 +125,7 @@ def await_completed( *, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxAsyncExecutionDetailView: """Wait for an execution to complete. @@ -131,6 +133,7 @@ def await_completed( execution_id: The ID of the execution to wait for id: The ID of the devbox config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -141,6 +144,7 @@ def await_completed( Raises: PollingTimeout: If polling times out before execution completes + PollingCancelled: If cancellation_token.cancel() is called """ def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: @@ -165,7 +169,13 @@ def handle_timeout_error(error: Exception) -> DevboxAsyncExecutionDetailView: def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return poll_until(wait_for_execution_status, is_done, polling_config, handle_timeout_error) + return poll_until( + wait_for_execution_status, + is_done, + polling_config, + handle_timeout_error, + cancellation_token=cancellation_token, + ) def execute_async( self, @@ -670,6 +680,7 @@ async def await_completed( devbox_id: str, # Use polling_config to configure the "long" polling behavior. polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, ) -> DevboxAsyncExecutionDetailView: """Wait for an execution to complete. @@ -677,6 +688,7 @@ async def await_completed( execution_id: The ID of the execution to wait for id: The ID of the devbox polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -687,6 +699,7 @@ async def await_completed( Raises: PollingTimeout: If polling times out before execution completes + PollingCancelled: If cancellation_token.cancel() is called """ async def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: @@ -707,7 +720,12 @@ async def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return await async_poll_until(wait_for_execution_status, is_done, polling_config) + return await async_poll_until( + wait_for_execution_status, + is_done, + polling_config, + cancellation_token=cancellation_token, + ) async def execute_async( self, diff --git a/src/runloop_api_client/resources/scenarios/runs.py b/src/runloop_api_client/resources/scenarios/runs.py index 67c5c4428..38e61af32 100644 --- a/src/runloop_api_client/resources/scenarios/runs.py +++ b/src/runloop_api_client/resources/scenarios/runs.py @@ -27,6 +27,7 @@ from ...lib.polling import PollingConfig, poll_until from ..._base_client import AsyncPaginator, make_request_options from ...types.scenarios import run_list_params +from ...lib.cancellation import CancellationToken from ...lib.polling_async import async_poll_until from ...types.scenario_run_view import ScenarioRunView @@ -325,6 +326,7 @@ def await_scored( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -337,6 +339,7 @@ def await_scored( Args: id: The ID of the scenario run to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -347,6 +350,7 @@ def await_scored( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ @@ -358,7 +362,12 @@ def retrieve_run() -> ScenarioRunView: def is_done_scoring(run: ScenarioRunView) -> bool: return run.state not in ["scoring"] - run = poll_until(retrieve_run, is_done_scoring, polling_config) + run = poll_until( + retrieve_run, + is_done_scoring, + polling_config, + cancellation_token=cancellation_token, + ) if run.state != "scored": raise RunloopError(f"Scenario run entered non-scored state unexpectedly: {run.state}") @@ -370,6 +379,7 @@ def score_and_await( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -382,6 +392,7 @@ def score_and_await( Args: id: The ID of the scenario run to score and wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -392,6 +403,7 @@ def score_and_await( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ self.score( @@ -405,6 +417,7 @@ def score_and_await( return self.await_scored( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, @@ -416,6 +429,7 @@ def score_and_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -428,6 +442,7 @@ def score_and_complete( Args: id: The ID of the scenario run to score, wait for, and complete polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -443,6 +458,7 @@ def score_and_complete( self.score_and_await( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, @@ -750,6 +766,7 @@ async def await_scored( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -762,6 +779,7 @@ async def await_scored( Args: id: The ID of the scenario run to wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -772,6 +790,7 @@ async def await_scored( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ @@ -783,7 +802,12 @@ async def retrieve_run() -> ScenarioRunView: def is_done_scoring(run: ScenarioRunView) -> bool: return run.state not in ["scoring"] - run = await async_poll_until(retrieve_run, is_done_scoring, polling_config) + run = await async_poll_until( + retrieve_run, + is_done_scoring, + polling_config, + cancellation_token=cancellation_token, + ) if run.state != "scored": raise RunloopError(f"Scenario run entered non-scored state unexpectedly: {run.state}") @@ -795,6 +819,7 @@ async def score_and_await( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -807,6 +832,7 @@ async def score_and_await( Args: id: The ID of the scenario run to score and wait for polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -817,6 +843,7 @@ async def score_and_await( Raises: PollingTimeout: If polling times out before scenario run is scored + PollingCancelled: If cancellation_token.cancel() is called RunloopError: If scenario run enters a non-scored terminal state """ await self.score( @@ -830,6 +857,7 @@ async def score_and_await( return await self.await_scored( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, @@ -841,6 +869,7 @@ async def score_and_complete( id: str, *, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -853,6 +882,7 @@ async def score_and_complete( Args: id: The ID of the scenario run to score, wait for, and complete polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -868,6 +898,7 @@ async def score_and_complete( await self.score_and_await( id, polling_config=polling_config, + cancellation_token=cancellation_token, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, diff --git a/src/runloop_api_client/resources/scenarios/scenarios.py b/src/runloop_api_client/resources/scenarios/scenarios.py index e3ce8c91b..327f9a1a1 100644 --- a/src/runloop_api_client/resources/scenarios/scenarios.py +++ b/src/runloop_api_client/resources/scenarios/scenarios.py @@ -43,6 +43,7 @@ from ...pagination import SyncScenariosCursorIDPage, AsyncScenariosCursorIDPage from ...lib.polling import PollingConfig from ..._base_client import AsyncPaginator, make_request_options +from ...lib.cancellation import CancellationToken from ...types.scenario_view import ScenarioView from ...types.scenario_run_view import ScenarioRunView from ...types.input_context_param import InputContextParam @@ -527,6 +528,7 @@ def start_run_and_await_env_ready( run_name: Optional[str] | Omit = omit, run_profile: Optional[scenario_start_run_params.RunProfile] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -543,6 +545,7 @@ def start_run_and_await_env_ready( run_name: Display name of the run run_profile: Runtime configuration to use for this benchmark run polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request @@ -572,6 +575,7 @@ def start_run_and_await_env_ready( self._client.devboxes.await_running( run.devbox_id, polling_config=polling_config, + cancellation_token=cancellation_token, ) return run @@ -1048,6 +1052,7 @@ async def start_run_and_await_env_ready( run_name: Optional[str] | Omit = omit, run_profile: Optional[scenario_start_run_params.RunProfile] | Omit = omit, polling_config: PollingConfig | None = None, + cancellation_token: CancellationToken | None = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -1064,6 +1069,7 @@ async def start_run_and_await_env_ready( run_name: Display name of the run run_profile: Runtime configuration to use for this benchmark run polling_config: Optional polling configuration + cancellation_token: Token to cancel the wait operation Returns: The scenario run in running state @@ -1088,6 +1094,7 @@ async def start_run_and_await_env_ready( await self._client.devboxes.await_running( run.devbox_id, polling_config=polling_config, + cancellation_token=cancellation_token, ) return run diff --git a/src/runloop_api_client/sdk/_types.py b/src/runloop_api_client/sdk/_types.py index 4d2accc46..4f2075429 100644 --- a/src/runloop_api_client/sdk/_types.py +++ b/src/runloop_api_client/sdk/_types.py @@ -42,6 +42,7 @@ from ..lib.polling import PollingConfig from ..types.devboxes import DiskSnapshotListParams, DiskSnapshotUpdateParams from ..types.scenarios import ScorerListParams, ScorerCreateParams, ScorerUpdateParams +from ..lib.cancellation import CancellationToken from ..types.devbox_create_params import DevboxBaseCreateParams from ..types.axons.sql_batch_params import SqlBatchParams from ..types.axons.sql_query_params import SqlQueryParams @@ -86,6 +87,9 @@ class PollingRequestOptions(BaseRequestOptions, total=False): polling_config: Optional[PollingConfig] """Configuration for polling behavior""" + cancellation_token: Optional[CancellationToken] + """Token to cancel polling operations""" + class LongPollingRequestOptions(LongRequestOptions, PollingRequestOptions): # type: ignore[misc] pass diff --git a/src/runloop_api_client/sdk/async_scenario_run.py b/src/runloop_api_client/sdk/async_scenario_run.py index 314de676f..ad7ef6c69 100644 --- a/src/runloop_api_client/sdk/async_scenario_run.py +++ b/src/runloop_api_client/sdk/async_scenario_run.py @@ -106,7 +106,11 @@ async def await_env_ready( :return: Scenario run state after environment is ready :rtype: ScenarioRunView """ - await self._client.devboxes.await_running(self._devbox_id, polling_config=options.get("polling_config")) + await self._client.devboxes.await_running( + self._devbox_id, + polling_config=options.get("polling_config"), + cancellation_token=options.get("cancellation_token"), + ) return await self.get_info(**filter_params(options, BaseRequestOptions)) async def score( diff --git a/src/runloop_api_client/sdk/scenario_run.py b/src/runloop_api_client/sdk/scenario_run.py index ede44b105..14dd3a4c6 100644 --- a/src/runloop_api_client/sdk/scenario_run.py +++ b/src/runloop_api_client/sdk/scenario_run.py @@ -106,7 +106,11 @@ def await_env_ready( :return: Scenario run state after environment is ready :rtype: ScenarioRunView """ - self._client.devboxes.await_running(self._devbox_id, polling_config=options.get("polling_config")) + self._client.devboxes.await_running( + self._devbox_id, + polling_config=options.get("polling_config"), + cancellation_token=options.get("cancellation_token"), + ) return self.get_info(**filter_params(options, BaseRequestOptions)) def score( diff --git a/src/runloop_api_client/types/axons/__init__.py b/src/runloop_api_client/types/axons/__init__.py index 8ab8cf9b1..1a00da9e0 100644 --- a/src/runloop_api_client/types/axons/__init__.py +++ b/src/runloop_api_client/types/axons/__init__.py @@ -11,3 +11,4 @@ from .sql_step_result_view import SqlStepResultView as SqlStepResultView from .sql_batch_result_view import SqlBatchResultView as SqlBatchResultView from .sql_query_result_view import SqlQueryResultView as SqlQueryResultView +from .axon_subscribe_sse_params import AxonSubscribeSseParams as AxonSubscribeSseParams diff --git a/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py b/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py new file mode 100644 index 000000000..c88c2e3da --- /dev/null +++ b/src/runloop_api_client/types/axons/axon_subscribe_sse_params.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["AxonSubscribeSseParams"] + + +class AxonSubscribeSseParams(TypedDict, total=False): + after_sequence: int + """Resume SSE stream from events after this sequence number (used internally for reconnection)""" diff --git a/tests/api_resources/test_devboxes.py b/tests/api_resources/test_devboxes.py index c04de1d06..b33c8ea13 100644 --- a/tests/api_resources/test_devboxes.py +++ b/tests/api_resources/test_devboxes.py @@ -1386,7 +1386,7 @@ def test_method_create_and_await_running_success(self, client: Runloop) -> None: assert result.id == "test_id" assert result.status == "running" mock_create.assert_called_once() - mock_await.assert_called_once_with("test_id", polling_config=None) + mock_await.assert_called_once_with("test_id", polling_config=None, cancellation_token=None) @parametrize def test_method_create_and_await_running_with_config(self, client: Runloop) -> None: @@ -1426,7 +1426,7 @@ def test_method_create_and_await_running_with_config(self, client: Runloop) -> N assert result.id == "test_id" assert result.status == "running" - mock_await.assert_called_once_with("test_id", polling_config=config) + mock_await.assert_called_once_with("test_id", polling_config=config, cancellation_token=None) @parametrize def test_method_create_and_await_running_create_failure(self, client: Runloop) -> None: diff --git a/tests/sdk/test_async_scenario_run.py b/tests/sdk/test_async_scenario_run.py index c034524a0..290caf545 100644 --- a/tests/sdk/test_async_scenario_run.py +++ b/tests/sdk/test_async_scenario_run.py @@ -54,7 +54,11 @@ async def test_await_env_ready( run = AsyncScenarioRun(mock_async_client, "scr_123", "dbx_123") result = await run.await_env_ready() - mock_async_client.devboxes.await_running.assert_awaited_once_with("dbx_123", polling_config=None) + mock_async_client.devboxes.await_running.assert_awaited_once_with( + "dbx_123", + polling_config=None, + cancellation_token=None, + ) assert result == scenario_run_view async def test_score(self, mock_async_client: AsyncMock, scenario_run_view: MockScenarioRunView) -> None: diff --git a/tests/sdk/test_scenario_run.py b/tests/sdk/test_scenario_run.py index 339e365f8..82c004da8 100644 --- a/tests/sdk/test_scenario_run.py +++ b/tests/sdk/test_scenario_run.py @@ -51,7 +51,11 @@ def test_await_env_ready( run = ScenarioRun(mock_client, "scr_123", "dbx_123") result = run.await_env_ready() - mock_client.devboxes.await_running.assert_called_once_with("dbx_123", polling_config=None) + mock_client.devboxes.await_running.assert_called_once_with( + "dbx_123", + polling_config=None, + cancellation_token=None, + ) assert result == scenario_run_view def test_score(self, mock_client: Mock, scenario_run_view: MockScenarioRunView) -> None: diff --git a/tests/smoketests/sdk/test_list_pagination.py b/tests/smoketests/sdk/test_list_pagination.py new file mode 100644 index 000000000..72e2c14b2 --- /dev/null +++ b/tests/smoketests/sdk/test_list_pagination.py @@ -0,0 +1,306 @@ +"""Smoke tests to verify list methods respect limit parameter and only return one page. + +This test suite validates the fix for slow list endpoints, ensuring that +SDK list() methods return only the requested page of results instead of +auto-paginating through all available items. + +Related to TypeScript PR: https://github.com/runloopai/api-client-ts/pull/767 +""" + +from __future__ import annotations + +import pytest + +from runloop_api_client.sdk import RunloopSDK, AsyncRunloopSDK +from tests.smoketests.utils import unique_name +from runloop_api_client.types.shared_params import AgentSource + +pytestmark = pytest.mark.smoketest + +THIRTY_SECOND_TIMEOUT = 30 +AGENT_SOURCE: AgentSource = { + "type": "npm", + "npm": { + "package_name": "@runloop/hello-world-agent", + }, +} + + +class TestAsyncListPagination: + """Test async list methods respect limit and return only one page.""" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_agent_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify agent.list() with limit returns at most that many items.""" + # Request a small page + agents = await async_sdk_client.agent.list(limit=5) + + assert isinstance(agents, list) + # Should return at most 5 items (might be fewer if less data exists) + assert len(agents) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_agent_list_limit_one(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify agent.list() with limit=1 returns at most one item.""" + agents = await async_sdk_client.agent.list(limit=1) + + assert isinstance(agents, list) + assert len(agents) <= 1, "list(limit=1) should return at most 1 item" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_devbox_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify devbox.list() with limit returns at most that many items.""" + devboxes = await async_sdk_client.devbox.list(limit=3) + + assert isinstance(devboxes, list) + assert len(devboxes) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_blueprint_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify blueprint.list() with limit returns at most that many items.""" + blueprints = await async_sdk_client.blueprint.list(limit=5) + + assert isinstance(blueprints, list) + assert len(blueprints) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_storage_object_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify storage_object.list() with limit returns at most that many items.""" + objects = await async_sdk_client.storage_object.list(limit=4) + + assert isinstance(objects, list) + assert len(objects) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_snapshot_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify snapshot.list() with limit returns at most that many items.""" + snapshots = await async_sdk_client.snapshot.list(limit=2) + + assert isinstance(snapshots, list) + assert len(snapshots) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_axon_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify axon.list() with limit returns at most that many items.""" + axons = await async_sdk_client.axon.list(limit=3) + + assert isinstance(axons, list) + assert len(axons) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_scorer_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify scorer.list() with limit returns at most that many items.""" + scorers = await async_sdk_client.scorer.list(limit=5) + + assert isinstance(scorers, list) + assert len(scorers) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_scenario_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify scenario.list() with limit returns at most that many items.""" + scenarios = await async_sdk_client.scenario.list(limit=4) + + assert isinstance(scenarios, list) + assert len(scenarios) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_benchmark_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify benchmark.list() with limit returns at most that many items.""" + benchmarks = await async_sdk_client.benchmark.list(limit=3) + + assert isinstance(benchmarks, list) + assert len(benchmarks) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_network_policy_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify network_policy.list() with limit returns at most that many items.""" + policies = await async_sdk_client.network_policy.list(limit=5) + + assert isinstance(policies, list) + assert len(policies) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_gateway_config_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify gateway_config.list() with limit returns at most that many items.""" + configs = await async_sdk_client.gateway_config.list(limit=2) + + assert isinstance(configs, list) + assert len(configs) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_mcp_config_list_respects_limit(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify mcp_config.list() with limit returns at most that many items.""" + configs = await async_sdk_client.mcp_config.list(limit=3) + + assert isinstance(configs, list) + assert len(configs) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_secret_list_no_auto_pagination(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Verify secret.list() returns only one page (secrets don't have limit param).""" + secrets = await async_sdk_client.secret.list() + + # Secrets list doesn't have a limit parameter, but should still + # return only one page worth of results, not auto-paginate + assert isinstance(secrets, list) + + +class TestSyncListPagination: + """Test sync list methods respect limit and return only one page.""" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_agent_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify agent.list() with limit returns at most that many items.""" + agents = sdk_client.agent.list(limit=5) + + assert isinstance(agents, list) + assert len(agents) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_devbox_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify devbox.list() with limit returns at most that many items.""" + devboxes = sdk_client.devbox.list(limit=3) + + assert isinstance(devboxes, list) + assert len(devboxes) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_blueprint_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify blueprint.list() with limit returns at most that many items.""" + blueprints = sdk_client.blueprint.list(limit=5) + + assert isinstance(blueprints, list) + assert len(blueprints) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_storage_object_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify storage_object.list() with limit returns at most that many items.""" + objects = sdk_client.storage_object.list(limit=4) + + assert isinstance(objects, list) + assert len(objects) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_snapshot_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify snapshot.list() with limit returns at most that many items.""" + snapshots = sdk_client.snapshot.list(limit=2) + + assert isinstance(snapshots, list) + assert len(snapshots) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_axon_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify axon.list() with limit returns at most that many items.""" + axons = sdk_client.axon.list(limit=3) + + assert isinstance(axons, list) + assert len(axons) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_scorer_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify scorer.list() with limit returns at most that many items.""" + scorers = sdk_client.scorer.list(limit=5) + + assert isinstance(scorers, list) + assert len(scorers) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_scenario_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify scenario.list() with limit returns at most that many items.""" + scenarios = sdk_client.scenario.list(limit=4) + + assert isinstance(scenarios, list) + assert len(scenarios) <= 4, "list(limit=4) should return at most 4 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_benchmark_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify benchmark.list() with limit returns at most that many items.""" + benchmarks = sdk_client.benchmark.list(limit=3) + + assert isinstance(benchmarks, list) + assert len(benchmarks) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_network_policy_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify network_policy.list() with limit returns at most that many items.""" + policies = sdk_client.network_policy.list(limit=5) + + assert isinstance(policies, list) + assert len(policies) <= 5, "list(limit=5) should return at most 5 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_gateway_config_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify gateway_config.list() with limit returns at most that many items.""" + configs = sdk_client.gateway_config.list(limit=2) + + assert isinstance(configs, list) + assert len(configs) <= 2, "list(limit=2) should return at most 2 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_mcp_config_list_respects_limit(self, sdk_client: RunloopSDK) -> None: + """Verify mcp_config.list() with limit returns at most that many items.""" + configs = sdk_client.mcp_config.list(limit=3) + + assert isinstance(configs, list) + assert len(configs) <= 3, "list(limit=3) should return at most 3 items" + + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + def test_secret_list_no_auto_pagination(self, sdk_client: RunloopSDK) -> None: + """Verify secret.list() returns only one page.""" + secrets = sdk_client.secret.list() + + assert isinstance(secrets, list) + + +class TestListPaginationWithData: + """Test list pagination behavior when data is guaranteed to exist.""" + + @pytest.mark.asyncio + @pytest.mark.timeout(THIRTY_SECOND_TIMEOUT) + async def test_list_limit_with_created_data(self, async_sdk_client: AsyncRunloopSDK) -> None: + """Create multiple items and verify list limit works correctly.""" + # Create several agents to ensure we have data + created_agents: list[object] = [] + for i in range(5): + agent = await async_sdk_client.agent.create( + name=unique_name(f"sdk-list-test-{i}"), + version="1.0.0", + source=AGENT_SOURCE, + ) + created_agents.append(agent) + + try: + # Request only 2 items + listed_agents = await async_sdk_client.agent.list(limit=2) + + # Should get at most 2, even though 5+ exist + assert len(listed_agents) <= 2, ( + f"Expected at most 2 items with limit=2, got {len(listed_agents)}. " + "This indicates auto-pagination is occurring when it shouldn't." + ) + + # Request 10 items - should get all we created (5) plus any existing ones, + # but should stop at first page (up to 10) + listed_agents = await async_sdk_client.agent.list(limit=10) + assert len(listed_agents) <= 10, ( + f"Expected at most 10 items with limit=10, got {len(listed_agents)}. " + "This indicates auto-pagination is occurring when it shouldn't." + ) + + finally: + # Cleanup is not possible yet as agents don't have delete + pass diff --git a/tests/test_axon_sse_reconnect.py b/tests/test_axon_sse_reconnect.py new file mode 100644 index 000000000..5fafef802 --- /dev/null +++ b/tests/test_axon_sse_reconnect.py @@ -0,0 +1,342 @@ +"""Tests for Axon SSE auto-reconnect functionality.""" + +from typing import Any, Iterator, AsyncIterator, cast +from unittest.mock import Mock, AsyncMock, patch + +import httpx +import pytest + +from src.runloop_api_client._constants import RAW_RESPONSE_HEADER +from src.runloop_api_client._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream + + +class MockAxonEvent: + """Mock AxonEventView for testing.""" + + def __init__(self, sequence: int, data: str): + self.sequence = sequence + self.data = data + + +class TestAxonSSEReconnectSync: + """Test SSE reconnection for sync Axon subscriptions.""" + + def test_subscribe_sse_returns_reconnecting_stream(self): + """Test that subscribe_sse returns a ReconnectingStream.""" + from src.runloop_api_client import Runloop + + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: + # Mock the initial stream + mock_stream = Mock(spec=Stream) + mock_get.return_value = mock_stream + + result = client.axons.subscribe_sse("axon-123") + + # Should return a ReconnectingStream + assert isinstance(result, ReconnectingStream) + + def test_subscribe_sse_with_raw_header_returns_plain_stream(self): + """Test that RAW_RESPONSE_HEADER opts out of reconnection.""" + from src.runloop_api_client import Runloop + + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + mock_get.return_value = mock_stream + + result = client.axons.subscribe_sse("axon-123", extra_headers={RAW_RESPONSE_HEADER: "true"}) + + # Should return plain Stream, not ReconnectingStream + assert not isinstance(result, ReconnectingStream) + assert result == mock_stream + + def test_reconnection_uses_last_sequence(self): + """Test that reconnection uses the sequence from the last event.""" + from src.runloop_api_client import Runloop + + call_count = 0 + query_params: list[dict[str, object]] = [] + + def mock_get(*_args: object, **kwargs: Any) -> Mock: + nonlocal call_count + # Capture query params + options = cast(dict[str, object], kwargs.get("options", {})) + if "params" in options: + query_params.append(cast(dict[str, object], options["params"])) + + # First call: return stream with events + if call_count == 0: + call_count += 1 + mock_stream = Mock(spec=Stream) + + def first_iter(_self: object) -> Iterator[MockAxonEvent]: + yield MockAxonEvent(sequence=1, data="event1") + yield MockAxonEvent(sequence=2, data="event2") + raise httpx.ReadTimeout("stream timed out") + + mock_stream.__iter__ = first_iter + return mock_stream + + # Second call (reconnection): return stream continuing from sequence 2 + mock_stream = Mock(spec=Stream) + + def second_iter(_self: object) -> Iterator[MockAxonEvent]: + yield MockAxonEvent(sequence=3, data="event3") + + mock_stream.__iter__ = second_iter + return mock_stream + + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get", side_effect=mock_get): + stream = client.axons.subscribe_sse("axon-123") + + # Consume events + events = cast(list[MockAxonEvent], list(stream)) + + # Should have 3 events total (2 from first stream, 1 from reconnected stream) + assert len(events) == 3 + assert events[0].sequence == 1 + assert events[1].sequence == 2 + assert events[2].sequence == 3 + + # Check that second call used after_sequence parameter + # Note: first call has None, second call should have after_sequence=2 + assert len(query_params) >= 2 + + def test_sequence_extraction_handles_missing_sequence(self): + """Test that missing sequence fields are handled gracefully.""" + from src.runloop_api_client import Runloop + + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + + # Event without sequence attribute + class EventWithoutSequence: + pass + + def mock_iter(_self: object) -> Iterator[object]: + yield EventWithoutSequence() + + mock_stream.__iter__ = mock_iter + mock_get.return_value = mock_stream + + stream = client.axons.subscribe_sse("axon-123") + + # Should not crash, sequence extractor should return None + events = list(stream) + assert len(events) == 1 + + def test_subscribe_sse_preserves_request_options(self): + """Test that extra headers, query, etc. are preserved.""" + from src.runloop_api_client import Runloop + + client = Runloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get") as mock_get: + mock_stream = Mock(spec=Stream) + + def empty_iter(_self: object) -> Iterator[object]: + return iter([]) + + mock_stream.__iter__ = empty_iter + mock_get.return_value = mock_stream + + extra_headers = {"X-Custom": "value"} + extra_query = {"param": "value"} + + client.axons.subscribe_sse("axon-123", extra_headers=extra_headers, extra_query=extra_query, timeout=30.0) + + # Verify _get was called with the options + call_args = mock_get.call_args + options = call_args.kwargs["options"] + + # Headers should include Accept: text/event-stream and custom header + assert "Accept" in options["headers"] + assert options["headers"]["Accept"] == "text/event-stream" + assert options["headers"]["X-Custom"] == "value" + assert options["params"]["param"] == "value" + assert "after_sequence" not in options["params"] + assert options["timeout"] == 30.0 + + +class TestAxonSSEReconnectAsync: + """Test SSE reconnection for async Axon subscriptions.""" + + @pytest.mark.asyncio + async def test_subscribe_sse_returns_reconnecting_stream(self): + """Test that subscribe_sse returns an AsyncReconnectingStream.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*_args: object, **_kwargs: object) -> Mock: + mock_stream = Mock(spec=AsyncStream) + return mock_stream + + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): + result = await client.axons.subscribe_sse("axon-123") + + # Should return an AsyncReconnectingStream + assert isinstance(result, AsyncReconnectingStream) + + @pytest.mark.asyncio + async def test_subscribe_sse_with_raw_header_returns_plain_stream(self): + """Test that RAW_RESPONSE_HEADER opts out of reconnection.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*_args: object, **_kwargs: object) -> Mock: + mock_stream = Mock(spec=AsyncStream) + return mock_stream + + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): + result = await client.axons.subscribe_sse("axon-123", extra_headers={RAW_RESPONSE_HEADER: "true"}) + + # Should return plain AsyncStream, not AsyncReconnectingStream + assert not isinstance(result, AsyncReconnectingStream) + + @pytest.mark.asyncio + async def test_reconnection_uses_last_sequence(self): + """Test that reconnection uses the sequence from the last event.""" + from src.runloop_api_client import AsyncRunloop + + call_count = 0 + query_params: list[dict[str, object]] = [] + + async def mock_get(*_args: object, **kwargs: Any) -> Mock: + nonlocal call_count + # Capture query params + options = cast(dict[str, object], kwargs.get("options", {})) + if "params" in options: + query_params.append(cast(dict[str, object], options["params"])) + + # First call: return stream with events + if call_count == 0: + call_count += 1 + mock_stream = Mock(spec=AsyncStream) + + async def first_iter(_self: object) -> AsyncIterator[MockAxonEvent]: + yield MockAxonEvent(sequence=1, data="event1") + yield MockAxonEvent(sequence=2, data="event2") + raise httpx.ReadTimeout("stream timed out") + + mock_stream.__aiter__ = first_iter + return mock_stream + + # Second call (reconnection) + mock_stream = Mock(spec=AsyncStream) + + async def second_iter(_self: object) -> AsyncIterator[MockAxonEvent]: + yield MockAxonEvent(sequence=3, data="event3") + + mock_stream.__aiter__ = second_iter + return mock_stream + + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): + stream = await client.axons.subscribe_sse("axon-123") + + # Consume events + events: list[Any] = [] + async for event in stream: + events.append(event) + if len(events) >= 3: + break + + # Should have 3 events total + assert len(events) == 3 + assert events[0].sequence == 1 + assert events[1].sequence == 2 + assert events[2].sequence == 3 + + @pytest.mark.asyncio + async def test_sequence_extraction_handles_none(self): + """Test that None sequences are handled gracefully.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*_args: object, **_kwargs: object) -> Mock: + mock_stream = Mock(spec=AsyncStream) + + # Event with sequence = None + class EventWithNoneSequence: + sequence = None + + async def mock_iter(_self: object) -> AsyncIterator[object]: + yield EventWithNoneSequence() + + mock_stream.__aiter__ = mock_iter + return mock_stream + + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)): + stream = await client.axons.subscribe_sse("axon-123") + + # Should not crash + events: list[object] = [] + async for event in stream: + events.append(event) + break + + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_subscribe_sse_preserves_request_options(self): + """Test that extra headers, query, etc. are preserved in async.""" + from src.runloop_api_client import AsyncRunloop + + async def mock_get(*_args: object, **_kwargs: object) -> Mock: + mock_stream = Mock(spec=AsyncStream) + + async def mock_iter(): + return + yield # Make it async generator + + mock_stream.__aiter__ = mock_iter + return mock_stream + + client = AsyncRunloop(bearer_token="test-key", base_url="http://test") + + with patch.object(client.axons, "_get", new=AsyncMock(side_effect=mock_get)) as mock_get_method: + extra_headers = {"X-Custom": "value"} + extra_query = {"param": "value"} + + await client.axons.subscribe_sse( + "axon-123", extra_headers=extra_headers, extra_query=extra_query, timeout=30.0 + ) + + # Verify _get was called with the options + call_args = mock_get_method.call_args + options = call_args.kwargs["options"] + + # Headers should include Accept: text/event-stream and custom header + assert "Accept" in options["headers"] + assert options["headers"]["Accept"] == "text/event-stream" + assert options["headers"]["X-Custom"] == "value" + assert options["params"]["param"] == "value" + assert "after_sequence" not in options["params"] + assert options["timeout"] == 30.0 + + +class TestAxonSubscribeSseParams: + """Test AxonSubscribeSseParams TypedDict.""" + + def test_params_structure(self): + """Test that AxonSubscribeSseParams has the correct structure.""" + from src.runloop_api_client.types.axons import AxonSubscribeSseParams + + # Should be able to create with after_sequence + params: AxonSubscribeSseParams = {"after_sequence": 123} + assert params["after_sequence"] == 123 + + # The field is optional via total=False, so it can also be omitted. + params2: AxonSubscribeSseParams = {} + assert "after_sequence" not in params2 diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py new file mode 100644 index 000000000..0a24153da --- /dev/null +++ b/tests/test_cancellation.py @@ -0,0 +1,213 @@ +"""Tests for CancellationToken and PollingCancelled exception.""" + +import asyncio +import threading +from concurrent.futures import Future, ThreadPoolExecutor + +import pytest + +from src.runloop_api_client.lib.cancellation import PollingCancelled, CancellationToken + + +class TestPollingCancelled: + """Test PollingCancelled exception.""" + + def test_polling_cancelled_initialization(self): + """Test PollingCancelled exception initialization.""" + exception = PollingCancelled("Operation was cancelled") + assert "Operation was cancelled" in str(exception) + + def test_polling_cancelled_inherits_from_runloop_error(self): + """Test that PollingCancelled inherits from RunloopError.""" + from src.runloop_api_client._exceptions import RunloopError + + exception = PollingCancelled("Test") + assert isinstance(exception, RunloopError) + + +class TestCancellationToken: + """Test CancellationToken class.""" + + def test_initialization(self): + """Test token is not cancelled on initialization.""" + token = CancellationToken() + assert not token.is_cancelled() + + def test_cancel(self): + """Test cancelling a token.""" + token = CancellationToken() + token.cancel() + assert token.is_cancelled() + + def test_cancel_idempotent(self): + """Test that calling cancel() multiple times is safe.""" + token = CancellationToken() + token.cancel() + token.cancel() + token.cancel() + assert token.is_cancelled() + + def test_raise_if_cancelled_not_cancelled(self): + """Test raise_if_cancelled() when token is not cancelled.""" + token = CancellationToken() + # Should not raise + token.raise_if_cancelled() + + def test_raise_if_cancelled_when_cancelled(self): + """Test raise_if_cancelled() when token is cancelled.""" + token = CancellationToken() + token.cancel() + with pytest.raises(PollingCancelled, match="Polling operation was cancelled"): + token.raise_if_cancelled() + + def test_sync_event_property(self): + """Test sync_event property returns threading.Event.""" + token = CancellationToken() + event = token.sync_event + assert isinstance(event, threading.Event) + assert not event.is_set() + + token.cancel() + assert event.is_set() + + def test_sync_event_wait(self): + """Test sync_event can be used with wait().""" + token = CancellationToken() + event = token.sync_event + + # Should timeout since not cancelled + result = event.wait(timeout=0.01) + assert not result + + token.cancel() + # Should return immediately when cancelled + result = event.wait(timeout=1.0) + assert result + + @pytest.mark.asyncio + async def test_async_event_property(self): + """Test async_event property returns asyncio.Event.""" + token = CancellationToken() + event = token.async_event + assert isinstance(event, asyncio.Event) + assert not event.is_set() + + token.cancel() + assert event.is_set() + + @pytest.mark.asyncio + async def test_async_event_wait(self): + """Test async_event can be used with wait().""" + token = CancellationToken() + event = token.async_event + + # Should timeout since not cancelled + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(event.wait(), timeout=0.01) + + token.cancel() + # Should return immediately when cancelled + await asyncio.wait_for(event.wait(), timeout=1.0) + + @pytest.mark.asyncio + async def test_async_event_lazy_creation(self): + """Test that async_event is created lazily.""" + token = CancellationToken() + # Access _async_event directly to check it's None + assert token._async_event is None + + # Access property to trigger lazy creation + event = token.async_event + assert token._async_event is not None + assert event is token._async_event + + @pytest.mark.asyncio + async def test_async_event_set_if_already_cancelled(self): + """Test that async_event is set immediately if token was already cancelled.""" + token = CancellationToken() + token.cancel() + + # Async event should be set when created + event = token.async_event + assert event.is_set() + + def test_thread_safety(self): + """Test that CancellationToken is thread-safe.""" + token = CancellationToken() + results: list[bool] = [] + + def cancel_token(): + token.cancel() + results.append(token.is_cancelled()) + + def check_token(): + # Busy wait until cancelled + while not token.is_cancelled(): + pass + results.append(token.is_cancelled()) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures: list[Future[None]] = [] + # Start 4 checking threads + for _ in range(4): + futures.append(executor.submit(check_token)) + # Start 1 cancelling thread + futures.append(executor.submit(cancel_token)) + + # Wait for all to complete + for future in futures: + future.result(timeout=2.0) + + # All results should be True + assert all(results) + assert len(results) == 5 + + def test_sync_and_async_events_synchronized(self): + """Test that sync and async events are synchronized.""" + token = CancellationToken() + + # Get both events + sync_event = token.sync_event + async_event = token.async_event + + assert not sync_event.is_set() + assert not async_event.is_set() + + # Cancel token + token.cancel() + + # Both should be set + assert sync_event.is_set() + assert async_event.is_set() + + def test_multiple_tokens_independent(self): + """Test that multiple tokens are independent.""" + token1 = CancellationToken() + token2 = CancellationToken() + + token1.cancel() + + assert token1.is_cancelled() + assert not token2.is_cancelled() + + @pytest.mark.asyncio + async def test_async_cancellation_propagation(self): + """Test cancellation in async context.""" + token = CancellationToken() + + async def wait_for_cancellation(): + await token.async_event.wait() + return token.is_cancelled() + + # Start waiting task + task = asyncio.create_task(wait_for_cancellation()) + + # Give it a moment to start + await asyncio.sleep(0.01) + + # Cancel token + token.cancel() + + # Task should complete and return True + result = await asyncio.wait_for(task, timeout=1.0) + assert result is True diff --git a/tests/test_polling.py b/tests/test_polling.py index 74819531b..bb92ace3c 100644 --- a/tests/test_polling.py +++ b/tests/test_polling.py @@ -1,9 +1,11 @@ +import threading from typing import Any from unittest.mock import Mock, patch import pytest from src.runloop_api_client.lib.polling import PollingConfig, PollingTimeout, poll_until +from src.runloop_api_client.lib.cancellation import PollingCancelled, CancellationToken class TestPollingConfig: @@ -260,3 +262,173 @@ def dynamic_terminal(_: Any) -> bool: assert result == "value3" assert retriever.call_count == 3 + + +class TestPollUntilWithCancellation: + """Test poll_until function with CancellationToken.""" + + def test_cancellation_before_first_poll(self): + """Test cancellation before first poll attempt.""" + token = CancellationToken() + token.cancel() + + retriever = Mock(return_value="value") + is_terminal = Mock(return_value=False) + + with pytest.raises(PollingCancelled, match="Polling operation was cancelled"): + poll_until(retriever, is_terminal, cancellation_token=token) + + # Should not call retriever since cancelled before first attempt + assert retriever.call_count == 0 + + def test_cancellation_during_polling(self): + """Test cancellation during polling loop.""" + token = CancellationToken() + retriever = Mock(side_effect=["value1", "value2", "value3"]) + is_terminal = Mock(return_value=False) + + wait_call_count = 0 + + def cancel_on_second_wait(*_args: object, **_kwargs: object) -> bool: + nonlocal wait_call_count + wait_call_count += 1 + if wait_call_count == 2: + token.cancel() + return True + return False + + with patch.object(token.sync_event, "wait", side_effect=cancel_on_second_wait): + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, cancellation_token=token) + + # Should have called retriever twice before cancellation + assert retriever.call_count == 2 + + def test_cancellation_during_sleep(self): + """Test that cancellation wakes up from sleep.""" + token = CancellationToken() + retriever = Mock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=10.0) # Long sleep + + def cancel_after_delay(): + import time + + time.sleep(0.05) # Wait a bit + token.cancel() + + # Start cancellation in background thread + cancel_thread = threading.Thread(target=cancel_after_delay) + cancel_thread.start() + + # Poll should be interrupted by cancellation + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, config, cancellation_token=token) + + cancel_thread.join() + + # Should have attempted once before cancellation during sleep + assert retriever.call_count == 1 + + def test_no_cancellation_completes_normally(self): + """Test that polling completes normally without cancellation.""" + token = CancellationToken() + retriever = Mock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + + with patch.object(token.sync_event, "wait", return_value=False): + result = poll_until(retriever, is_terminal, cancellation_token=token) + + assert result == "completed" + assert not token.is_cancelled() + + def test_cancellation_after_completion_no_effect(self): + """Test that cancelling after completion has no effect.""" + token = CancellationToken() + retriever = Mock(return_value="completed") + is_terminal = Mock(return_value=True) + + result = poll_until(retriever, is_terminal, cancellation_token=token) + + # Cancel after completion + token.cancel() + + assert result == "completed" + + def test_cancellation_with_error_handler(self): + """Test cancellation works with error handler.""" + token = CancellationToken() + retriever = Mock(side_effect=[ValueError("error"), "value"]) + is_terminal = Mock(return_value=False) + + def error_handler(_: Exception) -> str: + return "handled" + + def cancel_on_first_wait(*_args: object, **_kwargs: object) -> bool: + token.cancel() + return True + + with patch.object(token.sync_event, "wait", side_effect=cancel_on_first_wait): + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, on_error=error_handler, cancellation_token=token) + + def test_cancellation_with_custom_config(self): + """Test cancellation with custom polling config.""" + token = CancellationToken() + retriever = Mock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=0.5, max_attempts=10) + + token.cancel() + + with pytest.raises(PollingCancelled): + poll_until(retriever, is_terminal, config, cancellation_token=token) + + def test_none_cancellation_token_works_normally(self): + """Test that passing None as cancellation_token works (backward compatibility).""" + retriever = Mock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + + with patch("time.sleep"): + result = poll_until(retriever, is_terminal, cancellation_token=None) + + assert result == "completed" + + def test_cancellable_sleep_blocks_correctly(self): + """Test that cancellable sleep blocks for the correct duration.""" + import time + + token = CancellationToken() + retriever = Mock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.1) + + start = time.time() + result = poll_until(retriever, is_terminal, config, cancellation_token=token) + elapsed = time.time() - start + + assert result == "completed" + # Should have slept approximately 0.1 seconds + assert 0.08 <= elapsed <= 0.15 # Allow some tolerance + + def test_multiple_cancellations_same_token(self): + """Test that the same token can be used for multiple poll operations.""" + token = CancellationToken() + + # First poll succeeds + retriever1 = Mock(return_value="done") + is_terminal1 = Mock(return_value=True) + result1 = poll_until(retriever1, is_terminal1, cancellation_token=token) + assert result1 == "done" + + # Cancel token + token.cancel() + + # Second poll should fail immediately + retriever2 = Mock(return_value="value") + is_terminal2 = Mock(return_value=False) + with pytest.raises(PollingCancelled): + poll_until(retriever2, is_terminal2, cancellation_token=token) + + # Second retriever should not be called + assert retriever2.call_count == 0 diff --git a/tests/test_polling_async.py b/tests/test_polling_async.py new file mode 100644 index 000000000..62f1bdf65 --- /dev/null +++ b/tests/test_polling_async.py @@ -0,0 +1,293 @@ +"""Tests for async polling with cancellation.""" + +import asyncio +from unittest.mock import Mock, AsyncMock, patch + +import pytest + +from src.runloop_api_client.lib.polling import PollingConfig, PollingTimeout +from src.runloop_api_client.lib.cancellation import PollingCancelled, CancellationToken +from src.runloop_api_client.lib.polling_async import async_poll_until + + +class TestAsyncPollUntil: + """Test async_poll_until function.""" + + @pytest.mark.asyncio + async def test_immediate_success(self): + """Test when condition is met on first attempt.""" + retriever = AsyncMock(return_value="completed") + is_terminal = Mock(return_value=True) + + result = await async_poll_until(retriever, is_terminal) + + assert result == "completed" + assert retriever.call_count == 1 + assert is_terminal.call_count == 1 + + @pytest.mark.asyncio + async def test_success_after_multiple_attempts(self): + """Test when condition is met after several attempts.""" + values = ["pending", "running", "completed"] + retriever = AsyncMock(side_effect=values) + is_terminal = Mock(side_effect=[False, False, True]) + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await async_poll_until(retriever, is_terminal) + + assert result == "completed" + assert retriever.call_count == 3 + assert is_terminal.call_count == 3 + assert mock_sleep.call_count == 2 + + @pytest.mark.asyncio + async def test_max_attempts_exceeded(self): + """Test when max attempts is exceeded.""" + retriever = AsyncMock(return_value="still_running") + is_terminal = Mock(return_value=False) + config = PollingConfig(max_attempts=3, interval_seconds=0.01) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(PollingTimeout) as exc_info: + await async_poll_until(retriever, is_terminal, config) + + assert "Exceeded maximum attempts (3)" in str(exc_info.value) + assert retriever.call_count == 3 + + @pytest.mark.asyncio + async def test_timeout_exceeded(self): + """Test when timeout is exceeded.""" + retriever = AsyncMock(return_value="still_running") + is_terminal = Mock(return_value=False) + config = PollingConfig(timeout_seconds=0.1, interval_seconds=0.01) + + start_time = 1000.0 + with patch("time.time", side_effect=[start_time, start_time + 0.05, start_time + 0.15]): + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(PollingTimeout) as exc_info: + await async_poll_until(retriever, is_terminal, config) + + assert "Exceeded timeout of 0.1 seconds" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_error_without_handler(self): + """Test that exceptions are re-raised when no error handler is provided.""" + retriever = AsyncMock(side_effect=ValueError("Test error")) + is_terminal = Mock(return_value=False) + + with pytest.raises(ValueError, match="Test error"): + await async_poll_until(retriever, is_terminal) + + @pytest.mark.asyncio + async def test_error_with_handler_continue(self): + """Test error handler that allows polling to continue.""" + retriever = AsyncMock(side_effect=[ValueError("Test error"), "recovered"]) + is_terminal = Mock(side_effect=[False, True]) + + def error_handler(_: Exception) -> str: + return "error_handled" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await async_poll_until(retriever, is_terminal, on_error=error_handler) + + assert result == "recovered" + assert retriever.call_count == 2 + + +class TestAsyncPollUntilWithCancellation: + """Test async_poll_until with CancellationToken.""" + + @pytest.mark.asyncio + async def test_cancellation_before_first_poll(self): + """Test cancellation before first poll attempt.""" + token = CancellationToken() + token.cancel() + + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + + with pytest.raises(PollingCancelled, match="Polling operation was cancelled"): + await async_poll_until(retriever, is_terminal, cancellation_token=token) + + assert retriever.call_count == 0 + + @pytest.mark.asyncio + async def test_cancellation_during_polling(self): + """Test cancellation during polling loop.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["value1", "value2", "value3"]) + is_terminal = Mock(return_value=False) + + wait_call_count = 0 + + async def cancel_on_second_wait(*_args: object, **_kwargs: object) -> None: + nonlocal wait_call_count + wait_call_count += 1 + if wait_call_count == 2: + token.cancel() + return None + raise asyncio.TimeoutError + + with patch("asyncio.wait_for", side_effect=cancel_on_second_wait): + with pytest.raises(PollingCancelled): + await async_poll_until(retriever, is_terminal, cancellation_token=token) + + assert retriever.call_count == 2 + + @pytest.mark.asyncio + async def test_cancellation_during_sleep(self): + """Test that cancellation wakes up from sleep.""" + token = CancellationToken() + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=10.0) # Long sleep + + async def cancel_after_delay(): + await asyncio.sleep(0.05) + token.cancel() + + # Start cancellation task + cancel_task = asyncio.create_task(cancel_after_delay()) + + # Poll should be interrupted by cancellation + with pytest.raises(PollingCancelled): + await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + + await cancel_task + + # Should have attempted once before cancellation + assert retriever.call_count == 1 + + @pytest.mark.asyncio + async def test_no_cancellation_completes_normally(self): + """Test that polling completes normally without cancellation.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.001) + + result = await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + + assert result == "completed" + assert not token.is_cancelled() + + @pytest.mark.asyncio + async def test_cancellation_after_completion_no_effect(self): + """Test that cancelling after completion has no effect.""" + token = CancellationToken() + retriever = AsyncMock(return_value="completed") + is_terminal = Mock(return_value=True) + + result = await async_poll_until(retriever, is_terminal, cancellation_token=token) + + token.cancel() + + assert result == "completed" + + @pytest.mark.asyncio + async def test_cancellation_with_error_handler(self): + """Test cancellation works with error handler.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=[ValueError("error"), "value"]) + is_terminal = Mock(return_value=False) + + def error_handler(_: Exception) -> str: + return "handled" + + async def cancel_on_first_wait(awaitable: asyncio.Task[object], *_args: object, **_kwargs: object) -> None: + awaitable.cancel() + token.cancel() + return None + + with patch("asyncio.wait_for", side_effect=cancel_on_first_wait): + with pytest.raises(PollingCancelled): + await async_poll_until(retriever, is_terminal, on_error=error_handler, cancellation_token=token) + + @pytest.mark.asyncio + async def test_none_cancellation_token_works_normally(self): + """Test that passing None as cancellation_token works (backward compatibility).""" + retriever = AsyncMock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await async_poll_until(retriever, is_terminal, cancellation_token=None) + + assert result == "completed" + + @pytest.mark.asyncio + async def test_cancellable_sleep_blocks_correctly(self): + """Test that cancellable sleep blocks for the correct duration.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.1) + + start = asyncio.get_event_loop().time() + result = await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + elapsed = asyncio.get_event_loop().time() - start + + assert result == "completed" + # Should have slept approximately 0.1 seconds + assert 0.08 <= elapsed <= 0.15 + + @pytest.mark.asyncio + async def test_concurrent_polling_with_shared_token(self): + """Test multiple concurrent polls with the same token.""" + token = CancellationToken() + + async def poll_task(): + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=0.01) + await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + + # Start multiple polling tasks + tasks = [asyncio.create_task(poll_task()) for _ in range(3)] + + # Give them time to start + await asyncio.sleep(0.05) + + # Cancel the shared token + token.cancel() + + # All tasks should raise PollingCancelled + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + assert isinstance(result, PollingCancelled) + + @pytest.mark.asyncio + async def test_cancellation_timeout_error_handling(self): + """Test that asyncio.TimeoutError during cancellable sleep is handled correctly.""" + token = CancellationToken() + retriever = AsyncMock(side_effect=["value1", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.1) + + # The cancellable sleep should handle TimeoutError correctly and continue + result = await async_poll_until(retriever, is_terminal, config, cancellation_token=token) + + assert result == "completed" + assert retriever.call_count == 2 + + @pytest.mark.asyncio + async def test_cancellation_from_different_task(self): + """Test that token can be cancelled from a different async task.""" + token = CancellationToken() + retriever = AsyncMock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(interval_seconds=1.0) + + async def cancel_from_other_task(): + await asyncio.sleep(0.1) + token.cancel() + + # Start both tasks + poll_task = asyncio.create_task(async_poll_until(retriever, is_terminal, config, cancellation_token=token)) + cancel_task = asyncio.create_task(cancel_from_other_task()) + + # Wait for both + with pytest.raises(PollingCancelled): + await poll_task + + await cancel_task diff --git a/uv.lock b/uv.lock index e2e199a71..afe2f32e5 100644 --- a/uv.lock +++ b/uv.lock @@ -2386,7 +2386,7 @@ wheels = [ [[package]] name = "runloop-api-client" -version = "1.14.0" +version = "1.15.0" source = { editable = "." } dependencies = [ { name = "anyio" },