Skip to content

Commit ed970e1

Browse files
committed
Reconnect stream on timeout
1 parent 742357d commit ed970e1

File tree

3 files changed

+403
-30
lines changed

3 files changed

+403
-30
lines changed

src/runloop_api_client/_streaming.py

Lines changed: 189 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,31 @@
44
import json
55
import inspect
66
from types import TracebackType
7-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
8-
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
Generic,
11+
TypeVar,
12+
Callable,
13+
Iterator,
14+
Optional,
15+
Awaitable,
16+
AsyncIterator,
17+
cast,
18+
)
19+
from typing_extensions import (
20+
Self,
21+
Protocol,
22+
TypeGuard,
23+
override,
24+
get_origin,
25+
runtime_checkable,
26+
)
927

1028
import httpx
1129

1230
from ._utils import extract_type_var_from_base
31+
from ._exceptions import APIStatusError, APITimeoutError
1332

1433
if TYPE_CHECKING:
1534
from ._client import Runloop, AsyncRunloop
@@ -55,6 +74,17 @@ def __stream__(self) -> Iterator[_T]:
5574
iterator = self._iter_events()
5675

5776
for sse in iterator:
77+
# Surface server-sent error events as API errors to allow callers to handle/retry
78+
if sse.event == "error":
79+
try:
80+
error_obj = json.loads(sse.data)
81+
status_code = int(error_obj.get("code", 500))
82+
# Build a synthetic response to mirror normal error handling
83+
fake_resp = httpx.Response(status_code, request=response.request, content=sse.data)
84+
except Exception:
85+
fake_resp = httpx.Response(500, request=response.request, content=sse.data)
86+
raise self._client._make_status_error_from_response(fake_resp)
87+
5888
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
5989

6090
# Ensure the entire stream is consumed
@@ -119,6 +149,17 @@ async def __stream__(self) -> AsyncIterator[_T]:
119149
iterator = self._iter_events()
120150

121151
async for sse in iterator:
152+
# Surface server-sent error events as API errors to allow callers to handle/retry
153+
if sse.event == "error":
154+
try:
155+
error_obj = json.loads(sse.data)
156+
status_code = int(error_obj.get("code", 500))
157+
# Build a synthetic response to mirror normal error handling
158+
fake_resp = httpx.Response(status_code, request=response.request, content=sse.data)
159+
except Exception:
160+
fake_resp = httpx.Response(500, request=response.request, content=sse.data)
161+
raise self._client._make_status_error_from_response(fake_resp)
162+
122163
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
123164

