Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions weaviate/classes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
BM25OperatorFactory as BM25Operator,
)
from weaviate.collections.classes.grpc import (
Diversity,
GroupBy,
HybridFusion,
HybridVector,
Expand All @@ -21,6 +22,7 @@
from weaviate.collections.classes.types import GeoCoordinate

__all__ = [
"Diversity",
"Filter",
"FilterReturn",
"GeoCoordinate",
Expand Down
17 changes: 17 additions & 0 deletions weaviate/collections/classes/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,23 @@ class Rerank(_WeaviateInput):
query: Optional[str] = Field(default=None)


@dataclass
class _DiversityMMR:
"""Define MMR (Maximal Marginal Relevance) diversity selection."""

limit: Optional[int] = None
balance: Optional[float] = None


class Diversity:
"""Use this factory class to apply diversity selection to search results via MMR."""

def __init__(self) -> None:
raise TypeError("Diversity cannot be instantiated directly. Use Diversity.MMR(...).")

MMR = _DiversityMMR


@dataclass
class BM25OperatorOptions:
# replace with ClassVar[base_search_pb2.SearchOperatorOptions.Operator] once python 3.10 is removed
Expand Down
13 changes: 11 additions & 2 deletions weaviate/collections/grpc/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
QueryNested,
Rerank,
TargetVectorJoinType,
_DiversityMMR,
_MetadataQuery,
_QueryReference,
_QueryReferenceMultiTarget,
Expand Down Expand Up @@ -262,6 +263,7 @@ def near_vector(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[_DiversityMMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -275,7 +277,7 @@ def near_vector(
autocut=autocut,
group_by=group_by,
near_vector=self._parse_near_vector(
near_vector, certainty, distance, target_vector=target_vector
near_vector, certainty, distance, target_vector=target_vector, selection=selection
),
)

Expand All @@ -296,6 +298,7 @@ def near_object(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[_DiversityMMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -308,7 +311,9 @@ def near_object(
rerank=rerank,
autocut=autocut,
group_by=group_by,
near_object=self._parse_near_object(near_object, certainty, distance, target_vector),
near_object=self._parse_near_object(
near_object, certainty, distance, target_vector, selection=selection
),
)

def near_text(
Expand All @@ -330,6 +335,7 @@ def near_text(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[_DiversityMMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -349,6 +355,7 @@ def near_text(
move_away=move_away,
move_to=move_to,
target_vector=target_vector,
selection=selection,
),
)

Expand All @@ -370,6 +377,7 @@ def near_media(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[_DiversityMMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -388,6 +396,7 @@ def near_media(
certainty,
distance,
target_vector,
selection=selection,
),
)

Expand Down
28 changes: 28 additions & 0 deletions weaviate/collections/grpc/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PrimitiveVectorType,
TargetVectorJoinType,
TwoDimensionalVectorType,
_DiversityMMR,
_HybridNearText,
_HybridNearVector,
_ListOfVectorsQuery,
Expand Down Expand Up @@ -310,12 +311,26 @@ def _parse_near_options(
float(distance) if distance is not None else None,
)

@staticmethod
def _selection_to_grpc(
selection: Optional[_DiversityMMR],
) -> Optional[base_search_pb2.Selection]:
if selection is None:
return None
return base_search_pb2.Selection(
mmr=base_search_pb2.Selection.MMR(
limit=selection.limit,
balance=selection.balance,
)
)

def _parse_near_vector(
self,
near_vector: NearVectorInputType,
certainty: Optional[NUMBER],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[_DiversityMMR] = None,
) -> base_search_pb2.NearVector:
if self._validate_arguments:
_validate_input(
Expand Down Expand Up @@ -399,6 +414,7 @@ def _parse_near_vector(
vector_per_target=vector_per_target_tmp,
vector_for_targets=vector_for_targets,
vectors=vectors,
selection=self._selection_to_grpc(selection),
)

@staticmethod
Expand All @@ -423,6 +439,7 @@ def _parse_near_text(
move_to: Optional[Move],
move_away: Optional[Move],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[_DiversityMMR] = None,
) -> base_search_pb2.NearTextSearch:
if self._validate_arguments:
_validate_input(
Expand Down Expand Up @@ -451,6 +468,7 @@ def _parse_near_text(
move_to=self.__parse_move(move_to),
targets=targets,
target_vectors=target_vector,
selection=self._selection_to_grpc(selection),
)

def _parse_near_object(
Expand All @@ -459,6 +477,7 @@ def _parse_near_object(
certainty: Optional[NUMBER],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[_DiversityMMR] = None,
) -> base_search_pb2.NearObject:
if self._validate_arguments:
_validate_input(
Expand All @@ -482,6 +501,7 @@ def _parse_near_object(
distance=distance,
targets=targets,
target_vectors=target_vector,
selection=self._selection_to_grpc(selection),
)

def _parse_media(
Expand All @@ -491,6 +511,7 @@ def _parse_media(
certainty: Optional[NUMBER],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[_DiversityMMR] = None,
) -> dict:
if self._validate_arguments:
_validate_input(
Expand All @@ -508,13 +529,15 @@ def _parse_media(

kwargs: Dict[str, Any] = {}
targets, target_vector = self.__target_vector_to_grpc(target_vector)
selection_grpc = self._selection_to_grpc(selection)
if type_ == "audio":
kwargs["near_audio"] = base_search_pb2.NearAudioSearch(
audio=media,
distance=distance,
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "depth":
kwargs["near_depth"] = base_search_pb2.NearDepthSearch(
Expand All @@ -523,6 +546,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "image":
kwargs["near_image"] = base_search_pb2.NearImageSearch(
Expand All @@ -531,6 +555,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "imu":
kwargs["near_imu"] = base_search_pb2.NearIMUSearch(
Expand All @@ -539,6 +564,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "thermal":
kwargs["near_thermal"] = base_search_pb2.NearThermalSearch(
Expand All @@ -547,6 +573,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "video":
kwargs["near_video"] = base_search_pb2.NearVideoSearch(
Expand All @@ -555,6 +582,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
else:
raise ValueError(
Expand Down
16 changes: 16 additions & 0 deletions weaviate/collections/queries/near_image/query/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NearMediaType,
Rerank,
TargetVectorJoinType,
_DiversityMMR,
)
from weaviate.collections.classes.internal import (
CrossReferences,
Expand Down Expand Up @@ -52,6 +53,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Literal[None] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -72,6 +74,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Literal[None] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -92,6 +95,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Literal[None] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -112,6 +116,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Literal[None] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -132,6 +137,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Literal[None] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -152,6 +158,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Literal[None] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -174,6 +181,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: GroupBy,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -194,6 +202,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: GroupBy,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -214,6 +223,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: GroupBy,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -234,6 +244,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: GroupBy,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -254,6 +265,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: GroupBy,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -274,6 +286,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: GroupBy,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -295,6 +308,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Optional[GroupBy] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand All @@ -316,6 +330,7 @@ def near_image(
filters: Optional[FilterReturn] = None,
group_by: Optional[GroupBy] = None,
rerank: Optional[Rerank] = None,
selection: Optional[_DiversityMMR] = None,
target_vector: Optional[TargetVectorJoinType] = None,
include_vector: INCLUDE_VECTOR = False,
return_metadata: Optional[METADATA] = None,
Expand Down Expand Up @@ -385,6 +400,7 @@ def resp(
filters=filters,
group_by=_GroupBy.from_input(group_by),
rerank=rerank,
selection=selection,
target_vector=target_vector,
limit=limit,
offset=offset,
Expand Down
Loading
Loading