From 2af16f1f253f5272412b02701282bdbdd43e7279 Mon Sep 17 00:00:00 2001 From: yuejiaointel Date: Thu, 2 Apr 2026 14:01:43 -0700 Subject: [PATCH 1/5] add adaptive batch size heuristic for filtered search --- bindings/cpp/src/dynamic_vamana_index_impl.h | 21 ++++++++- bindings/cpp/tests/runtime_test.cpp | 49 ++++++++++++++++++++ bindings/cpp/tests/utils.h | 13 ++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index 4b16cf4bc..c7bd75041 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -38,6 +38,20 @@ namespace svs { namespace runtime { +// Compute the next batch size based on observed filter hit rate. +// On the first round (found == 0), returns initial_batch_size unchanged. +// On subsequent rounds, estimates how many candidates are needed to find the +// remaining results given the observed hit rate. +inline size_t compute_filtered_batch_size( + size_t found, size_t needed, size_t total_checked, size_t initial_batch_size +) { + if (found == 0 || found >= needed) { + return initial_batch_size; + } + double hit_rate = static_cast(found) / total_checked; + return static_cast((needed - found) / hit_rate); +} + // Dynamic Vamana index implementation class DynamicVamanaIndexImpl { using allocator_type = svs::data::Blocked>; @@ -125,9 +139,12 @@ class DynamicVamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = impl_->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + auto batch_size = sp.buffer_config_.get_search_window_size(); do { - iterator.next(k); + iterator.next(batch_size); for (auto& neighbor : iterator.results()) { + total_checked++; if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); found++; @@ -136,6 +153,8 @@ class DynamicVamanaIndexImpl { } } } + batch_size = + compute_filtered_batch_size(found, k, total_checked, batch_size); } while (found < k && !iterator.done()); // Pad results if not enough neighbors found diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 201375d3c..2f296790c 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -501,6 +501,55 @@ CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { svs::runtime::v0::DynamicVamanaIndex::destroy(index); } +CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept every 10th ID + std::unordered_set valid_ids; + for (size_t i = 0; i < test_n; i += 10) { + valid_ids.insert(i); + } + test_utils::IDFilterSet filter(valid_ids); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + status = + index->search(nq, xq, k, distances.data(), result_labels.data(), nullptr, &filter); + CATCH_REQUIRE(status.ok()); + + // All returned labels must be in the valid set + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels[i])) { + CATCH_REQUIRE(valid_ids.contains(result_labels[i])); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + CATCH_TEST_CASE("RangeSearchFunctional", "[runtime]") { const auto& test_data = get_test_data(); // Build index diff --git a/bindings/cpp/tests/utils.h b/bindings/cpp/tests/utils.h index 8d1bc89f6..e2174b938 100644 --- a/bindings/cpp/tests/utils.h +++ b/bindings/cpp/tests/utils.h @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace svs_test { @@ -73,6 +74,18 @@ class IDFilterRange : public svs::runtime::v0::IDFilter { bool is_member(size_t id) const override { return id >= min_id_ && id < max_id_; } }; +// ID filter that accepts only IDs in a given set +class IDFilterSet : public svs::runtime::v0::IDFilter { + private: + std::unordered_set valid_ids_; + + public: + IDFilterSet(std::unordered_set ids) + : valid_ids_(std::move(ids)) {} + + bool is_member(size_t id) const override { return valid_ids_.contains(id); } +}; + // Custom results allocator for testing class TestResultsAllocator : public svs::runtime::v0::ResultsAllocator { private: From 605a0bef2b9ec8de29b0b720a365d38be9775c66 Mon Sep 17 00:00:00 2001 From: yuejiaointel Date: Thu, 2 Apr 2026 14:16:16 -0700 Subject: [PATCH 2/5] use IDFilterRange instead of IDFilterSet in test --- bindings/cpp/tests/runtime_test.cpp | 15 +++++++-------- bindings/cpp/tests/utils.h | 13 ------------- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 2f296790c..92b819894 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -526,12 +526,10 @@ CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") { const float* xq = test_data.data(); const int k = 5; - // 10% selectivity: accept every 10th ID - std::unordered_set valid_ids; - for (size_t i = 0; i < test_n; i += 10) { - valid_ids.insert(i); - } - test_utils::IDFilterSet filter(valid_ids); + // 10% selectivity: accept only IDs 0-9 out of 100 + size_t min_id = 0; + size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); std::vector distances(nq * k); std::vector result_labels(nq * k); @@ -540,10 +538,11 @@ CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") { index->search(nq, xq, k, distances.data(), result_labels.data(), nullptr, &filter); CATCH_REQUIRE(status.ok()); - // All returned labels must be in the valid set + // All returned labels must fall inside the filter range for (int i = 0; i < nq * k; ++i) { if (svs::runtime::v0::is_specified(result_labels[i])) { - CATCH_REQUIRE(valid_ids.contains(result_labels[i])); + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); } } diff --git a/bindings/cpp/tests/utils.h b/bindings/cpp/tests/utils.h index e2174b938..8d1bc89f6 100644 --- a/bindings/cpp/tests/utils.h +++ b/bindings/cpp/tests/utils.h @@ -22,7 +22,6 @@ #include #include #include -#include #include namespace svs_test { @@ -74,18 +73,6 @@ class IDFilterRange : public svs::runtime::v0::IDFilter { bool is_member(size_t id) const override { return id >= min_id_ && id < max_id_; } }; -// ID filter that accepts only IDs in a given set -class IDFilterSet : public svs::runtime::v0::IDFilter { - private: - std::unordered_set valid_ids_; - - public: - IDFilterSet(std::unordered_set ids) - : valid_ids_(std::move(ids)) {} - - bool is_member(size_t id) const override { return valid_ids_.contains(id); } -}; - // Custom results allocator for testing class TestResultsAllocator : public svs::runtime::v0::ResultsAllocator { private: From 309d0add26f4709a8e811ffd704854ae11588a6f Mon Sep 17 00:00:00 2001 From: yuejiaointel Date: Fri, 3 Apr 2026 17:12:36 -0700 Subject: [PATCH 3/5] address PR review: refactor and optimize adaptive batch size - Rename compute_filtered_batch_size to predict_further_processing and move to svs_runtime_utils.h for reuse - Use float arithmetic instead of double for hit rate calculation - Compute batch size at loop start to avoid unnecessary computation - Use iterator.size() instead of per-element increment for total_checked - Initial batch size = max(k, search_window_size) - Apply adaptive batch size to vamana_index_impl.h filtered search --- bindings/cpp/src/dynamic_vamana_index_impl.h | 22 ++++---------------- bindings/cpp/src/svs_runtime_utils.h | 14 +++++++++++++ bindings/cpp/src/vamana_index_impl.h | 7 ++++++- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index c7bd75041..fe4c0b49b 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -38,20 +38,6 @@ namespace svs { namespace runtime { -// Compute the next batch size based on observed filter hit rate. -// On the first round (found == 0), returns initial_batch_size unchanged. -// On subsequent rounds, estimates how many candidates are needed to find the -// remaining results given the observed hit rate. -inline size_t compute_filtered_batch_size( - size_t found, size_t needed, size_t total_checked, size_t initial_batch_size -) { - if (found == 0 || found >= needed) { - return initial_batch_size; - } - double hit_rate = static_cast(found) / total_checked; - return static_cast((needed - found) / hit_rate); -} - // Dynamic Vamana index implementation class DynamicVamanaIndexImpl { using allocator_type = svs::data::Blocked>; @@ -140,11 +126,13 @@ class DynamicVamanaIndexImpl { auto iterator = impl_->batch_iterator(query); size_t found = 0; size_t total_checked = 0; - auto batch_size = sp.buffer_config_.get_search_window_size(); + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); do { + batch_size = + predict_further_processing(total_checked, found, k, batch_size); iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { - total_checked++; if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); found++; @@ -153,8 +141,6 @@ class DynamicVamanaIndexImpl { } } } - batch_size = - compute_filtered_batch_size(found, k, total_checked, batch_size); } while (found < k && !iterator.done()); // Pad results if not enough neighbors found diff --git a/bindings/cpp/src/svs_runtime_utils.h b/bindings/cpp/src/svs_runtime_utils.h index e0d7c68af..b5fd12756 100644 --- a/bindings/cpp/src/svs_runtime_utils.h +++ b/bindings/cpp/src/svs_runtime_utils.h @@ -431,6 +431,20 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) { } } // namespace storage +// Predict how many more items need to be processed to reach the goal, +// based on the observed hit rate so far. +// If no hits yet, returns `hint` unchanged. +// The caller should cap the result to a max batch size if needed. +inline size_t predict_further_processing( + size_t processed, size_t hits, size_t goal, size_t hint +) { + if (hits == 0 || hits >= goal) { + return hint; + } + float batch_size = static_cast(goal - hits) * processed / hits; + return std::max(static_cast(batch_size), size_t{1}); +} + inline svs::threads::ThreadPoolHandle default_threadpool() { return svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()) ); diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 4cf58d7e0..d5a731017 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -131,8 +131,13 @@ class VamanaIndexImpl { auto query = queries.get_datum(i); auto iterator = get_impl()->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); do { - iterator.next(k); + batch_size = + predict_further_processing(total_checked, found, k, batch_size); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); From 62d9bdff2144be154e3488ba35fd64df5ffff670 Mon Sep 17 00:00:00 2001 From: yuejiaointel Date: Fri, 3 Apr 2026 22:40:09 -0700 Subject: [PATCH 4/5] add batch size cap and comments to adaptive filtered search - Cap batch size with std::min instead of modulo to avoid SIGFPE - Add comments explaining adaptive batch sizing logic --- bindings/cpp/src/dynamic_vamana_index_impl.h | 12 ++++++++++-- bindings/cpp/src/vamana_index_impl.h | 12 ++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index fe4c0b49b..f74978591 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -126,10 +126,18 @@ class DynamicVamanaIndexImpl { auto iterator = impl_->batch_iterator(query); size_t found = 0; size_t total_checked = 0; + // Use adaptive batch sizing: start with at least k candidates, + // then adjust based on observed filter hit rate. auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); + const auto max_batch_size = batch_size; do { - batch_size = - predict_further_processing(total_checked, found, k, batch_size); + // Estimate how many candidates we need to find remaining + // results given the observed hit rate so far. + batch_size = predict_further_processing( + total_checked, found, k, batch_size + ); + // Cap to avoid oversized batches in the iterator. + batch_size = std::min(batch_size, max_batch_size); iterator.next(batch_size); total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index d5a731017..65cee325f 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -132,10 +132,18 @@ class VamanaIndexImpl { auto iterator = get_impl()->batch_iterator(query); size_t found = 0; size_t total_checked = 0; + // Use adaptive batch sizing: start with at least k candidates, + // then adjust based on observed filter hit rate. auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size()); + const auto max_batch_size = batch_size; do { - batch_size = - predict_further_processing(total_checked, found, k, batch_size); + // Estimate how many candidates we need to find remaining + // results given the observed hit rate so far. + batch_size = predict_further_processing( + total_checked, found, k, batch_size + ); + // Cap to avoid oversized batches in the iterator. + batch_size = std::min(batch_size, max_batch_size); iterator.next(batch_size); total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { From ee06f00e480ad68281a806e3916dea25ecd5bf99 Mon Sep 17 00:00:00 2001 From: yuejiaointel Date: Mon, 6 Apr 2026 13:17:34 -0700 Subject: [PATCH 5/5] apply clang-format to adaptive batch size code --- bindings/cpp/src/dynamic_vamana_index_impl.h | 5 ++--- bindings/cpp/src/svs_runtime_utils.h | 5 ++--- bindings/cpp/src/vamana_index_impl.h | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index f74978591..b27958703 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -133,9 +133,8 @@ class DynamicVamanaIndexImpl { do { // Estimate how many candidates we need to find remaining // results given the observed hit rate so far. - batch_size = predict_further_processing( - total_checked, found, k, batch_size - ); + batch_size = + predict_further_processing(total_checked, found, k, batch_size); // Cap to avoid oversized batches in the iterator. batch_size = std::min(batch_size, max_batch_size); iterator.next(batch_size); diff --git a/bindings/cpp/src/svs_runtime_utils.h b/bindings/cpp/src/svs_runtime_utils.h index b5fd12756..6caa1a325 100644 --- a/bindings/cpp/src/svs_runtime_utils.h +++ b/bindings/cpp/src/svs_runtime_utils.h @@ -435,9 +435,8 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) { // based on the observed hit rate so far. // If no hits yet, returns `hint` unchanged. // The caller should cap the result to a max batch size if needed. -inline size_t predict_further_processing( - size_t processed, size_t hits, size_t goal, size_t hint -) { +inline size_t +predict_further_processing(size_t processed, size_t hits, size_t goal, size_t hint) { if (hits == 0 || hits >= goal) { return hint; } diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 65cee325f..2fd1f1452 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -139,9 +139,8 @@ class VamanaIndexImpl { do { // Estimate how many candidates we need to find remaining // results given the observed hit rate so far. - batch_size = predict_further_processing( - total_checked, found, k, batch_size - ); + batch_size = + predict_further_processing(total_checked, found, k, batch_size); // Cap to avoid oversized batches in the iterator. batch_size = std::min(batch_size, max_batch_size); iterator.next(batch_size);