diff --git a/requirements-dev.txt b/requirements-dev.txt index 3646e43..64cd126 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ -weaviate-client>=4.16.7 +weaviate-client>=4.20.4 click==8.1.7 twine pytest diff --git a/setup.cfg b/setup.cfg index bae19a6..c7e11bf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ classifiers = include_package_data = True python_requires = >=3.9 install_requires = - weaviate-client>=4.19.0 + weaviate-client>=4.20.4 click==8.1.7 semver>=3.0.2 numpy>=1.24.0 diff --git a/weaviate_cli/commands/create.py b/weaviate_cli/commands/create.py index 71a96ad..0276cf8 100644 --- a/weaviate_cli/commands/create.py +++ b/weaviate_cli/commands/create.py @@ -77,6 +77,7 @@ def create() -> None: "hnsw_acorn", "hnsw_multivector", "flat_bq", + "hfresh", ] ), help="Vector index type (default: 'hnsw').", @@ -184,6 +185,36 @@ def create() -> None: type=str, help="Date property name for TTL when object_ttl_type is 'property' (default: 'releaseDate'). Only valid when --object_ttl_type=property.", ) +@click.option( + "--hfresh_max_posting_size_kb", + default=CreateCollectionDefaults.hfresh_max_posting_size_kb, + type=int, + help="hfresh - max posting size in KB (default: None).", +) +@click.option( + "--hfresh_replicas", + default=CreateCollectionDefaults.hfresh_replicas, + type=int, + help="hfresh - number of replicas for each element in different posting lists (default: None).", +) +@click.option( + "--hfresh_search_probe", + default=CreateCollectionDefaults.hfresh_search_probe, + type=int, + help="hfresh - search probe (default: None).", +) +@click.option( + "--distance_metric", + default=CreateCollectionDefaults.distance_metric, + type=click.Choice(["cosine", "dot", "l2-squared", "hamming", "manhattan"]), + help="Distance metric (default: None, set by Weaviate server).", +) +@click.option( + "--rescore_limit", + default=CreateCollectionDefaults.rescore_limit, + type=int, + help="Rescore limit (default: None, set by Weaviate server).", +) @click.pass_context def create_collection_cli( ctx: click.Context, @@ -203,6 +234,11 @@ def create_collection_cli( replication_deletion_strategy: Optional[str], named_vector: bool, named_vector_name: Optional[str], + hfresh_max_posting_size_kb: Optional[int], + hfresh_replicas: Optional[int], + hfresh_search_probe: Optional[int], + distance_metric: Optional[str], + rescore_limit: Optional[int], json_output: bool, object_ttl_type: str, object_ttl_time: Optional[int], @@ -243,6 +279,11 @@ def create_collection_cli( replication_deletion_strategy=replication_deletion_strategy, named_vector=named_vector, named_vector_name=named_vector_name, + hfresh_max_posting_size_kb=hfresh_max_posting_size_kb, + hfresh_replicas=hfresh_replicas, + hfresh_search_probe=hfresh_search_probe, + distance_metric=distance_metric, + rescore_limit=rescore_limit, json_output=json_output, object_ttl_type=object_ttl_type, object_ttl_time=object_ttl_time, diff --git a/weaviate_cli/defaults.py b/weaviate_cli/defaults.py index e6a5d4f..9d0fb4e 100644 --- a/weaviate_cli/defaults.py +++ b/weaviate_cli/defaults.py @@ -78,6 +78,11 @@ class CreateCollectionDefaults: replication_deletion_strategy: Optional[str] = None named_vector: bool = False named_vector_name: Optional[str] = "default" + hfresh_max_posting_size_kb: Optional[int] = None + hfresh_replicas: Optional[int] = None + hfresh_search_probe: Optional[int] = None + distance_metric: Optional[str] = None + rescore_limit: Optional[int] = None object_ttl_type: str = "create" object_ttl_time: Optional[int] = None object_ttl_filter_expired: Optional[bool] = None diff --git a/weaviate_cli/managers/collection_manager.py b/weaviate_cli/managers/collection_manager.py index 92c7555..d783252 100644 --- a/weaviate_cli/managers/collection_manager.py +++ b/weaviate_cli/managers/collection_manager.py @@ -5,7 +5,7 @@ from weaviate.collections import Collection from weaviate.collections.classes.config import _CollectionConfigSimple from weaviate.collections.classes.tenants import TenantActivityStatus -from weaviate.classes.config import VectorFilterStrategy +from weaviate.collections.classes.config_vector_index import VectorFilterStrategy from weaviate_cli.defaults import ( CreateCollectionDefaults, UpdateCollectionDefaults, @@ -147,6 +147,46 @@ def _print_text(): def get_all_collections(self) -> dict[str, _CollectionConfigSimple]: return self.client.collections.list_all() + def _build_hfresh_config( + self, + max_posting_size_kb: Optional[int] = None, + distance_metric: Optional[str] = "cosine", + rescore_limit: Optional[int] = None, + replicas: Optional[int] = None, + search_probe: Optional[int] = None, + ): + """Build hfresh configuration with provided parameters.""" + # Explicit mapping of distance metric strings to enum values + distance_metric_map = { + "cosine": wvc.VectorDistances.COSINE, + "dot": wvc.VectorDistances.DOT, + "l2-squared": wvc.VectorDistances.L2_SQUARED, + "hamming": wvc.VectorDistances.HAMMING, + "manhattan": wvc.VectorDistances.MANHATTAN, + } + + kwargs = {} + + if max_posting_size_kb is not None: + kwargs["max_posting_size_kb"] = max_posting_size_kb + if distance_metric is not None: + if distance_metric not in distance_metric_map: + raise ValueError( + f"Invalid distance_metric: '{distance_metric}'. " + f"Must be one of: {list(distance_metric_map.keys())}" + ) + kwargs["distance_metric"] = distance_metric_map[distance_metric] + if replicas is not None: + kwargs["replicas"] = replicas + if search_probe is not None: + kwargs["search_probe"] = search_probe + if rescore_limit is not None: + kwargs["quantizer"] = wvc.Configure.VectorIndex.Quantizer.rq( + bits=8, rescore_limit=rescore_limit + ) + + return wvc.Configure.VectorIndex.hfresh(**kwargs) + def create_collection( self, collection: str = CreateCollectionDefaults.collection, @@ -169,6 +209,15 @@ def create_collection( ] = CreateCollectionDefaults.replication_deletion_strategy, named_vector: bool = CreateCollectionDefaults.named_vector, named_vector_name: Optional[str] = CreateCollectionDefaults.named_vector_name, + hfresh_max_posting_size_kb: Optional[ + int + ] = CreateCollectionDefaults.hfresh_max_posting_size_kb, + hfresh_replicas: Optional[int] = CreateCollectionDefaults.hfresh_replicas, + hfresh_search_probe: Optional[ + int + ] = CreateCollectionDefaults.hfresh_search_probe, + distance_metric: Optional[str] = CreateCollectionDefaults.distance_metric, + rescore_limit: Optional[int] = CreateCollectionDefaults.rescore_limit, json_output: bool = False, object_ttl_type: str = CreateCollectionDefaults.object_ttl_type, object_ttl_time: Optional[int] = CreateCollectionDefaults.object_ttl_time, @@ -200,40 +249,55 @@ def create_collection( ) vector_index_map: Dict[str, wvc.VectorIndexConfig] = { - "hnsw": wvc.Configure.VectorIndex.hnsw(), - "flat": wvc.Configure.VectorIndex.flat(), + "hnsw": wvc.Configure.VectorIndex.hnsw(distance_metric=distance_metric), + "flat": wvc.Configure.VectorIndex.flat(distance_metric=distance_metric), "dynamic": wvc.Configure.VectorIndex.dynamic(), "dynamic_flat_bq": wvc.Configure.VectorIndex.dynamic( flat=wvc.Configure.VectorIndex.flat( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq(), + distance_metric=distance_metric, ) ), "dynamic_flat_bq_hnsw_pq": wvc.Configure.VectorIndex.dynamic( flat=wvc.Configure.VectorIndex.flat( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), hnsw=wvc.Configure.VectorIndex.hnsw( quantizer=wvc.Configure.VectorIndex.Quantizer.pq( training_limit=training_limit - ) + ), + distance_metric=distance_metric, ), ), "dynamic_flat_bq_hnsw_sq": wvc.Configure.VectorIndex.dynamic( flat=wvc.Configure.VectorIndex.flat( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), hnsw=wvc.Configure.VectorIndex.hnsw( quantizer=wvc.Configure.VectorIndex.Quantizer.sq( - training_limit=training_limit - ) + rescore_limit=rescore_limit, training_limit=training_limit + ), + distance_metric=distance_metric, ), ), "dynamic_flat_bq_hnsw_bq": wvc.Configure.VectorIndex.dynamic( flat=wvc.Configure.VectorIndex.flat( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), hnsw=wvc.Configure.VectorIndex.hnsw( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), ), "dynamic_hnsw_pq": wvc.Configure.VectorIndex.dynamic( @@ -246,45 +310,75 @@ def create_collection( "dynamic_hnsw_sq": wvc.Configure.VectorIndex.dynamic( hnsw=wvc.Configure.VectorIndex.hnsw( quantizer=wvc.Configure.VectorIndex.Quantizer.sq( - training_limit=training_limit - ) + rescore_limit=rescore_limit, training_limit=training_limit + ), + distance_metric=distance_metric, ) ), "dynamic_hnsw_bq": wvc.Configure.VectorIndex.dynamic( hnsw=wvc.Configure.VectorIndex.hnsw( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ) ), "hnsw_pq": wvc.Configure.VectorIndex.hnsw( quantizer=wvc.Configure.VectorIndex.Quantizer.pq( training_limit=training_limit - ) + ), + distance_metric=distance_metric, ), "hnsw_bq": wvc.Configure.VectorIndex.hnsw( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), "hnsw_bq_cache": wvc.Configure.VectorIndex.hnsw( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq(cache=True) + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + cache=True, rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), "hnsw_sq": wvc.Configure.VectorIndex.hnsw( quantizer=wvc.Configure.VectorIndex.Quantizer.sq( - training_limit=training_limit - ) + rescore_limit=rescore_limit, training_limit=training_limit + ), + distance_metric=distance_metric, ), "hnsw_rq": wvc.Configure.VectorIndex.hnsw( - quantizer=wvc.Configure.VectorIndex.Quantizer.rq() + quantizer=wvc.Configure.VectorIndex.Quantizer.rq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), "hnsw_acorn": wvc.Configure.VectorIndex.hnsw( - filter_strategy=VectorFilterStrategy.ACORN + filter_strategy=VectorFilterStrategy.ACORN, + distance_metric=distance_metric, ), "hnsw_multivector": wvc.Configure.VectorIndex.hnsw( multi_vector=wvc.Configure.VectorIndex.MultiVector.multi_vector(), + distance_metric=distance_metric, ), "flat_bq": wvc.Configure.VectorIndex.flat( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq() + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + rescore_limit=rescore_limit + ), + distance_metric=distance_metric, ), "flat_bq_cache": wvc.Configure.VectorIndex.flat( - quantizer=wvc.Configure.VectorIndex.Quantizer.bq(cache=True) + quantizer=wvc.Configure.VectorIndex.Quantizer.bq( + cache=True, rescore_limit=rescore_limit + ), + distance_metric=distance_metric, + ), + "hfresh": self._build_hfresh_config( + max_posting_size_kb=hfresh_max_posting_size_kb, + distance_metric=distance_metric, + rescore_limit=rescore_limit, + replicas=hfresh_replicas, + search_probe=hfresh_search_probe, ), }