From 7858f1f96ba55f3602d5ab65211d680a1e107591 Mon Sep 17 00:00:00 2001 From: Nelson PROIA Date: Thu, 19 Mar 2026 18:24:01 +0100 Subject: [PATCH 1/2] fix(extra): return ResponseFormatTypedDict from response_format_from_pydantic_model response_format_from_pydantic_model now returns a ResponseFormatTypedDict (plain dict) instead of a ResponseFormat model instance. This fixes the type mismatch when passing the result to the Azure SDK, which expects its own ResponseFormat class, and avoids the schema alias data loss issue. Fixes AIR-143 / GitHub #367 --- .../jobs/async_jobs_ocr_batch_annotation.py | 4 +-- src/mistralai/extra/run/context.py | 4 ++- src/mistralai/extra/tests/test_utils.py | 28 ++++++++----------- src/mistralai/extra/utils/response_format.py | 25 ++++++++++++----- 4 files changed, 34 insertions(+), 27 deletions(-) diff --git a/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py b/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py index f209507d..8ddde775 100644 --- a/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py +++ b/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py @@ -29,9 +29,7 @@ def create_ocr_batch_request(custom_id: str, document_url: str) -> dict: "custom_id": custom_id, "body": { "document": {"type": "document_url", "document_url": document_url}, - "document_annotation_format": response_format.model_dump( - by_alias=True, exclude_none=True - ), + "document_annotation_format": response_format, "pages": [0, 1, 2, 3, 4, 5, 6, 7], "include_image_base64": False, }, diff --git a/src/mistralai/extra/run/context.py b/src/mistralai/extra/run/context.py index 7ade705f..d253edd5 100644 --- a/src/mistralai/extra/run/context.py +++ b/src/mistralai/extra/run/context.py @@ -243,7 +243,9 @@ async def prepare_model_request( def response_format(self) -> ResponseFormat: if not self.output_format: raise RunException("No response format exist for the current RunContext.") - return response_format_from_pydantic_model(self.output_format) + return ResponseFormat.model_validate( + response_format_from_pydantic_model(self.output_format) + ) async def _validate_run( diff --git a/src/mistralai/extra/tests/test_utils.py b/src/mistralai/extra/tests/test_utils.py index b0e5cdbc..a071b3ab 100644 --- a/src/mistralai/extra/tests/test_utils.py +++ b/src/mistralai/extra/tests/test_utils.py @@ -5,9 +5,6 @@ ) from pydantic import BaseModel, Field, ValidationError -from mistralai.client.models import ResponseFormat, JSONSchema -from mistralai.client.types.basemodel import Unset - import unittest @@ -55,15 +52,14 @@ class MathDemonstration(BaseModel): mathdemo_strict_schema["$defs"]["Explanation"]["additionalProperties"] = False # type: ignore mathdemo_strict_schema["additionalProperties"] = False -mathdemo_response_format = ResponseFormat( - type="json_schema", - json_schema=JSONSchema( - name="MathDemonstration", - schema_definition=mathdemo_strict_schema, - description=Unset(), - strict=True, - ), -) +mathdemo_response_format = { + "type": "json_schema", + "json_schema": { + "name": "MathDemonstration", + "schema": mathdemo_strict_schema, + "strict": True, + }, +} class TestResponseFormat(unittest.TestCase): @@ -220,10 +216,10 @@ class ModelWithConstraints(BaseModel): # Should not raise ValueError result = response_format_from_pydantic_model(ModelWithConstraints) - # Verify it returns a valid ResponseFormat - self.assertIsInstance(result, ResponseFormat) - self.assertEqual(result.type, "json_schema") - self.assertIsNotNone(result.json_schema) + # Verify it returns a valid response format dict + self.assertIsInstance(result, dict) + self.assertEqual(result.get("type"), "json_schema") + self.assertIsNotNone(result.get("json_schema")) def test_rec_strict_json_schema_with_invalid_type(self): """Test that rec_strict_json_schema raises ValueError for truly invalid types.""" diff --git a/src/mistralai/extra/utils/response_format.py b/src/mistralai/extra/utils/response_format.py index 2378b562..3600156b 100644 --- a/src/mistralai/extra/utils/response_format.py +++ b/src/mistralai/extra/utils/response_format.py @@ -1,7 +1,7 @@ -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from pydantic import BaseModel -from mistralai.client.models import JSONSchema, ResponseFormat +from mistralai.client.models import ResponseFormatTypedDict from ._pydantic_helper import rec_strict_json_schema CustomPydanticModel = TypeVar("CustomPydanticModel", bound=BaseModel) @@ -9,13 +9,24 @@ def response_format_from_pydantic_model( model: type[CustomPydanticModel], -) -> ResponseFormat: - """Generate a strict JSON schema from a pydantic model.""" +) -> ResponseFormatTypedDict: + """Generate a strict JSON schema response format from a pydantic model. + + Returns a TypedDict compatible with both the main SDK's and Azure SDK's + ResponseFormat / ResponseFormatTypedDict. + """ model_schema = rec_strict_json_schema(model.model_json_schema()) - json_schema = JSONSchema.model_validate( - {"name": model.__name__, "schema": model_schema, "strict": True} + return cast( + ResponseFormatTypedDict, + { + "type": "json_schema", + "json_schema": { + "name": model.__name__, + "schema": model_schema, + "strict": True, + }, + }, ) - return ResponseFormat(type="json_schema", json_schema=json_schema) def pydantic_model_from_json( From 4fb65be4eab8fcbde6d750ca002a6ac3c1123e60 Mon Sep 17 00:00:00 2001 From: Nelson PROIA Date: Fri, 20 Mar 2026 11:57:24 +0100 Subject: [PATCH 2/2] fix(examples): use ChatCompletionRequestTools1 type for tools list Fix mypy list invariance error introduced by SDK 2.1.0 regen where list[Tool] no longer satisfies the broadened tools union type. --- examples/mistral/chat/function_calling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/mistral/chat/function_calling.py b/examples/mistral/chat/function_calling.py index 68e9d91c..9299a822 100644 --- a/examples/mistral/chat/function_calling.py +++ b/examples/mistral/chat/function_calling.py @@ -7,6 +7,7 @@ from mistralai.client.models import ( AssistantMessage, ChatCompletionRequestMessage, + ChatCompletionRequestTools1, Function, Tool, ToolMessage, @@ -48,7 +49,7 @@ def retrieve_payment_date(data: dict[str, list[Any]], transaction_id: str) -> st "retrieve_payment_date": functools.partial(retrieve_payment_date, data=data), } -tools: list[Tool] = [ +tools: list[ChatCompletionRequestTools1] = [ Tool( function=Function( name="retrieve_payment_status",