From 9ac27bc3ccda17858afd3053ca175cbc85f8dc9e Mon Sep 17 00:00:00 2001 From: yuejiaointel Date: Wed, 8 Apr 2026 16:57:31 -0700 Subject: [PATCH] [CPP Runtime] Add get_distance and reconstruct_at to VamanaIndex API Add get_distance() and reconstruct_at() methods to the runtime library for OpenSearch integration. These expose existing orchestrator-layer functionality through the shared library ABI. - get_distance: computes distance between a stored vector and a query - reconstruct_at: decompresses/reconstructs vectors to float32 by ID - Works with all storage kinds (FP32, FP16, SQI8, LVQ, LeanVec) - Added to both VamanaIndex (static) and DynamicVamanaIndex - Includes tests for both index types --- .../cpp/include/svs/runtime/vamana_index.h | 7 + bindings/cpp/src/dynamic_vamana_index.cpp | 9 ++ bindings/cpp/src/dynamic_vamana_index_impl.h | 17 +++ bindings/cpp/src/vamana_index.cpp | 9 ++ bindings/cpp/src/vamana_index_impl.h | 17 +++ bindings/cpp/tests/runtime_test.cpp | 138 ++++++++++++++++++ 6 files changed, 197 insertions(+) diff --git a/bindings/cpp/include/svs/runtime/vamana_index.h b/bindings/cpp/include/svs/runtime/vamana_index.h index 98831952..dc320376 100644 --- a/bindings/cpp/include/svs/runtime/vamana_index.h +++ b/bindings/cpp/include/svs/runtime/vamana_index.h @@ -76,6 +76,13 @@ struct SVS_RUNTIME_API VamanaIndex { IDFilter* filter = nullptr ) const noexcept = 0; + // Compute distance between stored vector `id` and `query` (dim floats). + virtual Status + get_distance(double* distance, size_t id, const float* query) const noexcept = 0; + + // Reconstruct `n` vectors by ID into `output` buffer (n * dim floats). + virtual Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept = 0; + // Utility function to check storage kind support static Status check_storage_kind(StorageKind storage_kind) noexcept; diff --git a/bindings/cpp/src/dynamic_vamana_index.cpp b/bindings/cpp/src/dynamic_vamana_index.cpp index 0c1a6a89..14866e6d 100644 --- a/bindings/cpp/src/dynamic_vamana_index.cpp +++ b/bindings/cpp/src/dynamic_vamana_index.cpp @@ -118,6 +118,15 @@ struct DynamicVamanaIndexManagerBase : public DynamicVamanaIndex { Status save(std::ostream& out) const noexcept override { return runtime_error_wrapper([&] { impl_->save(out); }); } + + Status + get_distance(double* distance, size_t id, const float* query) const noexcept override { + return runtime_error_wrapper([&] { *distance = impl_->get_distance(id, query); }); + } + + Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override { + return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); }); + } }; } // namespace diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index 4b16cf4b..4254c5f5 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -305,6 +305,23 @@ class DynamicVamanaIndexImpl { return remove(ids_to_delete); } + double get_distance(size_t id, const float* query) const { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + auto query_span = std::span(query, dim_); + return impl_->get_distance(id, query_span); + } + + void reconstruct_at(size_t n, const size_t* ids, float* output) { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + svs::data::SimpleDataView dst{output, n, dim_}; + std::span id_span{reinterpret_cast(ids), n}; + impl_->reconstruct_at(dst, id_span); + } + void reset() { impl_.reset(); ntotal_soft_deleted = 0; diff --git a/bindings/cpp/src/vamana_index.cpp b/bindings/cpp/src/vamana_index.cpp index c015dd21..8e7cb7dd 100644 --- a/bindings/cpp/src/vamana_index.cpp +++ b/bindings/cpp/src/vamana_index.cpp @@ -88,6 +88,15 @@ struct VamanaIndexManagerBase : public VamanaIndex { Status save(std::ostream& out) const noexcept override { return runtime_error_wrapper([&] { impl_->save(out); }); } + + Status + get_distance(double* distance, size_t id, const float* query) const noexcept override { + return runtime_error_wrapper([&] { *distance = impl_->get_distance(id, query); }); + } + + Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override { + return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); }); + } }; } // namespace diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 4cf58d7e..ae704fdf 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -269,6 +269,23 @@ class VamanaIndexImpl { } } + double get_distance(size_t id, const float* query) const { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + auto query_span = std::span(query, dim_); + return get_impl()->get_distance(id, query_span); + } + + void reconstruct_at(size_t n, const size_t* ids, float* output) { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + svs::data::SimpleDataView dst{output, n, dim_}; + std::span id_span{reinterpret_cast(ids), n}; + get_impl()->reconstruct_at(dst, id_span); + } + void reset() { impl_.reset(); } void save(std::ostream& out) const { diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 201375d3..3e15fd2c 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -881,3 +881,141 @@ CATCH_TEST_CASE("RangeSearchFunctionalStatic", "[runtime][static_vamana]") { svs::runtime::v0::VamanaIndex::destroy(index); } + +CATCH_TEST_CASE("GetDistanceDynamic", "[runtime]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Self-distance should be approximately 0 + double dist = -1.0; + const float* vec0 = test_data.data(); + status = index->get_distance(&dist, 0, vec0); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist < 1e-6); + + // Distance to a different vector should be positive + const float* vec1 = test_data.data() + test_d; + status = index->get_distance(&dist, 0, vec1); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist > 0.0); + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("GetDistanceStatic", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::VamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::VamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Self-distance should be approximately 0 + double dist = -1.0; + const float* vec0 = test_data.data(); + status = index->get_distance(&dist, 0, vec0); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist < 1e-6); + + // Distance to a different vector should be positive + const float* vec1 = test_data.data() + test_d; + status = index->get_distance(&dist, 0, vec1); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist > 0.0); + + svs::runtime::v0::VamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("ReconstructAtDynamic", "[runtime]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Reconstruct first 5 vectors + constexpr size_t nrecon = 5; + std::vector ids(nrecon); + std::iota(ids.begin(), ids.end(), 0); + std::vector output(nrecon * test_d, 0.0f); + + status = index->reconstruct_at(nrecon, ids.data(), output.data()); + CATCH_REQUIRE(status.ok()); + + // For FP32 storage, reconstructed vectors should match originals exactly + for (size_t i = 0; i < nrecon; ++i) { + for (size_t j = 0; j < test_d; ++j) { + CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("ReconstructAtStatic", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::VamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::VamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Reconstruct first 5 vectors + constexpr size_t nrecon = 5; + std::vector ids(nrecon); + std::iota(ids.begin(), ids.end(), 0); + std::vector output(nrecon * test_d, 0.0f); + + status = index->reconstruct_at(nrecon, ids.data(), output.data()); + CATCH_REQUIRE(status.ok()); + + // For FP32 storage, reconstructed vectors should match originals exactly + for (size_t i = 0; i < nrecon; ++i) { + for (size_t j = 0; j < test_d; ++j) { + CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]); + } + } + + svs::runtime::v0::VamanaIndex::destroy(index); +}