Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ repos:
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies: [httpx, pydantic]
files: ^(examples/|src/mistralai/|packages/(azure|gcp)/src/mistralai/).*\.py$
exclude: ^src/mistralai/(__init__|sdkhooks|types)\.py$
55 changes: 51 additions & 4 deletions packages/azure/src/mistralai/azure/client/_hooks/registration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,59 @@
from .types import Hooks
import logging

import httpx

from .types import BeforeRequestContext, BeforeRequestHook, Hooks

logger = logging.getLogger(__name__)


# This file is only ever generated once on the first generation and then is free to be modified.
# Any hooks you wish to add should be registered in the init_hooks function. Feel free to define them
# in this file or in separate files in the hooks folder.


class AzureServerlessPathHook(BeforeRequestHook):
"""Rewrite URL paths for legacy Azure serverless endpoints.

After the spec update, operation paths match Foundry Resource format:
- Chat: /models/chat/completions
- OCR: /providers/mistral/azure/ocr

Legacy serverless endpoints (*.models.ai.azure.com) use different paths:
- Chat: /chat/completions
- OCR: /ocr

This hook rewrites Foundry paths back to serverless paths when
is_foundry=False.
"""

SERVERLESS_PATHS: dict[str, str] = {
"chat": "/chat/completions",
"ocr": "/ocr",
}

def before_request(
self, hook_ctx: BeforeRequestContext, request: httpx.Request
) -> httpx.Request:
for key, path in self.SERVERLESS_PATHS.items():
if key in hook_ctx.operation_id:
query = b""
if b"?" in request.url.raw_path:
query = b"?" + request.url.raw_path.split(b"?", 1)[1]
return httpx.Request(
method=request.method,
url=request.url.copy_with(
raw_path=path.encode("ascii") + query,
),
headers=request.headers,
content=request.content,
)
return request


def init_hooks(_hooks: Hooks) -> None:
"""Add hooks by calling hooks.register{sdk_init/before_request/after_success/after_error}Hook
with an instance of a hook that implements that specific Hook interface
Hooks are registered per SDK instance, and are valid for the lifetime of the SDK instance"""
"""Initialize hooks. Called by SDKHooks.__init__.

Note: AzureServerlessPathHook requires is_foundry context, so it is
registered separately in MistralAzure.__init__ when is_foundry=False.
"""
7 changes: 7 additions & 0 deletions packages/azure/src/mistralai/azure/client/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
from mistralai.azure.client import models, utils
from mistralai.azure.client._hooks import SDKHooks
from mistralai.azure.client._hooks.registration import AzureServerlessPathHook
from mistralai.azure.client.types import OptionalNullable, UNSET
import sys
from typing import Callable, Dict, Optional, TYPE_CHECKING, Union, cast
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(
timeout_ms: Optional[int] = None,
debug_logger: Optional[Logger] = None,
api_version: str = "2024-05-01-preview",
is_foundry: bool = True,
) -> None:
r"""Instantiates the SDK configuring it with the provided parameters.

Expand All @@ -61,6 +63,8 @@ def __init__(
:param retry_config: The retry configuration to use for all supported methods
:param timeout_ms: Optional request timeout applied to each operation in milliseconds
:param api_version: Azure API version to use (injected as query param)
:param is_foundry: True for Foundry Resource endpoints (services.ai.azure.com),
False for legacy serverless endpoints (models.ai.azure.com)
"""
client_supplied = True
if client is None:
Expand Down Expand Up @@ -135,6 +139,9 @@ def get_security() -> models.Security:
hooks = SDKHooks()
self.sdk_configuration.__dict__["_hooks"] = hooks

if not is_foundry:
hooks.register_before_request_hook(AzureServerlessPathHook())

