diff --git a/src/cohere/client.py b/src/cohere/client.py index 501338d3c..afe643213 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -12,7 +12,7 @@ from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate from .base_client import BaseCohere, AsyncBaseCohere, OMIT -from .config import embed_batch_size +from .config import embed_batch_size, embed_stream_batch_size from .core import RequestOptions from .environment import ClientEnvironment from .manually_maintained.cache import CacheMixin @@ -223,6 +223,61 @@ def embed( return merge_embed_responses(responses) + def embed_stream( + self, + *, + texts: typing.Sequence[str], + model: typing.Optional[str] = OMIT, + input_type: typing.Optional[EmbedInputType] = OMIT, + embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, + truncate: typing.Optional[EmbedRequestTruncate] = OMIT, + batch_size: int = embed_stream_batch_size, + request_options: typing.Optional[RequestOptions] = None, + ) -> typing.Iterator[typing.Any]: + """ + Memory-efficient embed that yields embeddings one batch at a time. + + Processes texts in batches and yields individual StreamedEmbedding objects + as they come back, so you can write to a vector store incrementally without + holding all embeddings in memory. + + Args: + texts: Texts to embed. + model: Embedding model ID. + input_type: Input type (search_document, search_query, etc.). + embedding_types: Types of embeddings to return (float, int8, etc.). + truncate: How to handle inputs longer than the max token length. + batch_size: Texts per API call. Defaults to 96 (API max). + request_options: Request-specific configuration. + + Yields: + StreamedEmbedding with index, embedding, embedding_type, and text. + """ + from .manually_maintained.streaming_embed import extract_embeddings_from_response + + if not texts: + return + if batch_size < 1: + raise ValueError("batch_size must be at least 1") + + texts_list = list(texts) + + for batch_start in range(0, len(texts_list), batch_size): + batch_texts = texts_list[batch_start : batch_start + batch_size] + + response = BaseCohere.embed( + self, + texts=batch_texts, + model=model, + input_type=input_type, + embedding_types=embedding_types, + truncate=truncate, + request_options=request_options, + ) + + response_data = response.dict() if hasattr(response, "dict") else response.__dict__ + yield from extract_embeddings_from_response(response_data, batch_texts, batch_start) + """ The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues. diff --git a/src/cohere/config.py b/src/cohere/config.py index c666c9eed..dcdce09a2 100644 --- a/src/cohere/config.py +++ b/src/cohere/config.py @@ -1 +1,2 @@ embed_batch_size = 96 +embed_stream_batch_size = 96 # Max texts per API request (API limit) diff --git a/src/cohere/manually_maintained/streaming_embed.py b/src/cohere/manually_maintained/streaming_embed.py new file mode 100644 index 000000000..6b392fa77 --- /dev/null +++ b/src/cohere/manually_maintained/streaming_embed.py @@ -0,0 +1,74 @@ +"""Utilities for streaming embed responses without loading all embeddings into memory.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterator, List, Optional, Union + + +@dataclass +class StreamedEmbedding: + """A single embedding yielded incrementally from embed_stream().""" + index: int + embedding: Union[List[float], List[int]] + embedding_type: str + text: Optional[str] = None + + +def extract_embeddings_from_response( + response_data: dict, + batch_texts: List[str], + global_offset: int = 0, +) -> Iterator[StreamedEmbedding]: + """ + Extract individual embeddings from a Cohere embed response dict. + + Works for both V1 (embeddings_floats / embeddings_by_type) and V2 response formats. + + Args: + response_data: Parsed JSON response from embed endpoint + batch_texts: The texts that were embedded in this batch + global_offset: Starting index for this batch within the full dataset + + Yields: + StreamedEmbedding objects + """ + response_type = response_data.get("response_type", "") + + if response_type == "embeddings_floats": + embeddings = response_data.get("embeddings", []) + for i, embedding in enumerate(embeddings): + yield StreamedEmbedding( + index=global_offset + i, + embedding=embedding, + embedding_type="float", + text=batch_texts[i] if i < len(batch_texts) else None, + ) + + elif response_type == "embeddings_by_type": + embeddings_obj = response_data.get("embeddings", {}) + for emb_type, embeddings_list in embeddings_obj.items(): + type_name = emb_type.rstrip("_") + if isinstance(embeddings_list, list): + for i, embedding in enumerate(embeddings_list): + yield StreamedEmbedding( + index=global_offset + i, + embedding=embedding, + embedding_type=type_name, + text=batch_texts[i] if i < len(batch_texts) else None, + ) + + else: + # V2 format: embeddings is a dict with type keys directly + embeddings_obj = response_data.get("embeddings", {}) + if isinstance(embeddings_obj, dict): + for emb_type, embeddings_list in embeddings_obj.items(): + type_name = emb_type.rstrip("_") + if isinstance(embeddings_list, list): + for i, embedding in enumerate(embeddings_list): + yield StreamedEmbedding( + index=global_offset + i, + embedding=embedding, + embedding_type=type_name, + text=batch_texts[i] if i < len(batch_texts) else None, + ) diff --git a/tests/test_embed_streaming.py b/tests/test_embed_streaming.py new file mode 100644 index 000000000..f9348e5c0 --- /dev/null +++ b/tests/test_embed_streaming.py @@ -0,0 +1,120 @@ +"""Tests for memory-efficient embed_stream functionality. + +All embed_stream code lives in manually maintained files (.fernignore protected): +- src/cohere/client.py — Client.embed_stream() +- src/cohere/manually_maintained/streaming_embed.py — StreamedEmbedding, extraction helpers +""" + +import unittest + +from cohere.manually_maintained.streaming_embed import ( + StreamedEmbedding, + extract_embeddings_from_response, +) +from cohere.config import embed_stream_batch_size + + +class TestStreamedEmbedding(unittest.TestCase): + """Test the StreamedEmbedding dataclass.""" + + def test_creation(self): + emb = StreamedEmbedding(index=0, embedding=[0.1, 0.2], embedding_type="float", text="hello") + self.assertEqual(emb.index, 0) + self.assertEqual(emb.embedding, [0.1, 0.2]) + self.assertEqual(emb.embedding_type, "float") + self.assertEqual(emb.text, "hello") + + def test_text_optional(self): + emb = StreamedEmbedding(index=0, embedding=[0.1], embedding_type="float") + self.assertIsNone(emb.text) + + +class TestExtractEmbeddings(unittest.TestCase): + """Test extract_embeddings_from_response for V1 and V2 formats.""" + + def test_v1_embeddings_floats(self): + """V1 embeddings_floats response returns flat float embeddings.""" + response = { + "response_type": "embeddings_floats", + "embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], + } + results = list(extract_embeddings_from_response(response, ["hello", "world"])) + + self.assertEqual(len(results), 2) + self.assertEqual(results[0].index, 0) + self.assertEqual(results[0].embedding, [0.1, 0.2, 0.3]) + self.assertEqual(results[0].embedding_type, "float") + self.assertEqual(results[0].text, "hello") + self.assertEqual(results[1].index, 1) + self.assertEqual(results[1].text, "world") + + def test_v1_embeddings_by_type(self): + """V1 embeddings_by_type response returns typed embeddings.""" + response = { + "response_type": "embeddings_by_type", + "embeddings": { + "float_": [[0.1, 0.2], [0.3, 0.4]], + "int8": [[1, 2], [3, 4]], + }, + } + results = list(extract_embeddings_from_response(response, ["a", "b"])) + + # 2 texts * 2 types = 4 embeddings + self.assertEqual(len(results), 4) + float_results = [r for r in results if r.embedding_type == "float"] + int8_results = [r for r in results if r.embedding_type == "int8"] + self.assertEqual(len(float_results), 2) + self.assertEqual(len(int8_results), 2) + + def test_v2_response_format(self): + """V2 response (no response_type) returns dict embeddings.""" + response = { + "embeddings": { + "float_": [[0.1, 0.2], [0.3, 0.4]], + }, + } + results = list(extract_embeddings_from_response(response, ["x", "y"])) + + self.assertEqual(len(results), 2) + self.assertEqual(results[0].embedding_type, "float") + self.assertEqual(results[0].text, "x") + + def test_global_offset(self): + """Global offset adjusts indices for batched processing.""" + response = { + "response_type": "embeddings_floats", + "embeddings": [[0.1], [0.2]], + } + results = list(extract_embeddings_from_response(response, ["c", "d"], global_offset=100)) + + self.assertEqual(results[0].index, 100) + self.assertEqual(results[1].index, 101) + + def test_empty_embeddings(self): + """Empty response yields nothing.""" + response = {"response_type": "embeddings_floats", "embeddings": []} + results = list(extract_embeddings_from_response(response, [])) + self.assertEqual(results, []) + + def test_texts_shorter_than_embeddings(self): + """Text is None when batch_texts runs out.""" + response = { + "response_type": "embeddings_floats", + "embeddings": [[0.1], [0.2], [0.3]], + } + results = list(extract_embeddings_from_response(response, ["only_one"])) + + self.assertEqual(results[0].text, "only_one") + self.assertIsNone(results[1].text) + self.assertIsNone(results[2].text) + + +class TestBatchSizeConstant(unittest.TestCase): + """Test that batch_size defaults come from config, not magic numbers.""" + + def test_default_batch_size_matches_api_limit(self): + self.assertEqual(embed_stream_batch_size, 96) + + +if __name__ == "__main__": + unittest.main()