|
4 | 4 | import json |
5 | 5 | import inspect |
6 | 6 | from types import TracebackType |
7 | | -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast |
8 | | -from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable |
| 7 | +from typing import ( |
| 8 | + TYPE_CHECKING, |
| 9 | + Any, |
| 10 | + Generic, |
| 11 | + TypeVar, |
| 12 | + Callable, |
| 13 | + Iterator, |
| 14 | + Optional, |
| 15 | + Awaitable, |
| 16 | + AsyncIterator, |
| 17 | + cast, |
| 18 | +) |
| 19 | +from typing_extensions import ( |
| 20 | + Self, |
| 21 | + Protocol, |
| 22 | + TypeGuard, |
| 23 | + override, |
| 24 | + get_origin, |
| 25 | + runtime_checkable, |
| 26 | +) |
9 | 27 |
|
10 | 28 | import httpx |
11 | 29 |
|
12 | 30 | from ._utils import extract_type_var_from_base |
| 31 | +from ._exceptions import APIStatusError, APITimeoutError |
13 | 32 |
|
14 | 33 | if TYPE_CHECKING: |
15 | 34 | from ._client import Runloop, AsyncRunloop |
@@ -55,6 +74,17 @@ def __stream__(self) -> Iterator[_T]: |
55 | 74 | iterator = self._iter_events() |
56 | 75 |
|
57 | 76 | for sse in iterator: |
| 77 | + # Surface server-sent error events as API errors to allow callers to handle/retry |
| 78 | + if sse.event == "error": |
| 79 | + try: |
| 80 | + error_obj = json.loads(sse.data) |
| 81 | + status_code = int(error_obj.get("code", 500)) |
| 82 | + # Build a synthetic response to mirror normal error handling |
| 83 | + fake_resp = httpx.Response(status_code, request=response.request, content=sse.data) |
| 84 | + except Exception: |
| 85 | + fake_resp = httpx.Response(500, request=response.request, content=sse.data) |
| 86 | + raise self._client._make_status_error_from_response(fake_resp) |
| 87 | + |
58 | 88 | yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
59 | 89 |
|
60 | 90 | # Ensure the entire stream is consumed |
@@ -119,6 +149,17 @@ async def __stream__(self) -> AsyncIterator[_T]: |
119 | 149 | iterator = self._iter_events() |
120 | 150 |
|
121 | 151 | async for sse in iterator: |
| 152 | + # Surface server-sent error events as API errors to allow callers to handle/retry |
| 153 | + if sse.event == "error": |
| 154 | + try: |
| 155 | + error_obj = json.loads(sse.data) |
| 156 | + status_code = int(error_obj.get("code", 500)) |
| 157 | + # Build a synthetic response to mirror normal error handling |
| 158 | + fake_resp = httpx.Response(status_code, request=response.request, content=sse.data) |
| 159 | + except Exception: |
| 160 | + fake_resp = httpx.Response(500, request=response.request, content=sse.data) |
| 161 | + raise self._client._make_status_error_from_response(fake_resp) |
| 162 | + |
122 | 163 | yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
123 | 164 |
|
124 | 165 | # Ensure the entire stream is consumed |
@@ -331,3 +372,149 @@ class MyStream(Stream[bytes]): |
331 | 372 | generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), |
332 | 373 | failure_message=failure_message, |
333 | 374 | ) |
| 375 | + |
| 376 | + |
| 377 | +class ReconnectingStream(Generic[_T]): |
| 378 | + """Wraps a Stream with automatic reconnection on timeout (HTTP 408) or read timeouts. |
| 379 | +
|
| 380 | + The reconnection uses the last observed offset from each item, as provided by |
| 381 | + the given `get_offset` callback. The `stream_creator` will be called with the |
| 382 | + last known offset to resume the stream. |
| 383 | + """ |
| 384 | + |
| 385 | + def __init__( |
| 386 | + self, |
| 387 | + *, |
| 388 | + current_stream: Stream[_T], |
| 389 | + stream_creator: Callable[[Optional[str]], Stream[_T]], |
| 390 | + get_offset: Callable[[_T], Optional[str]], |
| 391 | + ) -> None: |
| 392 | + self._current_stream = current_stream |
| 393 | + self._stream_creator = stream_creator |
| 394 | + self._get_offset = get_offset |
| 395 | + self._last_offset: Optional[str] = None |
| 396 | + self._iterator = self.__stream__() |
| 397 | + |
| 398 | + @property |
| 399 | + def response(self) -> httpx.Response: |
| 400 | + return self._current_stream.response |
| 401 | + |
| 402 | + def __next__(self) -> _T: |
| 403 | + return self._iterator.__next__() |
| 404 | + |
| 405 | + def __iter__(self) -> Iterator[_T]: |
| 406 | + for item in self._iterator: |
| 407 | + yield item |
| 408 | + |
| 409 | + def __enter__(self) -> "ReconnectingStream[_T]": |
| 410 | + return self |
| 411 | + |
| 412 | + def __exit__( |
| 413 | + self, |
| 414 | + exc_type: type[BaseException] | None, |
| 415 | + exc: BaseException | None, |
| 416 | + exc_tb: TracebackType | None, |
| 417 | + ) -> None: |
| 418 | + self.close() |
| 419 | + |
| 420 | + def close(self) -> None: |
| 421 | + self._current_stream.close() |
| 422 | + |
| 423 | + def __stream__(self) -> Iterator[_T]: |
| 424 | + while True: |
| 425 | + try: |
| 426 | + for item in self._current_stream: |
| 427 | + offset = self._get_offset(item) |
| 428 | + if offset is not None: |
| 429 | + self._last_offset = offset |
| 430 | + yield item |
| 431 | + return |
| 432 | + except Exception as e: |
| 433 | + # Reconnect on timeouts |
| 434 | + should_reconnect = False |
| 435 | + if isinstance(e, APITimeoutError): |
| 436 | + should_reconnect = True |
| 437 | + elif isinstance(e, APIStatusError) and getattr(e, "status_code", None) == 408: |
| 438 | + should_reconnect = True |
| 439 | + elif isinstance(e, httpx.TimeoutException): |
| 440 | + should_reconnect = True |
| 441 | + |
| 442 | + if should_reconnect: |
| 443 | + # Close existing response before reconnecting |
| 444 | + try: |
| 445 | + self._current_stream.close() |
| 446 | + except Exception: |
| 447 | + pass |
| 448 | + self._current_stream = self._stream_creator(self._last_offset) |
| 449 | + continue |
| 450 | + raise |
| 451 | + |
| 452 | + |
| 453 | +class AsyncReconnectingStream(Generic[_T]): |
| 454 | + """Async variant of ReconnectingStream supporting auto-reconnect on timeouts.""" |
| 455 | + |
| 456 | + def __init__( |
| 457 | + self, |
| 458 | + *, |
| 459 | + current_stream: AsyncStream[_T], |
| 460 | + stream_creator: Callable[[Optional[str]], Awaitable[AsyncStream[_T]]], |
| 461 | + get_offset: Callable[[_T], Optional[str]], |
| 462 | + ) -> None: |
| 463 | + self._current_stream = current_stream |
| 464 | + self._stream_creator = stream_creator |
| 465 | + self._get_offset = get_offset |
| 466 | + self._last_offset: Optional[str] = None |
| 467 | + self._iterator = self.__stream__() |
| 468 | + |
| 469 | + @property |
| 470 | + def response(self) -> httpx.Response: |
| 471 | + return self._current_stream.response |
| 472 | + |
| 473 | + async def __anext__(self) -> _T: |
| 474 | + return await self._iterator.__anext__() |
| 475 | + |
| 476 | + async def __aiter__(self) -> AsyncIterator[_T]: |
| 477 | + async for item in self._iterator: |
| 478 | + yield item |
| 479 | + |
| 480 | + async def __aenter__(self) -> "AsyncReconnectingStream[_T]": |
| 481 | + return self |
| 482 | + |
| 483 | + async def __aexit__( |
| 484 | + self, |
| 485 | + exc_type: type[BaseException] | None, |
| 486 | + exc: BaseException | None, |
| 487 | + exc_tb: TracebackType | None, |
| 488 | + ) -> None: |
| 489 | + await self.close() |
| 490 | + |
| 491 | + async def close(self) -> None: |
| 492 | + await self._current_stream.close() |
| 493 | + |
| 494 | + async def __stream__(self) -> AsyncIterator[_T]: |
| 495 | + while True: |
| 496 | + try: |
| 497 | + async for item in self._current_stream: |
| 498 | + offset = self._get_offset(item) |
| 499 | + if offset is not None: |
| 500 | + self._last_offset = offset |
| 501 | + yield item |
| 502 | + return |
| 503 | + except Exception as e: |
| 504 | + # Reconnect on timeouts |
| 505 | + should_reconnect = False |
| 506 | + if isinstance(e, APITimeoutError): |
| 507 | + should_reconnect = True |
| 508 | + elif isinstance(e, APIStatusError) and getattr(e, "status_code", None) == 408: |
| 509 | + should_reconnect = True |
| 510 | + elif isinstance(e, httpx.TimeoutException): |
| 511 | + should_reconnect = True |
| 512 | + |
| 513 | + if should_reconnect: |
| 514 | + try: |
| 515 | + await self._current_stream.close() |
| 516 | + except Exception: |
| 517 | + pass |
| 518 | + self._current_stream = await self._stream_creator(self._last_offset) |
| 519 | + continue |
| 520 | + raise |
0 commit comments