current_server_url, *_ = self.sdk_configuration.get_server_details()
server_url, self.sdk_configuration.client = hooks.sdk_init(
current_server_url, client
Expand Down
86 changes: 59 additions & 27 deletions tests/test_azure_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,40 @@
These tests require credentials and make real API calls.
Skip if AZURE_API_KEY env var is not set.

Prerequisites:
1. Azure API key (stored in Bitwarden at "[MaaS] - Azure Foundry API key")
2. Tailscale connected via gw-0 exit node

Usage:
AZURE_API_KEY=xxx pytest tests/test_azure_integration.py -v
AZURE_API_KEY=xxx AZURE_SERVER_URL=https://<resource>.services.ai.azure.com AZURE_MODEL=<model> AZURE_OCR_MODEL=<ocr-model> pytest tests/test_azure_integration.py -v

Environment variables:
AZURE_API_KEY: API key (required)
AZURE_ENDPOINT: Base URL (default: https://maas-qa-aifoundry.services.ai.azure.com/models)
AZURE_MODEL: Model name (default: maas-qa-ministral-3b)
AZURE_SERVER_URL: Base host URL (required, e.g. https://<resource>.services.ai.azure.com)
AZURE_MODEL: Chat model name (required)
AZURE_OCR_MODEL: OCR model name (required)
AZURE_API_VERSION: API version (default: 2024-05-01-preview)

Note: AZURE_ENDPOINT should be the base URL without path suffixes.
The SDK appends /chat/completions to this URL. The api_version parameter
is automatically injected as a query parameter by the SDK.

Available models:
Chat: maas-qa-ministral-3b, maas-qa-mistral-large-3, maas-qa-mistral-medium-2505
OCR: maas-qa-mistral-document-ai-2505, maas-qa-mistral-document-ai-2512
(OCR uses a separate endpoint, not tested here)
Note: AZURE_SERVER_URL should be the base host URL without any path suffix.
The SDK appends the correct path per operation type:
- Chat: /models/chat/completions
- OCR: /providers/mistral/azure/ocr
The api_version parameter is automatically injected as a query parameter.
"""
import base64
import json
import os

import pytest

# Configuration from env vars
AZURE_API_KEY = os.environ.get("AZURE_API_KEY")
AZURE_ENDPOINT = os.environ.get(
"AZURE_ENDPOINT",
"https://maas-qa-aifoundry.services.ai.azure.com/models",
)
AZURE_MODEL = os.environ.get("AZURE_MODEL", "maas-qa-ministral-3b")
AZURE_SERVER_URL = os.environ.get("AZURE_SERVER_URL")
AZURE_MODEL = os.environ.get("AZURE_MODEL")
AZURE_OCR_MODEL = os.environ.get("AZURE_OCR_MODEL")
AZURE_API_VERSION = os.environ.get("AZURE_API_VERSION", "2024-05-01-preview")

SKIP_REASON = "AZURE_API_KEY env var required"
SKIP_REASON = "Required env vars: AZURE_API_KEY, AZURE_SERVER_URL, AZURE_MODEL, AZURE_OCR_MODEL"

pytestmark = pytest.mark.skipif(
not AZURE_API_KEY,
reason=SKIP_REASON
not all([AZURE_API_KEY, AZURE_SERVER_URL, AZURE_MODEL, AZURE_OCR_MODEL]),
reason=SKIP_REASON,
)

# Shared tool definition for tool-call tests
Expand All @@ -61,15 +54,23 @@
},
}

# Minimal valid PDF for OCR tests (single blank page)
MINIMAL_PDF = (
b"%PDF-1.0\n1 0 obj<</Pages 2 0 R>>endobj\n"
b"2 0 obj<</Kids[3 0 R]/Count 1>>endobj\n"
b"3 0 obj<</MediaBox[0 0 612 792]>>endobj\n"
b"trailer<</Root 1 0 R>>"
)


@pytest.fixture
def azure_client():
"""Create an Azure client with api_version parameter."""
"""Create an Azure client for Foundry Resource endpoints."""
from mistralai.azure.client import MistralAzure
assert AZURE_API_KEY is not None, "AZURE_API_KEY must be set"
return MistralAzure(
api_key=AZURE_API_KEY,
server_url=AZURE_ENDPOINT,
server_url=AZURE_SERVER_URL,
api_version=AZURE_API_VERSION,
)

Expand Down Expand Up @@ -323,6 +324,37 @@ def test_stream_tool_call(self, azure_client):
assert tool_call_found, "Expected tool_call delta chunks in stream"


class TestAzureOcr:
"""Test OCR endpoint."""

def test_basic_ocr(self, azure_client):
"""Test OCR processes a document and returns pages."""
encoded = base64.b64encode(MINIMAL_PDF).decode("utf-8")
res = azure_client.ocr.process(
model=AZURE_OCR_MODEL,
document={
"type": "document_url",
"document_url": f"data:application/pdf;base64,{encoded}",
},
)
assert res is not None
assert res.pages is not None

@pytest.mark.asyncio
async def test_basic_ocr_async(self, azure_client):
"""Test async OCR processes a document and returns pages."""
encoded = base64.b64encode(MINIMAL_PDF).decode("utf-8")
res = await azure_client.ocr.process_async(
model=AZURE_OCR_MODEL,
document={
"type": "document_url",
"document_url": f"data:application/pdf;base64,{encoded}",
},
)
assert res is not None
assert res.pages is not None


class TestAzureChatCompleteAsync:
"""Test async chat completion."""

Expand Down Expand Up @@ -401,7 +433,7 @@ def test_sync_context_manager(self):
assert AZURE_API_KEY is not None, "AZURE_API_KEY must be set"
with MistralAzure(
api_key=AZURE_API_KEY,
server_url=AZURE_ENDPOINT,
server_url=AZURE_SERVER_URL,
api_version=AZURE_API_VERSION,
) as client:
res = client.chat.complete(
Expand All @@ -420,7 +452,7 @@ async def test_async_context_manager(self):
assert AZURE_API_KEY is not None, "AZURE_API_KEY must be set"
async with MistralAzure(
api_key=AZURE_API_KEY,
server_url=AZURE_ENDPOINT,
server_url=AZURE_SERVER_URL,
api_version=AZURE_API_VERSION,
) as client:
res = await client.chat.complete_async(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_gcp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

Prerequisites:
1. Authenticate with GCP: gcloud auth application-default login
2. Have "Vertex AI User" role on the project (e.g. model-garden-420509)
2. Have "Vertex AI User" role on the project

The SDK automatically:
- Detects credentials via google.auth.default()
Expand All @@ -19,7 +19,7 @@
See: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral

Usage:
GCP_PROJECT_ID=model-garden-420509 pytest tests/test_gcp_integration.py -v
GCP_PROJECT_ID=<your-project-id> pytest tests/test_gcp_integration.py -v

Environment variables:
GCP_PROJECT_ID: GCP project ID (required, or auto-detected from credentials)
Expand Down
Loading