Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
weaviate-client>=4.16.7
weaviate-client>=4.20.4
click==8.1.7
twine
pytest
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ classifiers =
include_package_data = True
python_requires = >=3.9
install_requires =
weaviate-client>=4.19.0
weaviate-client>=4.20.4
click==8.1.7
semver>=3.0.2
numpy>=1.24.0
Expand Down
41 changes: 41 additions & 0 deletions weaviate_cli/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def create() -> None:
"hnsw_acorn",
"hnsw_multivector",
"flat_bq",
"hfresh",
]
),
help="Vector index type (default: 'hnsw').",
Expand Down Expand Up @@ -184,6 +185,36 @@ def create() -> None:
type=str,
help="Date property name for TTL when object_ttl_type is 'property' (default: 'releaseDate'). Only valid when --object_ttl_type=property.",
)
@click.option(
"--hfresh_max_posting_size_kb",
default=CreateCollectionDefaults.hfresh_max_posting_size_kb,
type=int,
help="hfresh - max posting size in KB (default: None).",
)
@click.option(
"--hfresh_replicas",
default=CreateCollectionDefaults.hfresh_replicas,
type=int,
help="hfresh - number of replicas for each element in different posting lists (default: None).",
)
@click.option(
"--hfresh_search_probe",
default=CreateCollectionDefaults.hfresh_search_probe,
type=int,
help="hfresh - search probe (default: None).",
)
@click.option(
"--distance_metric",
default=CreateCollectionDefaults.distance_metric,
type=click.Choice(["cosine", "dot", "l2-squared", "hamming", "manhattan"]),
help="Distance metric (default: None, set by Weaviate server).",
)
@click.option(
"--rescore_limit",
default=CreateCollectionDefaults.rescore_limit,
type=int,
help="Rescore limit (default: None, set by Weaviate server).",
)
@click.pass_context
def create_collection_cli(
ctx: click.Context,
Expand All @@ -203,6 +234,11 @@ def create_collection_cli(
replication_deletion_strategy: Optional[str],
named_vector: bool,
named_vector_name: Optional[str],
hfresh_max_posting_size_kb: Optional[int],
hfresh_replicas: Optional[int],
hfresh_search_probe: Optional[int],
distance_metric: Optional[str],
rescore_limit: Optional[int],
json_output: bool,
object_ttl_type: str,
object_ttl_time: Optional[int],
Expand Down Expand Up @@ -243,6 +279,11 @@ def create_collection_cli(
replication_deletion_strategy=replication_deletion_strategy,
named_vector=named_vector,
named_vector_name=named_vector_name,
hfresh_max_posting_size_kb=hfresh_max_posting_size_kb,
hfresh_replicas=hfresh_replicas,
hfresh_search_probe=hfresh_search_probe,
distance_metric=distance_metric,
rescore_limit=rescore_limit,
json_output=json_output,
object_ttl_type=object_ttl_type,
object_ttl_time=object_ttl_time,
Expand Down
5 changes: 5 additions & 0 deletions weaviate_cli/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class CreateCollectionDefaults:
replication_deletion_strategy: Optional[str] = None
named_vector: bool = False
named_vector_name: Optional[str] = "default"
hfresh_max_posting_size_kb: Optional[int] = None
hfresh_replicas: Optional[int] = None
hfresh_search_probe: Optional[int] = None
distance_metric: Optional[str] = None
rescore_limit: Optional[int] = None
object_ttl_type: str = "create"
object_ttl_time: Optional[int] = None
object_ttl_filter_expired: Optional[bool] = None
Expand Down
140 changes: 117 additions & 23 deletions weaviate_cli/managers/collection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from weaviate.collections import Collection
from weaviate.collections.classes.config import _CollectionConfigSimple
from weaviate.collections.classes.tenants import TenantActivityStatus
from weaviate.classes.config import VectorFilterStrategy
from weaviate.collections.classes.config_vector_index import VectorFilterStrategy
from weaviate_cli.defaults import (
CreateCollectionDefaults,
UpdateCollectionDefaults,
Expand Down Expand Up @@ -147,6 +147,46 @@ def _print_text():
def get_all_collections(self) -> dict[str, _CollectionConfigSimple]:
return self.client.collections.list_all()

def _build_hfresh_config(
self,
max_posting_size_kb: Optional[int] = None,
distance_metric: Optional[str] = "cosine",
rescore_limit: Optional[int] = None,
replicas: Optional[int] = None,
search_probe: Optional[int] = None,
):
"""Build hfresh configuration with provided parameters."""
# Explicit mapping of distance metric strings to enum values
distance_metric_map = {
"cosine": wvc.VectorDistances.COSINE,
"dot": wvc.VectorDistances.DOT,
"l2-squared": wvc.VectorDistances.L2_SQUARED,
"hamming": wvc.VectorDistances.HAMMING,
"manhattan": wvc.VectorDistances.MANHATTAN,
}

kwargs = {}

if max_posting_size_kb is not None:
kwargs["max_posting_size_kb"] = max_posting_size_kb
if distance_metric is not None:
if distance_metric not in distance_metric_map:
raise ValueError(
f"Invalid distance_metric: '{distance_metric}'. "
f"Must be one of: {list(distance_metric_map.keys())}"
)
kwargs["distance_metric"] = distance_metric_map[distance_metric]
if replicas is not None:
kwargs["replicas"] = replicas
if search_probe is not None:
kwargs["search_probe"] = search_probe
if rescore_limit is not None:
kwargs["quantizer"] = wvc.Configure.VectorIndex.Quantizer.rq(
bits=8, rescore_limit=rescore_limit
)

return wvc.Configure.VectorIndex.hfresh(**kwargs)

def create_collection(
self,
collection: str = CreateCollectionDefaults.collection,
Expand All @@ -169,6 +209,15 @@ def create_collection(
] = CreateCollectionDefaults.replication_deletion_strategy,
named_vector: bool = CreateCollectionDefaults.named_vector,
named_vector_name: Optional[str] = CreateCollectionDefaults.named_vector_name,
hfresh_max_posting_size_kb: Optional[
int
] = CreateCollectionDefaults.hfresh_max_posting_size_kb,
hfresh_replicas: Optional[int] = CreateCollectionDefaults.hfresh_replicas,
hfresh_search_probe: Optional[
int
] = CreateCollectionDefaults.hfresh_search_probe,
distance_metric: Optional[str] = CreateCollectionDefaults.distance_metric,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is out of topic for this PR...but as we are already adding both parameters distance_metric and rescore_limit we should also pass it to other vector indeces...what do you think? would it be too much effort? I think all of them support these two arguments, it would be just bypassing it (if it's None it will use the default value anyway)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Done!

rescore_limit: Optional[int] = CreateCollectionDefaults.rescore_limit,
json_output: bool = False,
object_ttl_type: str = CreateCollectionDefaults.object_ttl_type,
object_ttl_time: Optional[int] = CreateCollectionDefaults.object_ttl_time,
Expand Down Expand Up @@ -200,40 +249,55 @@ def create_collection(
)

vector_index_map: Dict[str, wvc.VectorIndexConfig] = {
"hnsw": wvc.Configure.VectorIndex.hnsw(),
"flat": wvc.Configure.VectorIndex.flat(),
"hnsw": wvc.Configure.VectorIndex.hnsw(distance_metric=distance_metric),
"flat": wvc.Configure.VectorIndex.flat(distance_metric=distance_metric),
"dynamic": wvc.Configure.VectorIndex.dynamic(),
"dynamic_flat_bq": wvc.Configure.VectorIndex.dynamic(
flat=wvc.Configure.VectorIndex.flat(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(),
distance_metric=distance_metric,
)
),
"dynamic_flat_bq_hnsw_pq": wvc.Configure.VectorIndex.dynamic(
flat=wvc.Configure.VectorIndex.flat(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
hnsw=wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.pq(
training_limit=training_limit
)
),
distance_metric=distance_metric,
),
),
"dynamic_flat_bq_hnsw_sq": wvc.Configure.VectorIndex.dynamic(
flat=wvc.Configure.VectorIndex.flat(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
hnsw=wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.sq(
training_limit=training_limit
)
rescore_limit=rescore_limit, training_limit=training_limit
),
distance_metric=distance_metric,
),
),
"dynamic_flat_bq_hnsw_bq": wvc.Configure.VectorIndex.dynamic(
flat=wvc.Configure.VectorIndex.flat(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
hnsw=wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
),
"dynamic_hnsw_pq": wvc.Configure.VectorIndex.dynamic(
Expand All @@ -246,45 +310,75 @@ def create_collection(
"dynamic_hnsw_sq": wvc.Configure.VectorIndex.dynamic(
hnsw=wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.sq(
training_limit=training_limit
)
rescore_limit=rescore_limit, training_limit=training_limit
),
distance_metric=distance_metric,
)
),
"dynamic_hnsw_bq": wvc.Configure.VectorIndex.dynamic(
hnsw=wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
)
),
"hnsw_pq": wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.pq(
training_limit=training_limit
)
),
distance_metric=distance_metric,
),
"hnsw_bq": wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
"hnsw_bq_cache": wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(cache=True)
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
cache=True, rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
"hnsw_sq": wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.sq(
training_limit=training_limit
)
rescore_limit=rescore_limit, training_limit=training_limit
),
distance_metric=distance_metric,
),
"hnsw_rq": wvc.Configure.VectorIndex.hnsw(
quantizer=wvc.Configure.VectorIndex.Quantizer.rq()
quantizer=wvc.Configure.VectorIndex.Quantizer.rq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
"hnsw_acorn": wvc.Configure.VectorIndex.hnsw(
filter_strategy=VectorFilterStrategy.ACORN
filter_strategy=VectorFilterStrategy.ACORN,
distance_metric=distance_metric,
),
"hnsw_multivector": wvc.Configure.VectorIndex.hnsw(
multi_vector=wvc.Configure.VectorIndex.MultiVector.multi_vector(),
distance_metric=distance_metric,
),
"flat_bq": wvc.Configure.VectorIndex.flat(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq()
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
"flat_bq_cache": wvc.Configure.VectorIndex.flat(
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(cache=True)
quantizer=wvc.Configure.VectorIndex.Quantizer.bq(
cache=True, rescore_limit=rescore_limit
),
distance_metric=distance_metric,
),
"hfresh": self._build_hfresh_config(
max_posting_size_kb=hfresh_max_posting_size_kb,
distance_metric=distance_metric,
rescore_limit=rescore_limit,
replicas=hfresh_replicas,
search_probe=hfresh_search_probe,
),
}

Expand Down
Loading