Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/mistral/audio/chat_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def main():
print(f"Uploaded audio file, id={file.id}")
signed_url = client.files.get_signed_url(file_id=file.id)
try:
chat_response = client.chat.stream(
chat_response = client.chat.complete(
stream=True,
model=model,
messages=[
UserMessage(
Expand Down
4 changes: 3 additions & 1 deletion examples/mistral/audio/transcription_segments_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import os

from mistralai.client import Mistral
from mistralai.client.transcriptions import CompleteAcceptEnum


def main():
api_key = os.environ["MISTRAL_API_KEY"]
model = "voxtral-mini-latest"

client = Mistral(api_key=api_key)
response = client.audio.transcriptions.stream(
response = client.audio.transcriptions.complete(
model=model,
file_url="https://docs.mistral.ai/audio/bcn_weather.mp3",
timestamp_granularities=["segment"],
accept_header_override=CompleteAcceptEnum.TEXT_EVENT_STREAM,
)
for chunk in response:
print(chunk)
Expand Down
4 changes: 3 additions & 1 deletion examples/mistral/audio/transcription_stream_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mistralai.client import Mistral
from mistralai.client.models import File
from mistralai.client.transcriptions import CompleteAcceptEnum


async def main():
Expand All @@ -12,9 +13,10 @@ async def main():

client = Mistral(api_key=api_key)
with open("examples/fixtures/bcn_weather.mp3", "rb") as f:
response = await client.audio.transcriptions.stream_async(
response = await client.audio.transcriptions.complete_async(
model=model,
file=File(content=f, file_name=f.name),
accept_header_override=CompleteAcceptEnum.TEXT_EVENT_STREAM,
)
async for chunk in response:
print(chunk.event, chunk.data)
Expand Down
3 changes: 2 additions & 1 deletion examples/mistral/chat/async_chat_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ async def main():
client = Mistral(api_key=api_key)

print("Chat response:")
response = await client.chat.stream_async(
response = await client.chat.complete_async(
model=model,
stream=True,
messages=[
UserMessage(content="What is the best French cheese?give the best 50")
],
Expand Down
3 changes: 2 additions & 1 deletion examples/mistral/chat/chat_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def main():

client = Mistral(api_key=api_key)

for chunk in client.chat.stream(
for chunk in client.chat.complete(
model=model,
stream=True,
messages=[UserMessage(content="What is the best French cheese?")],
):
print(chunk.data.choices[0].delta.content, end="")
Expand Down
4 changes: 2 additions & 2 deletions examples/mistral/chat/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def run_inference(self, content):
f"Running inference with model: {self.model}, temperature: {self.temperature}"
)
logger.debug(f"Sending messages: {self.messages}")
for chunk in self.client.chat.stream(
model=self.model, temperature=self.temperature, messages=self.messages
for chunk in self.client.chat.complete(
model=self.model, temperature=self.temperature, stream=True, messages=self.messages
):
response = chunk.data.choices[0].delta.content
if response is not None:
Expand Down
3 changes: 2 additions & 1 deletion examples/mistral/chat/completion_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ async def main():
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"

print(prompt)
for chunk in client.fim.stream(
for chunk in client.fim.complete(
model="codestral-latest",
stream=True,
prompt=prompt,
suffix=suffix,
):
Expand Down
3 changes: 2 additions & 1 deletion examples/mistral/chat/structured_outputs_with_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def main():
print(chat_response.choices[0].message.content)

# Or with the streaming API
with client.chat.stream(
with client.chat.complete(
model="mistral-large-latest",
stream=True,
messages=[
{
"role": "system",
Expand Down
80 changes: 64 additions & 16 deletions src/mistralai/client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# @generated-id: 7eba0f088d47

from .basesdk import BaseSDK
from enum import Enum
from mistralai.client import errors, models, utils
from mistralai.client._hooks import HookContext
from mistralai.client.types import OptionalNullable, UNSET
Expand All @@ -23,6 +24,11 @@
# endregion imports


class CompleteAcceptEnum(str, Enum):
APPLICATION_JSON = "application/json"
TEXT_EVENT_STREAM = "text/event-stream"


class Chat(BaseSDK):
r"""Chat Completion API."""

Expand All @@ -41,7 +47,9 @@ def parse(
# Convert the input Pydantic Model to a strict JSON ready to be passed to chat.complete
json_response_format = response_format_from_pydantic_model(response_format)
# Run the inference
kwargs["stream"] = False
response = self.complete(**kwargs, response_format=json_response_format)
assert isinstance(response, models.ChatCompletionResponse)
# Parse response back to the input pydantic model
parsed_response = convert_to_parsed_chat_completion_response(
response, response_format
Expand All @@ -58,9 +66,11 @@ async def parse_async(
:return: The parsed response
"""
json_response_format = response_format_from_pydantic_model(response_format)
kwargs["stream"] = False
response = await self.complete_async( # pylint: disable=E1125
**kwargs, response_format=json_response_format
)
assert isinstance(response, models.ChatCompletionResponse)
parsed_response = convert_to_parsed_chat_completion_response(
response, response_format
)
Expand All @@ -73,11 +83,13 @@ def parse_stream(
Parse the response using the provided response format.
For now the response will be in JSON format not in the input Pydantic model.
:param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
:param Any **kwargs Additional keyword arguments to pass to the .stream method
:param Any **kwargs Additional keyword arguments to pass to the .complete method
:return: The JSON parsed response
"""
json_response_format = response_format_from_pydantic_model(response_format)
response = self.stream(**kwargs, response_format=json_response_format)
kwargs["stream"] = True
response = self.complete(**kwargs, response_format=json_response_format)
assert isinstance(response, eventstreaming.EventStream)
return response

async def parse_stream_async(
Expand All @@ -87,13 +99,15 @@ async def parse_stream_async(
Asynchronously parse the response using the provided response format.
For now the response will be in JSON format not in the input Pydantic model.
:param Type[CustomPydanticModel] response_format: The Pydantic model to parse the response into
:param Any **kwargs Additional keyword arguments to pass to the .stream method
:param Any **kwargs Additional keyword arguments to pass to the .complete method
:return: The JSON parsed response
"""
json_response_format = response_format_from_pydantic_model(response_format)
response = await self.stream_async( # pylint: disable=E1125
kwargs["stream"] = True
response = await self.complete_async( # pylint: disable=E1125
**kwargs, response_format=json_response_format
)
assert isinstance(response, eventstreaming.EventStreamAsync)
return response

# endregion sdk-class-body
Expand Down Expand Up @@ -142,8 +156,9 @@ def complete(
retries: OptionalNullable[utils.RetryConfig] = UNSET,
server_url: Optional[str] = None,
timeout_ms: Optional[int] = None,
accept_header_override: Optional[CompleteAcceptEnum] = None,
http_headers: Optional[Mapping[str, str]] = None,
) -> models.ChatCompletionResponse:
) -> models.ChatCompletionV1ChatCompletionsPostResponse:
r"""Chat Completion

:param model: ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions.
Expand All @@ -168,6 +183,7 @@ def complete(
:param retries: Override the default retry configuration for this method
:param server_url: Override the default server URL for this method
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
:param accept_header_override: Override the default accept header for this method
:param http_headers: Additional headers to set or replace on requests.
"""
base_url = None
Expand Down Expand Up @@ -220,7 +236,9 @@ def complete(
request_has_path_params=False,
request_has_query_params=True,
user_agent_header="user-agent",
accept_header_value="application/json",
accept_header_value=accept_header_override.value
if accept_header_override is not None
else "application/json;q=1, text/event-stream;q=0",
http_headers=http_headers,
security=self.sdk_configuration.security,
get_serialized_body=lambda: utils.serialize_request_body(
Expand Down Expand Up @@ -250,25 +268,38 @@ def complete(
),
request=req,
error_status_codes=["422", "4XX", "5XX"],
stream=True,
retry_config=retry_config,
)

response_data: Any = None
if utils.match_response(http_res, "200", "application/json"):
return unmarshal_json_response(models.ChatCompletionResponse, http_res)
http_res_text = utils.stream_to_text(http_res)
return unmarshal_json_response(
models.ChatCompletionResponse, http_res, http_res_text
)
if utils.match_response(http_res, "200", "text/event-stream"):
return eventstreaming.EventStream(
http_res,
lambda raw: utils.unmarshal_json(raw, models.CompletionEvent),
sentinel="[DONE]",
client_ref=self,
)
if utils.match_response(http_res, "422", "application/json"):
http_res_text = utils.stream_to_text(http_res)
response_data = unmarshal_json_response(
errors.HTTPValidationErrorData, http_res
errors.HTTPValidationErrorData, http_res, http_res_text
)
raise errors.HTTPValidationError(response_data, http_res)
raise errors.HTTPValidationError(response_data, http_res, http_res_text)
if utils.match_response(http_res, "4XX", "*"):
http_res_text = utils.stream_to_text(http_res)
raise errors.SDKError("API error occurred", http_res, http_res_text)
if utils.match_response(http_res, "5XX", "*"):
http_res_text = utils.stream_to_text(http_res)
raise errors.SDKError("API error occurred", http_res, http_res_text)

raise errors.SDKError("Unexpected response received", http_res)
http_res_text = utils.stream_to_text(http_res)
raise errors.SDKError("Unexpected response received", http_res, http_res_text)

async def complete_async(
self,
Expand Down Expand Up @@ -314,8 +345,9 @@ async def complete_async(
retries: OptionalNullable[utils.RetryConfig] = UNSET,
server_url: Optional[str] = None,
timeout_ms: Optional[int] = None,
accept_header_override: Optional[CompleteAcceptEnum] = None,
http_headers: Optional[Mapping[str, str]] = None,
) -> models.ChatCompletionResponse:
) -> models.ChatCompletionV1ChatCompletionsPostResponse:
r"""Chat Completion

:param model: ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions.
Expand All @@ -340,6 +372,7 @@ async def complete_async(
:param retries: Override the default retry configuration for this method
:param server_url: Override the default server URL for this method
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
:param accept_header_override: Override the default accept header for this method
:param http_headers: Additional headers to set or replace on requests.
"""
base_url = None
Expand Down Expand Up @@ -392,7 +425,9 @@ async def complete_async(
request_has_path_params=False,
request_has_query_params=True,
user_agent_header="user-agent",
accept_header_value="application/json",
accept_header_value=accept_header_override.value
if accept_header_override is not None
else "application/json;q=1, text/event-stream;q=0",
http_headers=http_headers,
security=self.sdk_configuration.security,
get_serialized_body=lambda: utils.serialize_request_body(
Expand Down Expand Up @@ -422,25 +457,38 @@ async def complete_async(
),
request=req,
error_status_codes=["422", "4XX", "5XX"],
stream=True,
retry_config=retry_config,
)

response_data: Any = None
if utils.match_response(http_res, "200", "application/json"):
return unmarshal_json_response(models.ChatCompletionResponse, http_res)
http_res_text = await utils.stream_to_text_async(http_res)
return unmarshal_json_response(
models.ChatCompletionResponse, http_res, http_res_text
)
if utils.match_response(http_res, "200", "text/event-stream"):
return eventstreaming.EventStreamAsync(
http_res,
lambda raw: utils.unmarshal_json(raw, models.CompletionEvent),
sentinel="[DONE]",
client_ref=self,
)
if utils.match_response(http_res, "422", "application/json"):
http_res_text = await utils.stream_to_text_async(http_res)
response_data = unmarshal_json_response(
errors.HTTPValidationErrorData, http_res
errors.HTTPValidationErrorData, http_res, http_res_text
)
raise errors.HTTPValidationError(response_data, http_res)
raise errors.HTTPValidationError(response_data, http_res, http_res_text)
if utils.match_response(http_res, "4XX", "*"):
http_res_text = await utils.stream_to_text_async(http_res)
raise errors.SDKError("API error occurred", http_res, http_res_text)
if utils.match_response(http_res, "5XX", "*"):
http_res_text = await utils.stream_to_text_async(http_res)
raise errors.SDKError("API error occurred", http_res, http_res_text)

raise errors.SDKError("Unexpected response received", http_res)
http_res_text = await utils.stream_to_text_async(http_res)
raise errors.SDKError("Unexpected response received", http_res, http_res_text)

def stream(
self,
Expand Down
Loading
Loading