124165
# Ensure the entire stream is consumed
@@ -331,3 +372,149 @@ class MyStream(Stream[bytes]):
331372
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
332373
failure_message=failure_message,
333374
)
375+
376+
377+
class ReconnectingStream(Generic[_T]):
378+
"""Wraps a Stream with automatic reconnection on timeout (HTTP 408) or read timeouts.
379+
380+
The reconnection uses the last observed offset from each item, as provided by
381+
the given `get_offset` callback. The `stream_creator` will be called with the
382+
last known offset to resume the stream.
383+
"""
384+
385+
def __init__(
386+
self,
387+
*,
388+
current_stream: Stream[_T],
389+
stream_creator: Callable[[Optional[str]], Stream[_T]],
390+
get_offset: Callable[[_T], Optional[str]],
391+
) -> None:
392+
self._current_stream = current_stream
393+
self._stream_creator = stream_creator
394+
self._get_offset = get_offset
395+
self._last_offset: Optional[str] = None
396+
self._iterator = self.__stream__()
397+
398+
@property
399+
def response(self) -> httpx.Response:
400+
return self._current_stream.response
401+
402+
def __next__(self) -> _T:
403+
return self._iterator.__next__()
404+
405+
def __iter__(self) -> Iterator[_T]:
406+
for item in self._iterator:
407+
yield item
408+
409+
def __enter__(self) -> "ReconnectingStream[_T]":
410+
return self
411+
412+
def __exit__(
413+
self,
414+
exc_type: type[BaseException] | None,
415+
exc: BaseException | None,
416+
exc_tb: TracebackType | None,
417+
) -> None:
418+
self.close()
419+
420+
def close(self) -> None:
421+
self._current_stream.close()
422+
423+
def __stream__(self) -> Iterator[_T]:
424+
while True:
425+
try:
426+
for item in self._current_stream:
427+
offset = self._get_offset(item)
428+
if offset is not None:
429+
self._last_offset = offset
430+
yield item
431+
return
432+
except Exception as e:
433+
# Reconnect on timeouts
434+
should_reconnect = False
435+
if isinstance(e, APITimeoutError):
436+
should_reconnect = True
437+
elif isinstance(e, APIStatusError) and getattr(e, "status_code", None) == 408:
438+
should_reconnect = True
439+
elif isinstance(e, httpx.TimeoutException):
440+
should_reconnect = True
441+
442+
if should_reconnect:
443+
# Close existing response before reconnecting
444+
try:
445+
self._current_stream.close()
446+
except Exception:
447+
pass
448+
self._current_stream = self._stream_creator(self._last_offset)
449+
continue
450+
raise
451+
452+
453+
class AsyncReconnectingStream(Generic[_T]):
454+
"""Async variant of ReconnectingStream supporting auto-reconnect on timeouts."""
455+
456+
def __init__(
457+
self,
458+
*,
459+
current_stream: AsyncStream[_T],
460+
stream_creator: Callable[[Optional[str]], Awaitable[AsyncStream[_T]]],
461+
get_offset: Callable[[_T], Optional[str]],
462+
) -> None:
463+
self._current_stream = current_stream
464+
self._stream_creator = stream_creator
465+
self._get_offset = get_offset
466+
self._last_offset: Optional[str] = None
467+
self._iterator = self.__stream__()
468+
469+
@property
470+
def response(self) -> httpx.Response:
471+
return self._current_stream.response
472+
473+
async def __anext__(self) -> _T:
474+
return await self._iterator.__anext__()
475+
476+
async def __aiter__(self) -> AsyncIterator[_T]:
477+
async for item in self._iterator:
478+
yield item
479+
480+
async def __aenter__(self) -> "AsyncReconnectingStream[_T]":
481+
return self
482+
483+
async def __aexit__(
484+
self,
485+
exc_type: type[BaseException] | None,
486+
exc: BaseException | None,
487+
exc_tb: TracebackType | None,
488+
) -> None:
489+
await self.close()
490+
491+
async def close(self) -> None:
492+
await self._current_stream.close()
493+
494+
async def __stream__(self) -> AsyncIterator[_T]:
495+
while True:
496+
try:
497+
async for item in self._current_stream:
498+
offset = self._get_offset(item)
499+
if offset is not None:
500+
self._last_offset = offset
501+
yield item
502+
return
503+
except Exception as e:
504+
# Reconnect on timeouts
505+
should_reconnect = False
506+
if isinstance(e, APITimeoutError):
507+
should_reconnect = True
508+
elif isinstance(e, APIStatusError) and getattr(e, "status_code", None) == 408:
509+
should_reconnect = True
510+
elif isinstance(e, httpx.TimeoutException):
511+
should_reconnect = True
512+
513+
if should_reconnect:
514+
try:
515+
await self._current_stream.close()
516+
except Exception:
517+
pass
518+
self._current_stream = await self._stream_creator(self._last_offset)
519+
continue
520+
raise

src/runloop_api_client/resources/devboxes/executions.py

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Optional
5+
from typing import Optional, cast
66

77
import httpx
88

