From a7fcbd81dbe6eded693d056fd921b1a8bb7aa36a Mon Sep 17 00:00:00 2001 From: Nelson PROIA Date: Thu, 12 Mar 2026 12:36:08 +0100 Subject: [PATCH] feat: adapt non-generated code for SSE overload support (stream param on complete methods) Update custom code regions, examples, and tests to use the new `complete(stream=True)` pattern alongside the existing dedicated `stream()` methods. Custom code regions (preserved by Speakeasy across regenerations): - chat.py: parse/parse_async now call complete(stream=False) with assert isinstance; parse_stream/parse_stream_async now call complete(stream=True) instead of stream() - conversations.py: run_async uses start_async(stream=False/True) and append_async(stream=False/True) instead of separate stream methods Examples (not generated): - All streaming examples updated to use complete(stream=True) pattern Tests (not generated): - Integration tests: use complete(stream=True) instead of stream() - Parity tests: add accept_header_override param, add stream/stream_async to known public methods, remove redundant stream-only test methods Note: Speakeasy-generated code is NOT included in this commit. The SDK was locally regenerated with specs that add text/event-stream as an alternative response on non-streaming operations (SSE overload pattern). Once the specs are published, the GitHub Actions workflow will regenerate the SDK code automatically. --- examples/mistral/audio/chat_streaming.py | 3 +- .../audio/transcription_segments_stream.py | 4 +- .../audio/transcription_stream_async.py | 4 +- .../mistral/chat/async_chat_with_streaming.py | 3 +- examples/mistral/chat/chat_with_streaming.py | 3 +- .../mistral/chat/chatbot_with_streaming.py | 4 +- .../mistral/chat/completion_with_streaming.py | 3 +- .../structured_outputs_with_json_schema.py | 3 +- src/mistralai/client/chat.py | 80 ++++-- src/mistralai/client/conversations.py | 227 +++++++++++++----- tests/test_azure_integration.py | 15 +- tests/test_azure_v2_parity.py | 48 +--- tests/test_gcp_integration.py | 21 +- tests/test_gcp_v2_parity.py | 102 +------- 14 files changed, 285 insertions(+), 235 deletions(-) diff --git a/examples/mistral/audio/chat_streaming.py b/examples/mistral/audio/chat_streaming.py index b418ef57..d9cbbf5e 100755 --- a/examples/mistral/audio/chat_streaming.py +++ b/examples/mistral/audio/chat_streaming.py @@ -19,7 +19,8 @@ def main(): print(f"Uploaded audio file, id={file.id}") signed_url = client.files.get_signed_url(file_id=file.id) try: - chat_response = client.chat.stream( + chat_response = client.chat.complete( + stream=True, model=model, messages=[ UserMessage( diff --git a/examples/mistral/audio/transcription_segments_stream.py b/examples/mistral/audio/transcription_segments_stream.py index 32edf951..2195857e 100644 --- a/examples/mistral/audio/transcription_segments_stream.py +++ b/examples/mistral/audio/transcription_segments_stream.py @@ -3,6 +3,7 @@ import os from mistralai.client import Mistral +from mistralai.client.transcriptions import CompleteAcceptEnum def main(): @@ -10,10 +11,11 @@ def main(): model = "voxtral-mini-latest" client = Mistral(api_key=api_key) - response = client.audio.transcriptions.stream( + response = client.audio.transcriptions.complete( model=model, file_url="https://docs.mistral.ai/audio/bcn_weather.mp3", timestamp_granularities=["segment"], + accept_header_override=CompleteAcceptEnum.TEXT_EVENT_STREAM, ) for chunk in response: print(chunk) diff --git a/examples/mistral/audio/transcription_stream_async.py b/examples/mistral/audio/transcription_stream_async.py index 3055f3de..12582cd5 100644 --- a/examples/mistral/audio/transcription_stream_async.py +++ b/examples/mistral/audio/transcription_stream_async.py @@ -4,6 +4,7 @@ from mistralai.client import Mistral from mistralai.client.models import File +from mistralai.client.transcriptions import CompleteAcceptEnum async def main(): @@ -12,9 +13,10 @@ async def main(): client = Mistral(api_key=api_key) with open("examples/fixtures/bcn_weather.mp3", "rb") as f: - response = await client.audio.transcriptions.stream_async( + response = await client.audio.transcriptions.complete_async( model=model, file=File(content=f, file_name=f.name), + accept_header_override=CompleteAcceptEnum.TEXT_EVENT_STREAM, ) async for chunk in response: print(chunk.event, chunk.data) diff --git a/examples/mistral/chat/async_chat_with_streaming.py b/examples/mistral/chat/async_chat_with_streaming.py index 1642ea41..1a3f1471 100755 --- a/examples/mistral/chat/async_chat_with_streaming.py +++ b/examples/mistral/chat/async_chat_with_streaming.py @@ -14,8 +14,9 @@ async def main(): client = Mistral(api_key=api_key) print("Chat response:") - response = await client.chat.stream_async( + response = await client.chat.complete_async( model=model, + stream=True, messages=[ UserMessage(content="What is the best French cheese?give the best 50") ], diff --git a/examples/mistral/chat/chat_with_streaming.py b/examples/mistral/chat/chat_with_streaming.py index 94a3e29c..ed9f0238 100755 --- a/examples/mistral/chat/chat_with_streaming.py +++ b/examples/mistral/chat/chat_with_streaming.py @@ -12,8 +12,9 @@ def main(): client = Mistral(api_key=api_key) - for chunk in client.chat.stream( + for chunk in client.chat.complete( model=model, + stream=True, messages=[UserMessage(content="What is the best French cheese?")], ): print(chunk.data.choices[0].delta.content, end="") diff --git a/examples/mistral/chat/chatbot_with_streaming.py b/examples/mistral/chat/chatbot_with_streaming.py index eae79dcf..f1eec3cd 100755 --- a/examples/mistral/chat/chatbot_with_streaming.py +++ b/examples/mistral/chat/chatbot_with_streaming.py @@ -150,8 +150,8 @@ def run_inference(self, content): f"Running inference with model: {self.model}, temperature: {self.temperature}" ) logger.debug(f"Sending messages: {self.messages}") - for chunk in self.client.chat.stream( - model=self.model, temperature=self.temperature, messages=self.messages + for chunk in self.client.chat.complete( + model=self.model, temperature=self.temperature, stream=True, messages=self.messages ): response = chunk.data.choices[0].delta.content if response is not None: diff --git a/examples/mistral/chat/completion_with_streaming.py b/examples/mistral/chat/completion_with_streaming.py index 399e8638..37a74aa1 100644 --- a/examples/mistral/chat/completion_with_streaming.py +++ b/examples/mistral/chat/completion_with_streaming.py @@ -15,8 +15,9 @@ async def main(): suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))" print(prompt) - for chunk in client.fim.stream( + for chunk in client.fim.complete( model="codestral-latest", + stream=True, prompt=prompt, suffix=suffix, ): diff --git a/examples/mistral/chat/structured_outputs_with_json_schema.py b/examples/mistral/chat/structured_outputs_with_json_schema.py index 2f99f747..1bec8afb 100644 --- a/examples/mistral/chat/structured_outputs_with_json_schema.py +++ b/examples/mistral/chat/structured_outputs_with_json_schema.py @@ -62,8 +62,9 @@ def main(): print(chat_response.choices[0].message.content) # Or with the streaming API - with client.chat.stream( + with client.chat.complete( model="mistral-large-latest", + stream=True, messages=[ { "role": "system", diff --git a/src/mistralai/client/chat.py b/src/mistralai/client/chat.py index 13b9c01f..8e8c22ec 100644 --- a/src/mistralai/client/chat.py +++ b/src/mistralai/client/chat.py @@ -2,6 +2,7 @@ # @generated-id: 7eba0f088d47 from .basesdk import BaseSDK +from enum import Enum from mistralai.client import errors, models, utils from mistralai.client._hooks import HookContext from mistralai.client.types import OptionalNullable, UNSET @@ -23,6 +24,11 @@ # endregion imports +class CompleteAcceptEnum(str, Enum): + APPLICATION_JSON = "application/json" + TEXT_EVENT_STREAM = "text/event-stream" + + class Chat(BaseSDK): r"""Chat Completion API.""" @@ -41,7 +47,9 @@ def parse( # Convert the input Pydantic Model to a strict JSON ready to be passed to chat.complete json_response_format = response_format_from_pydantic_model(response_format) # Run the inference + kwargs["stream"] = False response = self.complete(**kwargs, response_format=json_response_format) + assert isinstance(response, models.ChatCompletionResponse) # Parse response back to the input pydantic model parsed_response = convert_to_parsed_chat_completion_response( response, response_format @@ -58,9 +66,11 @@ async def parse_async( :return: The parsed response """ json_response_format = response_format_from_pydantic_model(response_format) + kwargs["stream"] = False response = await self.complete_async( # pylint: disable=E1125 **kwargs, response_format=json_response_format ) + assert isinstance(response, models.ChatCompletionResponse) parsed_response = convert_to_parsed_chat_completion_response( response, response_format ) @@ -73,11 +83,13 @@ def parse_stream( Parse the response using the provided response format. For now the response will be in JSON format not in the input Pydantic model. :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into - :param Any **kwargs Additional keyword arguments to pass to the .stream method + :param Any **kwargs Additional keyword arguments to pass to the .complete method :return: The JSON parsed response """ json_response_format = response_format_from_pydantic_model(response_format) - response = self.stream(**kwargs, response_format=json_response_format) + kwargs["stream"] = True + response = self.complete(**kwargs, response_format=json_response_format) + assert isinstance(response, eventstreaming.EventStream) return response async def parse_stream_async( @@ -87,13 +99,15 @@ async def parse_stream_async( Asynchronously parse the response using the provided response format. For now the response will be in JSON format not in the input Pydantic model. :param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into - :param Any **kwargs Additional keyword arguments to pass to the .stream method + :param Any **kwargs Additional keyword arguments to pass to the .complete method :return: The JSON parsed response """ json_response_format = response_format_from_pydantic_model(response_format) - response = await self.stream_async( # pylint: disable=E1125 + kwargs["stream"] = True + response = await self.complete_async( # pylint: disable=E1125 **kwargs, response_format=json_response_format ) + assert isinstance(response, eventstreaming.EventStreamAsync) return response # endregion sdk-class-body @@ -142,8 +156,9 @@ def complete( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[CompleteAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ChatCompletionResponse: + ) -> models.ChatCompletionV1ChatCompletionsPostResponse: r"""Chat Completion :param model: ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions. @@ -168,6 +183,7 @@ def complete( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -220,7 +236,9 @@ def complete( request_has_path_params=False, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -250,17 +268,29 @@ def complete( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ChatCompletionResponse, http_res) + http_res_text = utils.stream_to_text(http_res) + return unmarshal_json_response( + models.ChatCompletionResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStream( + http_res, + lambda raw: utils.unmarshal_json(raw, models.CompletionEvent), + sentinel="[DONE]", + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = utils.stream_to_text(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -268,7 +298,8 @@ def complete( http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = utils.stream_to_text(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) async def complete_async( self, @@ -314,8 +345,9 @@ async def complete_async( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[CompleteAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ChatCompletionResponse: + ) -> models.ChatCompletionV1ChatCompletionsPostResponse: r"""Chat Completion :param model: ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions. @@ -340,6 +372,7 @@ async def complete_async( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -392,7 +425,9 @@ async def complete_async( request_has_path_params=False, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -422,17 +457,29 @@ async def complete_async( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ChatCompletionResponse, http_res) + http_res_text = await utils.stream_to_text_async(http_res) + return unmarshal_json_response( + models.ChatCompletionResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStreamAsync( + http_res, + lambda raw: utils.unmarshal_json(raw, models.CompletionEvent), + sentinel="[DONE]", + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = await utils.stream_to_text_async(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -440,7 +487,8 @@ async def complete_async( http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = await utils.stream_to_text_async(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) def stream( self, diff --git a/src/mistralai/client/conversations.py b/src/mistralai/client/conversations.py index a4af31f3..fc99c292 100644 --- a/src/mistralai/client/conversations.py +++ b/src/mistralai/client/conversations.py @@ -2,6 +2,7 @@ # @generated-id: 40692a878064 from .basesdk import BaseSDK +from enum import Enum from mistralai.client import errors, models, utils from mistralai.client._hooks import HookContext from mistralai.client.types import OptionalNullable, UNSET @@ -38,6 +39,21 @@ # endregion imports +class StartAcceptEnum(str, Enum): + APPLICATION_JSON = "application/json" + TEXT_EVENT_STREAM = "text/event-stream" + + +class AppendAcceptEnum(str, Enum): + APPLICATION_JSON = "application/json" + TEXT_EVENT_STREAM = "text/event-stream" + + +class RestartAcceptEnum(str, Enum): + APPLICATION_JSON = "application/json" + TEXT_EVENT_STREAM = "text/event-stream" + + class Conversations(BaseSDK): r"""(beta) Conversations API""" @@ -88,6 +104,7 @@ async def run_async( if run_ctx.conversation_id is None: res = await self.start_async( inputs=input_entries, + stream=False, http_headers=http_headers, name=name, description=description, @@ -96,6 +113,7 @@ async def run_async( timeout_ms=timeout_ms, **req, # type: ignore ) + assert isinstance(res, models.ConversationResponse) run_result.conversation_id = res.conversation_id run_ctx.conversation_id = res.conversation_id logger.info( # pylint: disable=logging-fstring-interpolation @@ -105,10 +123,12 @@ async def run_async( res = await self.append_async( conversation_id=run_ctx.conversation_id, inputs=input_entries, + stream=False, retries=retries, server_url=server_url, timeout_ms=timeout_ms, ) + assert isinstance(res, models.ConversationResponse) run_ctx.request_count += 1 run_result.output_entries.extend(res.outputs) fcalls = get_function_calls(res.outputs) @@ -167,8 +187,9 @@ async def run_generator() -> ( int, list[ConversationEventsData] ] = defaultdict(list) if run_ctx.conversation_id is None: - res = await self.start_stream_async( + res = await self.start_async( inputs=current_entries, + stream=True, http_headers=http_headers, name=name, description=description, @@ -178,13 +199,15 @@ async def run_generator() -> ( **req, # type: ignore ) else: - res = await self.append_stream_async( + res = await self.append_async( conversation_id=run_ctx.conversation_id, inputs=current_entries, + stream=True, retries=retries, server_url=server_url, timeout_ms=timeout_ms, ) + assert isinstance(res, eventstreaming.EventStreamAsync) async for event in res: if ( isinstance(event.data, ResponseStartedEvent) @@ -235,12 +258,12 @@ def start( models.ConversationRequestHandoffExecution ] = UNSET, instructions: OptionalNullable[str] = UNSET, - tools: Optional[ + tools: OptionalNullable[ Union[ List[models.ConversationRequestTool], List[models.ConversationRequestToolTypedDict], ] - ] = None, + ] = UNSET, completion_args: OptionalNullable[ Union[models.CompletionArgs, models.CompletionArgsTypedDict] ] = UNSET, @@ -261,8 +284,9 @@ def start( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[StartAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ConversationResponse: + ) -> models.AgentsAPIV1ConversationsStartResponse: r"""Create a conversation and append entries to it. Create a new conversation, using a base model or an agent and append entries. Completion and tool executions are run and the response is appended to the conversation.Use the returned conversation_id to continue the conversation. @@ -272,7 +296,7 @@ def start( :param store: :param handoff_execution: :param instructions: - :param tools: List of tools which are available to the model during the conversation. + :param tools: :param completion_args: :param guardrails: :param name: @@ -284,6 +308,7 @@ def start( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -303,7 +328,7 @@ def start( handoff_execution=handoff_execution, instructions=instructions, tools=utils.get_pydantic_model( - tools, Optional[List[models.ConversationRequestTool]] + tools, OptionalNullable[List[models.ConversationRequestTool]] ), completion_args=utils.get_pydantic_model( completion_args, OptionalNullable[models.CompletionArgs] @@ -329,7 +354,9 @@ def start( request_has_path_params=False, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -359,17 +386,28 @@ def start( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ConversationResponse, http_res) + http_res_text = utils.stream_to_text(http_res) + return unmarshal_json_response( + models.ConversationResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStream( + http_res, + lambda raw: utils.unmarshal_json(raw, models.ConversationEvents), + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = utils.stream_to_text(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -377,7 +415,8 @@ def start( http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = utils.stream_to_text(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) async def start_async( self, @@ -389,12 +428,12 @@ async def start_async( models.ConversationRequestHandoffExecution ] = UNSET, instructions: OptionalNullable[str] = UNSET, - tools: Optional[ + tools: OptionalNullable[ Union[ List[models.ConversationRequestTool], List[models.ConversationRequestToolTypedDict], ] - ] = None, + ] = UNSET, completion_args: OptionalNullable[ Union[models.CompletionArgs, models.CompletionArgsTypedDict] ] = UNSET, @@ -415,8 +454,9 @@ async def start_async( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[StartAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ConversationResponse: + ) -> models.AgentsAPIV1ConversationsStartResponse: r"""Create a conversation and append entries to it. Create a new conversation, using a base model or an agent and append entries. Completion and tool executions are run and the response is appended to the conversation.Use the returned conversation_id to continue the conversation. @@ -426,7 +466,7 @@ async def start_async( :param store: :param handoff_execution: :param instructions: - :param tools: List of tools which are available to the model during the conversation. + :param tools: :param completion_args: :param guardrails: :param name: @@ -438,6 +478,7 @@ async def start_async( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -457,7 +498,7 @@ async def start_async( handoff_execution=handoff_execution, instructions=instructions, tools=utils.get_pydantic_model( - tools, Optional[List[models.ConversationRequestTool]] + tools, OptionalNullable[List[models.ConversationRequestTool]] ), completion_args=utils.get_pydantic_model( completion_args, OptionalNullable[models.CompletionArgs] @@ -483,7 +524,9 @@ async def start_async( request_has_path_params=False, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -513,17 +556,28 @@ async def start_async( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ConversationResponse, http_res) + http_res_text = await utils.stream_to_text_async(http_res) + return unmarshal_json_response( + models.ConversationResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStreamAsync( + http_res, + lambda raw: utils.unmarshal_json(raw, models.ConversationEvents), + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = await utils.stream_to_text_async(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -531,7 +585,8 @@ async def start_async( http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = await utils.stream_to_text_async(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) def list( self, @@ -1113,8 +1168,9 @@ def append( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[AppendAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ConversationResponse: + ) -> models.AgentsAPIV1ConversationsAppendResponse: r"""Append new entries to an existing conversation. Run completion on the history of the conversation and the user entries. Return the new created entries. @@ -1129,6 +1185,7 @@ def append( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -1170,7 +1227,9 @@ def append( request_has_path_params=True, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -1204,17 +1263,28 @@ def append( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ConversationResponse, http_res) + http_res_text = utils.stream_to_text(http_res) + return unmarshal_json_response( + models.ConversationResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStream( + http_res, + lambda raw: utils.unmarshal_json(raw, models.ConversationEvents), + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = utils.stream_to_text(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -1222,7 +1292,8 @@ def append( http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = utils.stream_to_text(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) async def append_async( self, @@ -1248,8 +1319,9 @@ async def append_async( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[AppendAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ConversationResponse: + ) -> models.AgentsAPIV1ConversationsAppendResponse: r"""Append new entries to an existing conversation. Run completion on the history of the conversation and the user entries. Return the new created entries. @@ -1264,6 +1336,7 @@ async def append_async( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -1305,7 +1378,9 @@ async def append_async( request_has_path_params=True, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -1339,17 +1414,28 @@ async def append_async( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ConversationResponse, http_res) + http_res_text = await utils.stream_to_text_async(http_res) + return unmarshal_json_response( + models.ConversationResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStreamAsync( + http_res, + lambda raw: utils.unmarshal_json(raw, models.ConversationEvents), + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = await utils.stream_to_text_async(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -1357,7 +1443,8 @@ async def append_async( http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = await utils.stream_to_text_async(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) def get_history( self, @@ -1748,8 +1835,9 @@ def restart( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[RestartAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ConversationResponse: + ) -> models.AgentsAPIV1ConversationsRestartResponse: r"""Restart a conversation starting from a given entry. Given a conversation_id and an id, recreate a conversation from this point and run completion. A new conversation is returned with the new entries returned. @@ -1767,6 +1855,7 @@ def restart( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -1810,7 +1899,9 @@ def restart( request_has_path_params=True, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -1844,17 +1935,28 @@ def restart( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ConversationResponse, http_res) + http_res_text = utils.stream_to_text(http_res) + return unmarshal_json_response( + models.ConversationResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStream( + http_res, + lambda raw: utils.unmarshal_json(raw, models.ConversationEvents), + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = utils.stream_to_text(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -1862,7 +1964,8 @@ def restart( http_res_text = utils.stream_to_text(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = utils.stream_to_text(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) async def restart_async( self, @@ -1893,8 +1996,9 @@ async def restart_async( retries: OptionalNullable[utils.RetryConfig] = UNSET, server_url: Optional[str] = None, timeout_ms: Optional[int] = None, + accept_header_override: Optional[RestartAcceptEnum] = None, http_headers: Optional[Mapping[str, str]] = None, - ) -> models.ConversationResponse: + ) -> models.AgentsAPIV1ConversationsRestartResponse: r"""Restart a conversation starting from a given entry. Given a conversation_id and an id, recreate a conversation from this point and run completion. A new conversation is returned with the new entries returned. @@ -1912,6 +2016,7 @@ async def restart_async( :param retries: Override the default retry configuration for this method :param server_url: Override the default server URL for this method :param timeout_ms: Override the default request timeout configuration for this method in milliseconds + :param accept_header_override: Override the default accept header for this method :param http_headers: Additional headers to set or replace on requests. """ base_url = None @@ -1955,7 +2060,9 @@ async def restart_async( request_has_path_params=True, request_has_query_params=True, user_agent_header="user-agent", - accept_header_value="application/json", + accept_header_value=accept_header_override.value + if accept_header_override is not None + else "application/json;q=1, text/event-stream;q=0", http_headers=http_headers, security=self.sdk_configuration.security, get_serialized_body=lambda: utils.serialize_request_body( @@ -1989,17 +2096,28 @@ async def restart_async( ), request=req, error_status_codes=["422", "4XX", "5XX"], + stream=True, retry_config=retry_config, ) response_data: Any = None if utils.match_response(http_res, "200", "application/json"): - return unmarshal_json_response(models.ConversationResponse, http_res) + http_res_text = await utils.stream_to_text_async(http_res) + return unmarshal_json_response( + models.ConversationResponse, http_res, http_res_text + ) + if utils.match_response(http_res, "200", "text/event-stream"): + return eventstreaming.EventStreamAsync( + http_res, + lambda raw: utils.unmarshal_json(raw, models.ConversationEvents), + client_ref=self, + ) if utils.match_response(http_res, "422", "application/json"): + http_res_text = await utils.stream_to_text_async(http_res) response_data = unmarshal_json_response( - errors.HTTPValidationErrorData, http_res + errors.HTTPValidationErrorData, http_res, http_res_text ) - raise errors.HTTPValidationError(response_data, http_res) + raise errors.HTTPValidationError(response_data, http_res, http_res_text) if utils.match_response(http_res, "4XX", "*"): http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) @@ -2007,7 +2125,8 @@ async def restart_async( http_res_text = await utils.stream_to_text_async(http_res) raise errors.SDKError("API error occurred", http_res, http_res_text) - raise errors.SDKError("Unexpected response received", http_res) + http_res_text = await utils.stream_to_text_async(http_res) + raise errors.SDKError("Unexpected response received", http_res, http_res_text) def start_stream( self, @@ -2019,12 +2138,12 @@ def start_stream( models.ConversationStreamRequestHandoffExecution ] = UNSET, instructions: OptionalNullable[str] = UNSET, - tools: Optional[ + tools: OptionalNullable[ Union[ List[models.ConversationStreamRequestTool], List[models.ConversationStreamRequestToolTypedDict], ] - ] = None, + ] = UNSET, completion_args: OptionalNullable[ Union[models.CompletionArgs, models.CompletionArgsTypedDict] ] = UNSET, @@ -2056,7 +2175,7 @@ def start_stream( :param store: :param handoff_execution: :param instructions: - :param tools: List of tools which are available to the model during the conversation. + :param tools: :param completion_args: :param guardrails: :param name: @@ -2087,7 +2206,7 @@ def start_stream( handoff_execution=handoff_execution, instructions=instructions, tools=utils.get_pydantic_model( - tools, Optional[List[models.ConversationStreamRequestTool]] + tools, OptionalNullable[List[models.ConversationStreamRequestTool]] ), completion_args=utils.get_pydantic_model( completion_args, OptionalNullable[models.CompletionArgs] @@ -2180,12 +2299,12 @@ async def start_stream_async( models.ConversationStreamRequestHandoffExecution ] = UNSET, instructions: OptionalNullable[str] = UNSET, - tools: Optional[ + tools: OptionalNullable[ Union[ List[models.ConversationStreamRequestTool], List[models.ConversationStreamRequestToolTypedDict], ] - ] = None, + ] = UNSET, completion_args: OptionalNullable[ Union[models.CompletionArgs, models.CompletionArgsTypedDict] ] = UNSET, @@ -2217,7 +2336,7 @@ async def start_stream_async( :param store: :param handoff_execution: :param instructions: - :param tools: List of tools which are available to the model during the conversation. + :param tools: :param completion_args: :param guardrails: :param name: @@ -2248,7 +2367,7 @@ async def start_stream_async( handoff_execution=handoff_execution, instructions=instructions, tools=utils.get_pydantic_model( - tools, Optional[List[models.ConversationStreamRequestTool]] + tools, OptionalNullable[List[models.ConversationStreamRequestTool]] ), completion_args=utils.get_pydantic_model( completion_args, OptionalNullable[models.CompletionArgs] diff --git a/tests/test_azure_integration.py b/tests/test_azure_integration.py index ac4e38a1..5b4d1cfc 100644 --- a/tests/test_azure_integration.py +++ b/tests/test_azure_integration.py @@ -239,7 +239,8 @@ class TestAzureChatStream: def test_basic_stream(self, azure_client): """Test streaming returns chunks with content.""" - stream = azure_client.chat.stream( + stream = azure_client.chat.complete( + stream=True, model=AZURE_MODEL, messages=[ {"role": "user", "content": "Say 'hello' and nothing else."} @@ -258,7 +259,8 @@ def test_basic_stream(self, azure_client): def test_stream_with_max_tokens(self, azure_client): """Test streaming respects max_tokens truncation.""" - stream = azure_client.chat.stream( + stream = azure_client.chat.complete( + stream=True, model=AZURE_MODEL, messages=[ {"role": "user", "content": "Count from 1 to 100."} @@ -280,7 +282,8 @@ def test_stream_with_max_tokens(self, azure_client): def test_stream_finish_reason(self, azure_client): """Test that the last chunk has a finish_reason.""" - stream = azure_client.chat.stream( + stream = azure_client.chat.complete( + stream=True, model=AZURE_MODEL, messages=[ {"role": "user", "content": "Say 'hi'."} @@ -301,7 +304,8 @@ def test_stream_finish_reason(self, azure_client): def test_stream_tool_call(self, azure_client): """Test tool call via streaming, collecting tool_call delta chunks.""" - stream = azure_client.chat.stream( + stream = azure_client.chat.complete( + stream=True, model=AZURE_MODEL, messages=[ {"role": "user", "content": "What is the weather in Paris?"} @@ -377,7 +381,8 @@ class TestAzureChatStreamAsync: @pytest.mark.asyncio async def test_basic_stream_async(self, azure_client): """Test async streaming returns chunks with content.""" - stream = await azure_client.chat.stream_async( + stream = await azure_client.chat.complete_async( + stream=True, model=AZURE_MODEL, messages=[ {"role": "user", "content": "Say 'hello' and nothing else."} diff --git a/tests/test_azure_v2_parity.py b/tests/test_azure_v2_parity.py index 8cd89bf4..3da7d0f5 100644 --- a/tests/test_azure_v2_parity.py +++ b/tests/test_azure_v2_parity.py @@ -15,7 +15,7 @@ from mistralai.azure.client.types import UNSET AZURE_METHODS: dict[str, set[str]] = { - "chat": {"complete", "stream"}, + "chat": {"complete"}, "ocr": {"process"}, } @@ -69,13 +69,10 @@ def mark_tested(resource: str, method: str) -> None: ("retries", UNSET), ("server_url", None), ("timeout_ms", None), + ("accept_header_override", None), ("http_headers", None), ] -CHAT_STREAM_PARAMS = [ - (name, True if name == "stream" else default) - for name, default in CHAT_COMPLETE_PARAMS -] OCR_PROCESS_PARAMS = [ ("model", _EMPTY), @@ -133,14 +130,6 @@ def test_has_complete_async(self): assert hasattr(Chat, "complete_async") mark_tested("chat", "complete_async") - def test_has_stream(self): - assert hasattr(Chat, "stream") - mark_tested("chat", "stream") - - def test_has_stream_async(self): - assert hasattr(Chat, "stream_async") - mark_tested("chat", "stream_async") - # -- complete params -- @pytest.mark.parametrize("param_name,expected_default", CHAT_COMPLETE_PARAMS) def test_complete_has_param(self, param_name, expected_default): @@ -151,16 +140,6 @@ def test_complete_has_param(self, param_name, expected_default): f"Chat.complete param {param_name}: expected {expected_default!r}, got {actual!r}" ) - # -- stream params -- - @pytest.mark.parametrize("param_name,expected_default", CHAT_STREAM_PARAMS) - def test_stream_has_param(self, param_name, expected_default): - sig = inspect.signature(Chat.stream) - assert param_name in sig.parameters, f"Chat.stream missing param: {param_name}" - actual = sig.parameters[param_name].default - assert actual == expected_default, ( - f"Chat.stream param {param_name}: expected {expected_default!r}, got {actual!r}" - ) - # -- complete_async matches complete -- @pytest.mark.parametrize("param_name,expected_default", CHAT_COMPLETE_PARAMS) def test_complete_async_has_param(self, param_name, expected_default): @@ -171,44 +150,21 @@ def test_complete_async_has_param(self, param_name, expected_default): f"Chat.complete_async param {param_name}: expected {expected_default!r}, got {actual!r}" ) - # -- stream_async matches stream -- - @pytest.mark.parametrize("param_name,expected_default", CHAT_STREAM_PARAMS) - def test_stream_async_has_param(self, param_name, expected_default): - sig = inspect.signature(Chat.stream_async) - assert param_name in sig.parameters, f"Chat.stream_async missing param: {param_name}" - actual = sig.parameters[param_name].default - assert actual == expected_default, ( - f"Chat.stream_async param {param_name}: expected {expected_default!r}, got {actual!r}" - ) - # -- sync/async parity -- def test_complete_async_matches_complete(self): sync_params = set(inspect.signature(Chat.complete).parameters) - {"self"} async_params = set(inspect.signature(Chat.complete_async).parameters) - {"self"} assert sync_params == async_params - def test_stream_async_matches_stream(self): - sync_params = set(inspect.signature(Chat.stream).parameters) - {"self"} - async_params = set(inspect.signature(Chat.stream_async).parameters) - {"self"} - assert sync_params == async_params - # -- key defaults -- def test_complete_model_defaults_azureai(self): sig = inspect.signature(Chat.complete) assert sig.parameters["model"].default == "azureai" - def test_stream_model_defaults_azureai(self): - sig = inspect.signature(Chat.stream) - assert sig.parameters["model"].default == "azureai" - def test_complete_stream_defaults_false(self): sig = inspect.signature(Chat.complete) assert sig.parameters["stream"].default is False - def test_stream_stream_defaults_true(self): - sig = inspect.signature(Chat.stream) - assert sig.parameters["stream"].default is True - class TestAzureOcr: def test_has_process(self): diff --git a/tests/test_gcp_integration.py b/tests/test_gcp_integration.py index fe24b8b0..4eafe838 100644 --- a/tests/test_gcp_integration.py +++ b/tests/test_gcp_integration.py @@ -228,8 +228,9 @@ class TestGCPChatStream: def test_basic_stream(self, gcp_client): """Test streaming returns chunks with content.""" - stream = gcp_client.chat.stream( + stream = gcp_client.chat.complete( model=GCP_MODEL, + stream=True, messages=[ {"role": "user", "content": "Say 'hello' and nothing else."} ], @@ -247,8 +248,9 @@ def test_basic_stream(self, gcp_client): def test_stream_with_max_tokens(self, gcp_client): """Test streaming respects max_tokens truncation.""" - stream = gcp_client.chat.stream( + stream = gcp_client.chat.complete( model=GCP_MODEL, + stream=True, messages=[ {"role": "user", "content": "Count from 1 to 100."} ], @@ -269,8 +271,9 @@ def test_stream_with_max_tokens(self, gcp_client): def test_stream_finish_reason(self, gcp_client): """Test that the last chunk has a finish_reason.""" - stream = gcp_client.chat.stream( + stream = gcp_client.chat.complete( model=GCP_MODEL, + stream=True, messages=[ {"role": "user", "content": "Say 'hi'."} ], @@ -290,8 +293,9 @@ def test_stream_finish_reason(self, gcp_client): def test_stream_tool_call(self, gcp_client): """Test tool call via streaming, collecting tool_call delta chunks.""" - stream = gcp_client.chat.stream( + stream = gcp_client.chat.complete( model=GCP_MODEL, + stream=True, messages=[ {"role": "user", "content": "What is the weather in Paris?"} ], @@ -366,8 +370,9 @@ class TestGCPChatStreamAsync: @pytest.mark.asyncio async def test_basic_stream_async(self, gcp_client): """Test async streaming returns chunks with content.""" - stream = await gcp_client.chat.stream_async( + stream = await gcp_client.chat.complete_async( model=GCP_MODEL, + stream=True, messages=[ {"role": "user", "content": "Say 'hello' and nothing else."} ], @@ -443,8 +448,9 @@ def test_fim_complete(self): def test_fim_stream(self): """Test FIM streaming returns chunks.""" client = self._make_fim_client() - stream = client.fim.stream( + stream = client.fim.complete( model=GCP_FIM_MODEL, + stream=True, prompt="def hello():", suffix=" return greeting", timeout_ms=10000, @@ -492,8 +498,9 @@ async def test_fim_complete_async(self): async def test_fim_stream_async(self): """Test async FIM streaming returns chunks.""" client = self._make_fim_client() - stream = await client.fim.stream_async( + stream = await client.fim.complete_async( model=GCP_FIM_MODEL, + stream=True, prompt="def hello():", suffix=" return greeting", timeout_ms=10000, diff --git a/tests/test_gcp_v2_parity.py b/tests/test_gcp_v2_parity.py index 0d6471e4..f671bf33 100644 --- a/tests/test_gcp_v2_parity.py +++ b/tests/test_gcp_v2_parity.py @@ -15,8 +15,8 @@ from mistralai.gcp.client.types import UNSET GCP_METHODS: dict[str, set[str]] = { - "chat": {"complete", "stream"}, - "fim": {"complete", "stream"}, + "chat": {"complete"}, + "fim": {"complete"}, } TESTED_METHODS: set[str] = set() @@ -69,14 +69,10 @@ def mark_tested(resource: str, method: str) -> None: ("retries", UNSET), ("server_url", None), ("timeout_ms", None), + ("accept_header_override", None), ("http_headers", None), ] -CHAT_STREAM_PARAMS = [ - (name, True if name == "stream" else default) - for name, default in CHAT_COMPLETE_PARAMS -] - FIM_COMPLETE_PARAMS = [ ("model", _EMPTY), ("prompt", _EMPTY), @@ -92,14 +88,10 @@ def mark_tested(resource: str, method: str) -> None: ("retries", UNSET), ("server_url", None), ("timeout_ms", None), + ("accept_header_override", None), ("http_headers", None), ] -FIM_STREAM_PARAMS = [ - (name, True if name == "stream" else default) - for name, default in FIM_COMPLETE_PARAMS -] - # --------------------------------------------------------------------------- # Tests @@ -136,14 +128,6 @@ def test_has_complete_async(self): assert hasattr(Chat, "complete_async") mark_tested("chat", "complete_async") - def test_has_stream(self): - assert hasattr(Chat, "stream") - mark_tested("chat", "stream") - - def test_has_stream_async(self): - assert hasattr(Chat, "stream_async") - mark_tested("chat", "stream_async") - # -- complete params -- @pytest.mark.parametrize("param_name,expected_default", CHAT_COMPLETE_PARAMS) def test_complete_has_param(self, param_name, expected_default): @@ -154,16 +138,6 @@ def test_complete_has_param(self, param_name, expected_default): f"Chat.complete param {param_name}: expected {expected_default!r}, got {actual!r}" ) - # -- stream params -- - @pytest.mark.parametrize("param_name,expected_default", CHAT_STREAM_PARAMS) - def test_stream_has_param(self, param_name, expected_default): - sig = inspect.signature(Chat.stream) - assert param_name in sig.parameters, f"Chat.stream missing param: {param_name}" - actual = sig.parameters[param_name].default - assert actual == expected_default, ( - f"Chat.stream param {param_name}: expected {expected_default!r}, got {actual!r}" - ) - # -- complete_async matches complete -- @pytest.mark.parametrize("param_name,expected_default", CHAT_COMPLETE_PARAMS) def test_complete_async_has_param(self, param_name, expected_default): @@ -174,44 +148,21 @@ def test_complete_async_has_param(self, param_name, expected_default): f"Chat.complete_async param {param_name}: expected {expected_default!r}, got {actual!r}" ) - # -- stream_async matches stream -- - @pytest.mark.parametrize("param_name,expected_default", CHAT_STREAM_PARAMS) - def test_stream_async_has_param(self, param_name, expected_default): - sig = inspect.signature(Chat.stream_async) - assert param_name in sig.parameters, f"Chat.stream_async missing param: {param_name}" - actual = sig.parameters[param_name].default - assert actual == expected_default, ( - f"Chat.stream_async param {param_name}: expected {expected_default!r}, got {actual!r}" - ) - # -- sync/async parity -- def test_complete_async_matches_complete(self): sync_params = set(inspect.signature(Chat.complete).parameters) - {"self"} async_params = set(inspect.signature(Chat.complete_async).parameters) - {"self"} assert sync_params == async_params - def test_stream_async_matches_stream(self): - sync_params = set(inspect.signature(Chat.stream).parameters) - {"self"} - async_params = set(inspect.signature(Chat.stream_async).parameters) - {"self"} - assert sync_params == async_params - # -- key defaults -- def test_complete_model_required(self): sig = inspect.signature(Chat.complete) assert sig.parameters["model"].default is _EMPTY - def test_stream_model_required(self): - sig = inspect.signature(Chat.stream) - assert sig.parameters["model"].default is _EMPTY - def test_complete_stream_defaults_false(self): sig = inspect.signature(Chat.complete) assert sig.parameters["stream"].default is False - def test_stream_stream_defaults_true(self): - sig = inspect.signature(Chat.stream) - assert sig.parameters["stream"].default is True - class TestGCPFim: def test_has_complete(self): @@ -222,14 +173,6 @@ def test_has_complete_async(self): assert hasattr(Fim, "complete_async") mark_tested("fim", "complete_async") - def test_has_stream(self): - assert hasattr(Fim, "stream") - mark_tested("fim", "stream") - - def test_has_stream_async(self): - assert hasattr(Fim, "stream_async") - mark_tested("fim", "stream_async") - # -- complete params -- @pytest.mark.parametrize("param_name,expected_default", FIM_COMPLETE_PARAMS) def test_complete_has_param(self, param_name, expected_default): @@ -240,16 +183,6 @@ def test_complete_has_param(self, param_name, expected_default): f"Fim.complete param {param_name}: expected {expected_default!r}, got {actual!r}" ) - # -- stream params -- - @pytest.mark.parametrize("param_name,expected_default", FIM_STREAM_PARAMS) - def test_stream_has_param(self, param_name, expected_default): - sig = inspect.signature(Fim.stream) - assert param_name in sig.parameters, f"Fim.stream missing param: {param_name}" - actual = sig.parameters[param_name].default - assert actual == expected_default, ( - f"Fim.stream param {param_name}: expected {expected_default!r}, got {actual!r}" - ) - # -- complete_async matches complete -- @pytest.mark.parametrize("param_name,expected_default", FIM_COMPLETE_PARAMS) def test_complete_async_has_param(self, param_name, expected_default): @@ -260,52 +193,25 @@ def test_complete_async_has_param(self, param_name, expected_default): f"Fim.complete_async param {param_name}: expected {expected_default!r}, got {actual!r}" ) - # -- stream_async matches stream -- - @pytest.mark.parametrize("param_name,expected_default", FIM_STREAM_PARAMS) - def test_stream_async_has_param(self, param_name, expected_default): - sig = inspect.signature(Fim.stream_async) - assert param_name in sig.parameters, f"Fim.stream_async missing param: {param_name}" - actual = sig.parameters[param_name].default - assert actual == expected_default, ( - f"Fim.stream_async param {param_name}: expected {expected_default!r}, got {actual!r}" - ) - # -- sync/async parity -- def test_complete_async_matches_complete(self): sync_params = set(inspect.signature(Fim.complete).parameters) - {"self"} async_params = set(inspect.signature(Fim.complete_async).parameters) - {"self"} assert sync_params == async_params - def test_stream_async_matches_stream(self): - sync_params = set(inspect.signature(Fim.stream).parameters) - {"self"} - async_params = set(inspect.signature(Fim.stream_async).parameters) - {"self"} - assert sync_params == async_params - # -- key defaults -- def test_complete_model_required(self): sig = inspect.signature(Fim.complete) assert sig.parameters["model"].default is _EMPTY - def test_stream_model_required(self): - sig = inspect.signature(Fim.stream) - assert sig.parameters["model"].default is _EMPTY - def test_complete_stream_defaults_false(self): sig = inspect.signature(Fim.complete) assert sig.parameters["stream"].default is False - def test_stream_stream_defaults_true(self): - sig = inspect.signature(Fim.stream) - assert sig.parameters["stream"].default is True - def test_complete_top_p_defaults_to_1(self): sig = inspect.signature(Fim.complete) assert sig.parameters["top_p"].default == 1 - def test_stream_top_p_defaults_to_1(self): - sig = inspect.signature(Fim.stream) - assert sig.parameters["top_p"].default == 1 - class TestGCPCoverage: def test_all_methods_tested(self):