diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index 4b16cf4b..b2795870 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -125,8 +125,20 @@ class DynamicVamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = impl_->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + // Use adaptive batch sizing: start with at least k candidates, + // then adjust based on observed filter hit rate. + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); + const auto max_batch_size = batch_size; do { - iterator.next(k); + // Estimate how many candidates we need to find remaining + // results given the observed hit rate so far. + batch_size = + predict_further_processing(total_checked, found, k, batch_size); + // Cap to avoid oversized batches in the iterator. + batch_size = std::min(batch_size, max_batch_size); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); diff --git a/bindings/cpp/src/svs_runtime_utils.h b/bindings/cpp/src/svs_runtime_utils.h index e0d7c68a..6caa1a32 100644 --- a/bindings/cpp/src/svs_runtime_utils.h +++ b/bindings/cpp/src/svs_runtime_utils.h @@ -431,6 +431,19 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) { } } // namespace storage +// Predict how many more items need to be processed to reach the goal, +// based on the observed hit rate so far. +// If no hits yet, returns `hint` unchanged. +// The caller should cap the result to a max batch size if needed. +inline size_t +predict_further_processing(size_t processed, size_t hits, size_t goal, size_t hint) { + if (hits == 0 || hits >= goal) { + return hint; + } + float batch_size = static_cast(goal - hits) * processed / hits; + return std::max(static_cast(batch_size), size_t{1}); +} + inline svs::threads::ThreadPoolHandle default_threadpool() { return svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()) ); diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 4cf58d7e..2fd1f145 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -131,8 +131,20 @@ class VamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = get_impl()->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + // Use adaptive batch sizing: start with at least k candidates, + // then adjust based on observed filter hit rate. + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); + const auto max_batch_size = batch_size; do { - iterator.next(k); + // Estimate how many candidates we need to find remaining + // results given the observed hit rate so far. + batch_size = + predict_further_processing(total_checked, found, k, batch_size); + // Cap to avoid oversized batches in the iterator. + batch_size = std::min(batch_size, max_batch_size); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 201375d3..92b81989 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -501,6 +501,54 @@ CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { svs::runtime::v0::DynamicVamanaIndex::destroy(index); } +CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + size_t min_id = 0; + size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + status = + index->search(nq, xq, k, distances.data(), result_labels.data(), nullptr, &filter); + CATCH_REQUIRE(status.ok()); + + // All returned labels must fall inside the filter range + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels[i])) { + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + CATCH_TEST_CASE("RangeSearchFunctional", "[runtime]") { const auto& test_data = get_test_data(); // Build index