@@ -16,10 +16,10 @@
1616
async_to_raw_response_wrapper,
1717
async_to_streamed_response_wrapper,
1818
)
19-
from ..._constants import DEFAULT_TIMEOUT
19+
from ..._constants import DEFAULT_TIMEOUT, RAW_RESPONSE_HEADER
20+
from ..._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream
2021
from ..._exceptions import APIStatusError, APITimeoutError
2122
from ...lib.polling import PollingConfig, poll_until
22-
from ..._streaming import Stream, AsyncStream
2323
from ..._base_client import make_request_options
2424
from ...types.devboxes import (
2525
execution_kill_params,
@@ -357,18 +357,54 @@ def stream_updates(
357357
raise ValueError(f"Expected a non-empty value for `devbox_id` but received {devbox_id!r}")
358358
if not execution_id:
359359
raise ValueError(f"Expected a non-empty value for `execution_id` but received {execution_id!r}")
360-
return self._get(
361-
f"/v1/devboxes/{devbox_id}/executions/{execution_id}/stream_updates",
362-
options=make_request_options(
363-
extra_headers=extra_headers,
364-
extra_query=extra_query,
365-
extra_body=extra_body,
366-
timeout=timeout,
367-
query=maybe_transform({"offset": offset}, execution_stream_updates_params.ExecutionStreamUpdatesParams),
368-
),
369-
cast_to=DevboxAsyncExecutionDetailView,
370-
stream=True,
371-
stream_cls=Stream[ExecutionUpdateChunk],
360+
# If caller requested a raw or streaming response wrapper, return the underlying stream as-is
361+
if extra_headers and extra_headers.get(RAW_RESPONSE_HEADER):
362+
return self._get(
363+
f"/v1/devboxes/{devbox_id}/executions/{execution_id}/stream_updates",
364+
options=make_request_options(
365+
extra_headers=extra_headers,
366+
extra_query=extra_query,
367+
extra_body=extra_body,
368+
timeout=timeout,
369+
query=maybe_transform(
370+
{"offset": offset}, execution_stream_updates_params.ExecutionStreamUpdatesParams
371+
),
372+
),
373+
cast_to=DevboxAsyncExecutionDetailView,
374+
stream=True,
375+
stream_cls=Stream[ExecutionUpdateChunk],
376+
)
377+
378+
# Otherwise, wrap with auto-reconnect using last seen offset
379+
def create_stream(last_offset: str | None) -> Stream[ExecutionUpdateChunk]:
380+
new_offset = last_offset if last_offset is not None else (None if isinstance(offset, NotGiven) else offset)
381+
return self._get(
382+
f"/v1/devboxes/{devbox_id}/executions/{execution_id}/stream_updates",
383+
options=make_request_options(
384+
extra_headers=extra_headers,
385+
extra_query=extra_query,
386+
extra_body=extra_body,
387+
timeout=timeout,
388+
query=maybe_transform(
389+
{"offset": new_offset}, execution_stream_updates_params.ExecutionStreamUpdatesParams
390+
),
391+
),
392+
cast_to=DevboxAsyncExecutionDetailView,
393+
stream=True,
394+
stream_cls=Stream[ExecutionUpdateChunk],
395+
)
396+
397+
initial_stream = create_stream(None)
398+
399+
def get_offset(item: ExecutionUpdateChunk) -> str | None:
400+
value = getattr(item, "offset", None)
401+
if value is None:
402+
return None
403+
return str(value)
404+
405+
return cast(
406+
Stream[ExecutionUpdateChunk],
407+
ReconnectingStream(current_stream=initial_stream, stream_creator=create_stream, get_offset=get_offset),
372408
)
373409

374410

@@ -683,20 +719,53 @@ async def stream_updates(
683719
raise ValueError(f"Expected a non-empty value for `devbox_id` but received {devbox_id!r}")
684720
if not execution_id:
685721
raise ValueError(f"Expected a non-empty value for `execution_id` but received {execution_id!r}")
686-
return await self._get(
687-
f"/v1/devboxes/{devbox_id}/executions/{execution_id}/stream_updates",
688-
options=make_request_options(
689-
extra_headers=extra_headers,
690-
extra_query=extra_query,
691-
extra_body=extra_body,
692-
timeout=timeout,
693-
query=await async_maybe_transform(
694-
{"offset": offset}, execution_stream_updates_params.ExecutionStreamUpdatesParams
722+
# If caller requested a raw or streaming response wrapper, return the underlying stream as-is
723+
if extra_headers and extra_headers.get(RAW_RESPONSE_HEADER):
724+
return await self._get(
725+
f"/v1/devboxes/{devbox_id}/executions/{execution_id}/stream_updates",
726+
options=make_request_options(
727+
extra_headers=extra_headers,
728+
extra_query=extra_query,
729+
extra_body=extra_body,
730+
timeout=timeout,
731+
query=await async_maybe_transform(
732+
{"offset": offset}, execution_stream_updates_params.ExecutionStreamUpdatesParams
733+
),
695734
),
696-
),
697-
cast_to=DevboxAsyncExecutionDetailView,
698-
stream=True,
699-
stream_cls=AsyncStream[ExecutionUpdateChunk],
735+
cast_to=DevboxAsyncExecutionDetailView,
736+
stream=True,
737+
stream_cls=AsyncStream[ExecutionUpdateChunk],
738+
)
739+
740+
async def create_stream(last_offset: str | None) -> AsyncStream[ExecutionUpdateChunk]:
741+
new_offset = last_offset if last_offset is not None else (None if isinstance(offset, NotGiven) else offset)
742+
return await self._get(
743+
f"/v1/devboxes/{devbox_id}/executions/{execution_id}/stream_updates",
744+
options=make_request_options(
745+
extra_headers=extra_headers,
746+
extra_query=extra_query,
747+
extra_body=extra_body,
748+
timeout=timeout,
749+
query=await async_maybe_transform(
750+
{"offset": new_offset}, execution_stream_updates_params.ExecutionStreamUpdatesParams
751+
),
752+
),
753+
cast_to=DevboxAsyncExecutionDetailView,
754+
stream=True,
755+
stream_cls=AsyncStream[ExecutionUpdateChunk],
756+
)
757+
758+
initial_stream = await create_stream(None)
759+
760+
def get_offset(item: ExecutionUpdateChunk) -> str | None:
761+
value = getattr(item, "offset", None)
762+
if value is None:
763+
return None
764+
return str(value)
765+
766+
return cast(
767+
AsyncStream[ExecutionUpdateChunk],
768+
AsyncReconnectingStream(current_stream=initial_stream, stream_creator=create_stream, get_offset=get_offset),
700769
)
701770

702771

0 commit comments

Comments
 (0)