diff --git a/examples/mistral/chat/function_calling.py b/examples/mistral/chat/function_calling.py index 68e9d91c..9299a822 100644 --- a/examples/mistral/chat/function_calling.py +++ b/examples/mistral/chat/function_calling.py @@ -7,6 +7,7 @@ from mistralai.client.models import ( AssistantMessage, ChatCompletionRequestMessage, + ChatCompletionRequestTools1, Function, Tool, ToolMessage, @@ -48,7 +49,7 @@ def retrieve_payment_date(data: dict[str, list[Any]], transaction_id: str) -> st "retrieve_payment_date": functools.partial(retrieve_payment_date, data=data), } -tools: list[Tool] = [ +tools: list[ChatCompletionRequestTools1] = [ Tool( function=Function( name="retrieve_payment_status", diff --git a/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py b/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py index f209507d..8ddde775 100644 --- a/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py +++ b/examples/mistral/jobs/async_jobs_ocr_batch_annotation.py @@ -29,9 +29,7 @@ def create_ocr_batch_request(custom_id: str, document_url: str) -> dict: "custom_id": custom_id, "body": { "document": {"type": "document_url", "document_url": document_url}, - "document_annotation_format": response_format.model_dump( - by_alias=True, exclude_none=True - ), + "document_annotation_format": response_format, "pages": [0, 1, 2, 3, 4, 5, 6, 7], "include_image_base64": False, }, diff --git a/src/mistralai/extra/run/context.py b/src/mistralai/extra/run/context.py index 7ade705f..d253edd5 100644 --- a/src/mistralai/extra/run/context.py +++ b/src/mistralai/extra/run/context.py @@ -243,7 +243,9 @@ async def prepare_model_request( def response_format(self) -> ResponseFormat: if not self.output_format: raise RunException("No response format exist for the current RunContext.") - return response_format_from_pydantic_model(self.output_format) + return ResponseFormat.model_validate( + response_format_from_pydantic_model(self.output_format) + ) async def _validate_run( diff --git a/src/mistralai/extra/tests/test_utils.py b/src/mistralai/extra/tests/test_utils.py index b0e5cdbc..a071b3ab 100644 --- a/src/mistralai/extra/tests/test_utils.py +++ b/src/mistralai/extra/tests/test_utils.py @@ -5,9 +5,6 @@ ) from pydantic import BaseModel, Field, ValidationError -from mistralai.client.models import ResponseFormat, JSONSchema -from mistralai.client.types.basemodel import Unset - import unittest @@ -55,15 +52,14 @@ class MathDemonstration(BaseModel): mathdemo_strict_schema["$defs"]["Explanation"]["additionalProperties"] = False # type: ignore mathdemo_strict_schema["additionalProperties"] = False -mathdemo_response_format = ResponseFormat( - type="json_schema", - json_schema=JSONSchema( - name="MathDemonstration", - schema_definition=mathdemo_strict_schema, - description=Unset(), - strict=True, - ), -) +mathdemo_response_format = { + "type": "json_schema", + "json_schema": { + "name": "MathDemonstration", + "schema": mathdemo_strict_schema, + "strict": True, + }, +} class TestResponseFormat(unittest.TestCase): @@ -220,10 +216,10 @@ class ModelWithConstraints(BaseModel): # Should not raise ValueError result = response_format_from_pydantic_model(ModelWithConstraints) - # Verify it returns a valid ResponseFormat - self.assertIsInstance(result, ResponseFormat) - self.assertEqual(result.type, "json_schema") - self.assertIsNotNone(result.json_schema) + # Verify it returns a valid response format dict + self.assertIsInstance(result, dict) + self.assertEqual(result.get("type"), "json_schema") + self.assertIsNotNone(result.get("json_schema")) def test_rec_strict_json_schema_with_invalid_type(self): """Test that rec_strict_json_schema raises ValueError for truly invalid types.""" diff --git a/src/mistralai/extra/utils/response_format.py b/src/mistralai/extra/utils/response_format.py index 2378b562..3600156b 100644 --- a/src/mistralai/extra/utils/response_format.py +++ b/src/mistralai/extra/utils/response_format.py @@ -1,7 +1,7 @@ -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from pydantic import BaseModel -from mistralai.client.models import JSONSchema, ResponseFormat +from mistralai.client.models import ResponseFormatTypedDict from ._pydantic_helper import rec_strict_json_schema CustomPydanticModel = TypeVar("CustomPydanticModel", bound=BaseModel) @@ -9,13 +9,24 @@ def response_format_from_pydantic_model( model: type[CustomPydanticModel], -) -> ResponseFormat: - """Generate a strict JSON schema from a pydantic model.""" +) -> ResponseFormatTypedDict: + """Generate a strict JSON schema response format from a pydantic model. + + Returns a TypedDict compatible with both the main SDK's and Azure SDK's + ResponseFormat / ResponseFormatTypedDict. + """ model_schema = rec_strict_json_schema(model.model_json_schema()) - json_schema = JSONSchema.model_validate( - {"name": model.__name__, "schema": model_schema, "strict": True} + return cast( + ResponseFormatTypedDict, + { + "type": "json_schema", + "json_schema": { + "name": model.__name__, + "schema": model_schema, + "strict": True, + }, + }, ) - return ResponseFormat(type="json_schema", json_schema=json_schema) def pydantic_model_from_json(