Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion src/cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
from .config import embed_batch_size
from .config import embed_batch_size, embed_stream_batch_size
from .core import RequestOptions
from .environment import ClientEnvironment
from .manually_maintained.cache import CacheMixin
Expand Down Expand Up @@ -223,6 +223,61 @@ def embed(

return merge_embed_responses(responses)

def embed_stream(
self,
*,
texts: typing.Sequence[str],
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
batch_size: int = embed_stream_batch_size,
request_options: typing.Optional[RequestOptions] = None,
) -> typing.Iterator[typing.Any]:
"""
Memory-efficient embed that yields embeddings one batch at a time.

Processes texts in batches and yields individual StreamedEmbedding objects
as they come back, so you can write to a vector store incrementally without
holding all embeddings in memory.

Args:
texts: Texts to embed.
model: Embedding model ID.
input_type: Input type (search_document, search_query, etc.).
embedding_types: Types of embeddings to return (float, int8, etc.).
truncate: How to handle inputs longer than the max token length.
batch_size: Texts per API call. Defaults to 96 (API max).
request_options: Request-specific configuration.

Yields:
StreamedEmbedding with index, embedding, embedding_type, and text.
"""
from .manually_maintained.streaming_embed import extract_embeddings_from_response

if not texts:
return
if batch_size < 1:
raise ValueError("batch_size must be at least 1")

texts_list = list(texts)

for batch_start in range(0, len(texts_list), batch_size):
batch_texts = texts_list[batch_start : batch_start + batch_size]

response = BaseCohere.embed(
self,
texts=batch_texts,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)

response_data = response.dict() if hasattr(response, "dict") else response.__dict__
yield from extract_embeddings_from_response(response_data, batch_texts, batch_start)

"""
The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
Expand Down
1 change: 1 addition & 0 deletions src/cohere/config.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
embed_batch_size = 96
embed_stream_batch_size = 96 # Max texts per API request (API limit)
74 changes: 74 additions & 0 deletions src/cohere/manually_maintained/streaming_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Utilities for streaming embed responses without loading all embeddings into memory."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterator, List, Optional, Union


@dataclass
class StreamedEmbedding:
"""A single embedding yielded incrementally from embed_stream()."""
index: int
embedding: Union[List[float], List[int]]
embedding_type: str
text: Optional[str] = None


def extract_embeddings_from_response(
response_data: dict,
batch_texts: List[str],
global_offset: int = 0,
) -> Iterator[StreamedEmbedding]:
"""
Extract individual embeddings from a Cohere embed response dict.

Works for both V1 (embeddings_floats / embeddings_by_type) and V2 response formats.

Args:
response_data: Parsed JSON response from embed endpoint
batch_texts: The texts that were embedded in this batch
global_offset: Starting index for this batch within the full dataset

Yields:
StreamedEmbedding objects
"""
response_type = response_data.get("response_type", "")

if response_type == "embeddings_floats":
embeddings = response_data.get("embeddings", [])
for i, embedding in enumerate(embeddings):
yield StreamedEmbedding(
index=global_offset + i,
embedding=embedding,
embedding_type="float",
text=batch_texts[i] if i < len(batch_texts) else None,
)

elif response_type == "embeddings_by_type":
embeddings_obj = response_data.get("embeddings", {})
for emb_type, embeddings_list in embeddings_obj.items():
type_name = emb_type.rstrip("_")
if isinstance(embeddings_list, list):
for i, embedding in enumerate(embeddings_list):
yield StreamedEmbedding(
index=global_offset + i,
embedding=embedding,
embedding_type=type_name,
text=batch_texts[i] if i < len(batch_texts) else None,
)

else:
# V2 format: embeddings is a dict with type keys directly
embeddings_obj = response_data.get("embeddings", {})
if isinstance(embeddings_obj, dict):
for emb_type, embeddings_list in embeddings_obj.items():
type_name = emb_type.rstrip("_")
if isinstance(embeddings_list, list):
for i, embedding in enumerate(embeddings_list):
yield StreamedEmbedding(
index=global_offset + i,
embedding=embedding,
embedding_type=type_name,
text=batch_texts[i] if i < len(batch_texts) else None,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated extraction logic between response type branches

Low Severity

The else branch (V2 format) at lines 61–74 is a near-exact copy of the embeddings_by_type branch at lines 48–59, differing only by an extra isinstance(embeddings_obj, dict) guard. This duplication means any future bug fix or enhancement needs to be applied in both places. Additionally, since embed_stream only calls the V1 BaseCohere.embed() — which always returns a response with response_type set to "embeddings_floats" or "embeddings_by_type" — the else branch is unreachable dead code in the current usage.

Additional Locations (1)
Fix in Cursor Fix in Web

120 changes: 120 additions & 0 deletions tests/test_embed_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Tests for memory-efficient embed_stream functionality.

All embed_stream code lives in manually maintained files (.fernignore protected):
- src/cohere/client.py — Client.embed_stream()
- src/cohere/manually_maintained/streaming_embed.py — StreamedEmbedding, extraction helpers
"""

import unittest

from cohere.manually_maintained.streaming_embed import (
StreamedEmbedding,
extract_embeddings_from_response,
)
from cohere.config import embed_stream_batch_size


class TestStreamedEmbedding(unittest.TestCase):
"""Test the StreamedEmbedding dataclass."""

def test_creation(self):
emb = StreamedEmbedding(index=0, embedding=[0.1, 0.2], embedding_type="float", text="hello")
self.assertEqual(emb.index, 0)
self.assertEqual(emb.embedding, [0.1, 0.2])
self.assertEqual(emb.embedding_type, "float")
self.assertEqual(emb.text, "hello")

def test_text_optional(self):
emb = StreamedEmbedding(index=0, embedding=[0.1], embedding_type="float")
self.assertIsNone(emb.text)


class TestExtractEmbeddings(unittest.TestCase):
"""Test extract_embeddings_from_response for V1 and V2 formats."""

def test_v1_embeddings_floats(self):
"""V1 embeddings_floats response returns flat float embeddings."""
response = {
"response_type": "embeddings_floats",
"embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
}
results = list(extract_embeddings_from_response(response, ["hello", "world"]))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].index, 0)
self.assertEqual(results[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(results[0].embedding_type, "float")
self.assertEqual(results[0].text, "hello")
self.assertEqual(results[1].index, 1)
self.assertEqual(results[1].text, "world")

def test_v1_embeddings_by_type(self):
"""V1 embeddings_by_type response returns typed embeddings."""
response = {
"response_type": "embeddings_by_type",
"embeddings": {
"float_": [[0.1, 0.2], [0.3, 0.4]],
"int8": [[1, 2], [3, 4]],
},
}
results = list(extract_embeddings_from_response(response, ["a", "b"]))

# 2 texts * 2 types = 4 embeddings
self.assertEqual(len(results), 4)
float_results = [r for r in results if r.embedding_type == "float"]
int8_results = [r for r in results if r.embedding_type == "int8"]
self.assertEqual(len(float_results), 2)
self.assertEqual(len(int8_results), 2)

def test_v2_response_format(self):
"""V2 response (no response_type) returns dict embeddings."""
response = {
"embeddings": {
"float_": [[0.1, 0.2], [0.3, 0.4]],
},
}
results = list(extract_embeddings_from_response(response, ["x", "y"]))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].embedding_type, "float")
self.assertEqual(results[0].text, "x")

def test_global_offset(self):
"""Global offset adjusts indices for batched processing."""
response = {
"response_type": "embeddings_floats",
"embeddings": [[0.1], [0.2]],
}
results = list(extract_embeddings_from_response(response, ["c", "d"], global_offset=100))

self.assertEqual(results[0].index, 100)
self.assertEqual(results[1].index, 101)

def test_empty_embeddings(self):
"""Empty response yields nothing."""
response = {"response_type": "embeddings_floats", "embeddings": []}
results = list(extract_embeddings_from_response(response, []))
self.assertEqual(results, [])

def test_texts_shorter_than_embeddings(self):
"""Text is None when batch_texts runs out."""
response = {
"response_type": "embeddings_floats",
"embeddings": [[0.1], [0.2], [0.3]],
}
results = list(extract_embeddings_from_response(response, ["only_one"]))

self.assertEqual(results[0].text, "only_one")
self.assertIsNone(results[1].text)
self.assertIsNone(results[2].text)


class TestBatchSizeConstant(unittest.TestCase):
"""Test that batch_size defaults come from config, not magic numbers."""

def test_default_batch_size_matches_api_limit(self):
self.assertEqual(embed_stream_batch_size, 96)


if __name__ == "__main__":
unittest.main()