diff --git a/bindings/cpp/include/svs/runtime/ivf_index.h b/bindings/cpp/include/svs/runtime/ivf_index.h index 46811743..3bbe8d10 100644 --- a/bindings/cpp/include/svs/runtime/ivf_index.h +++ b/bindings/cpp/include/svs/runtime/ivf_index.h @@ -54,16 +54,28 @@ struct SVS_RUNTIME_API IVFIndex { size_t n_probes = Unspecify(); /// Level of reordering/reranking done when using compressed datasets (multiplier) float k_reorder = Unspecify(); + /// Minimum filter hit rate to continue filtered search. + /// If the hit rate after the first round falls below this threshold, + /// stop and return empty results (caller can fall back to exact search). + /// Default unspecified means never give up (treated as 0). + float filter_stop = Unspecify(); + /// Enable pre-search filter sampling to estimate hit rate before + /// cluster traversal. Uses a random sample of IDs to set the initial + /// batch size and trigger early exit. + OptionalBool filter_estimate_batch = Unspecify(); }; /// @brief Perform k-NN search on the index. + /// @param filter Optional ID filter; when non-null, only IDs satisfying + /// ``filter->is_member(id)`` are returned. virtual Status search( size_t n, const float* x, size_t k, float* distances, size_t* labels, - const SearchParams* params = nullptr + const SearchParams* params = nullptr, + IDFilter* filter = nullptr ) const noexcept = 0; /// @brief Utility function to check storage kind support. @@ -108,6 +120,13 @@ struct SVS_RUNTIME_API IVFIndex { /// @brief Get the number of threads used for index operations. virtual Status get_num_threads(size_t* num_threads) const noexcept = 0; + /// @brief Set the number of intra-query (cluster-level) threads. + /// Recreates the per-query intra-query thread pools. + virtual Status set_intra_query_threads(size_t intra_query_threads) noexcept = 0; + + /// @brief Get the current number of intra-query (cluster-level) threads. + virtual Status get_intra_query_threads(size_t* intra_query_threads) const noexcept = 0; + /// @brief Load an IVF index from a stream. static Status load( IVFIndex** index, diff --git a/bindings/cpp/src/dynamic_ivf_index.cpp b/bindings/cpp/src/dynamic_ivf_index.cpp index 5059eeae..697bbaba 100644 --- a/bindings/cpp/src/dynamic_ivf_index.cpp +++ b/bindings/cpp/src/dynamic_ivf_index.cpp @@ -54,14 +54,15 @@ struct DynamicIVFIndexManager : public DynamicIVFIndex { size_t k, float* distances, size_t* labels, - const SearchParams* params = nullptr + const SearchParams* params = nullptr, + IDFilter* filter = nullptr ) const noexcept override { return runtime_error_wrapper([&] { auto result = svs::QueryResultView{ svs::MatrixView{svs::make_dims(n, k), labels}, svs::MatrixView{svs::make_dims(n, k), distances}}; auto queries = svs::data::ConstSimpleDataView(x, n, impl_->dimensions()); - impl_->search(result, queries, params); + impl_->search(result, queries, params, filter); }); } @@ -112,6 +113,18 @@ struct DynamicIVFIndexManager : public DynamicIVFIndex { Status get_num_threads(size_t* num_threads) const noexcept override { return runtime_error_wrapper([&] { *num_threads = impl_->get_num_threads(); }); } + + Status set_intra_query_threads(size_t intra_query_threads) noexcept override { + return runtime_error_wrapper([&] { + impl_->set_intra_query_threads(intra_query_threads); + }); + } + + Status get_intra_query_threads(size_t* intra_query_threads) const noexcept override { + return runtime_error_wrapper([&] { + *intra_query_threads = impl_->get_intra_query_threads(); + }); + } }; } // namespace diff --git a/bindings/cpp/src/dynamic_ivf_index_impl.h b/bindings/cpp/src/dynamic_ivf_index_impl.h index c014d8f8..933e7458 100644 --- a/bindings/cpp/src/dynamic_ivf_index_impl.h +++ b/bindings/cpp/src/dynamic_ivf_index_impl.h @@ -125,7 +125,8 @@ class DynamicIVFIndexImpl { void search( svs::QueryResultView result, svs::data::ConstSimpleDataView queries, - const IVFIndex::SearchParams* params = nullptr + const IVFIndex::SearchParams* params = nullptr, + IDFilter* filter = nullptr ) const { if (!impl_) { auto& dists = result.distances(); @@ -145,7 +146,84 @@ class DynamicIVFIndexImpl { } auto sp = make_search_parameters(params); - impl_->search(result, queries, sp); + + // Simple search + if (filter == nullptr) { + impl_->search(result, queries, sp); + return; + } + + // Selective search with IDFilter: use batch iterator to over-fetch and + // filter per-neighbor, mirroring the Vamana approach. + auto old_sp = impl_->get_search_parameters(); + auto sp_restore = svs::lib::make_scope_guard([&]() noexcept { + impl_->set_search_parameters(old_sp); + }); + impl_->set_search_parameters(sp); + + float filter_stop = 0.0f; + bool filter_estimate_batch = true; + if (params) { + set_if_specified(filter_stop, params->filter_stop); + set_if_specified(filter_estimate_batch, params->filter_estimate_batch); + } + const auto max_batch_size = impl_->size(); + + // Pre-search filter sampling: estimate hit rate before cluster traversal. + size_t sampled = 0; + size_t sample_hits = 0; + const auto initial_batch_hint = std::max(k, size_t{1}); + auto initial_batch_size = initial_batch_hint; + if (filter_estimate_batch) { + std::tie(sampled, sample_hits) = sample_filter_hits( + *filter, + max_batch_size, + [this](size_t id) { return impl_->has_id(id); }, + sample_size_for_filter_stop(filter_stop) + ); + if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) { + pad_empty_results(result, queries.size(), k); + return; + } + initial_batch_size = predict_further_processing( + sampled, sample_hits, k, initial_batch_hint, max_batch_size + ); + } + + for (size_t i = 0; i < queries.size(); ++i) { + auto query = queries.get_datum(i); + auto iterator = + impl_->batch_iterator(std::span(query.data(), query.size())); + size_t found = 0; + size_t total_checked = 0; + auto batch_size = initial_batch_size; + do { + batch_size = predict_further_processing( + total_checked, found, k, batch_size, max_batch_size + ); + iterator.next(batch_size); + total_checked += iterator.size(); + for (const auto& neighbor : iterator.results()) { + if (filter->is_member(neighbor.id())) { + result.set(neighbor, i, found); + ++found; + if (found == k) { + break; + } + } + } + if (should_stop_filtered_search(total_checked, found, filter_stop)) { + found = 0; + break; + } + } while (found < k && !iterator.done()); + + for (size_t j = found; j < k; ++j) { + result.set( + svs::Neighbor{Unspecify(), Unspecify()}, i, j + ); + } + } } void save(std::ostream& out) const { @@ -174,6 +252,24 @@ class DynamicIVFIndexImpl { return num_threads_; } + void set_intra_query_threads(size_t intra_query_threads) { + if (intra_query_threads == 0) { + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, "intra_query_threads must be at least 1"}; + } + intra_query_threads_ = intra_query_threads; + if (impl_) { + impl_->set_num_intra_query_threads(intra_query_threads); + } + } + + size_t get_intra_query_threads() const { + if (impl_) { + return impl_->get_num_intra_query_threads(); + } + return intra_query_threads_; + } + static DynamicIVFIndexImpl* load( std::istream& in, MetricType metric, diff --git a/bindings/cpp/src/ivf_index.cpp b/bindings/cpp/src/ivf_index.cpp index 36693fa3..18ad2fe9 100644 --- a/bindings/cpp/src/ivf_index.cpp +++ b/bindings/cpp/src/ivf_index.cpp @@ -54,14 +54,15 @@ struct IVFIndexManager : public IVFIndex { size_t k, float* distances, size_t* labels, - const SearchParams* params = nullptr + const SearchParams* params = nullptr, + IDFilter* filter = nullptr ) const noexcept override { return runtime_error_wrapper([&] { auto result = svs::QueryResultView{ svs::MatrixView{svs::make_dims(n, k), labels}, svs::MatrixView{svs::make_dims(n, k), distances}}; auto queries = svs::data::ConstSimpleDataView(x, n, impl_->dimensions()); - impl_->search(result, queries, params); + impl_->search(result, queries, params, filter); }); } @@ -76,6 +77,18 @@ struct IVFIndexManager : public IVFIndex { Status get_num_threads(size_t* num_threads) const noexcept override { return runtime_error_wrapper([&] { *num_threads = impl_->get_num_threads(); }); } + + Status set_intra_query_threads(size_t intra_query_threads) noexcept override { + return runtime_error_wrapper([&] { + impl_->set_intra_query_threads(intra_query_threads); + }); + } + + Status get_intra_query_threads(size_t* intra_query_threads) const noexcept override { + return runtime_error_wrapper([&] { + *intra_query_threads = impl_->get_intra_query_threads(); + }); + } }; } // namespace diff --git a/bindings/cpp/src/ivf_index_impl.h b/bindings/cpp/src/ivf_index_impl.h index ace92329..1cf12b39 100644 --- a/bindings/cpp/src/ivf_index_impl.h +++ b/bindings/cpp/src/ivf_index_impl.h @@ -363,7 +363,8 @@ class IVFIndexImpl { void search( svs::QueryResultView result, svs::data::ConstSimpleDataView queries, - const IVFIndex::SearchParams* params = nullptr + const IVFIndex::SearchParams* params = nullptr, + IDFilter* filter = nullptr ) const { if (!impl_) { auto& dists = result.distances(); @@ -383,7 +384,85 @@ class IVFIndexImpl { } auto sp = make_search_parameters(params); - impl_->search(result, queries, sp); + + // Simple search + if (filter == nullptr) { + impl_->search(result, queries, sp); + return; + } + + // Selective search with IDFilter: use batch iterator to over-fetch and + // filter per-neighbor, mirroring the Vamana approach. + auto old_sp = impl_->get_search_parameters(); + auto sp_restore = svs::lib::make_scope_guard([&]() noexcept { + impl_->set_search_parameters(old_sp); + }); + impl_->set_search_parameters(sp); + + float filter_stop = 0.0f; + bool filter_estimate_batch = true; + if (params) { + set_if_specified(filter_stop, params->filter_stop); + set_if_specified(filter_estimate_batch, params->filter_estimate_batch); + } + const auto max_batch_size = impl_->size(); + + // Pre-search filter sampling: estimate hit rate before cluster traversal. + size_t sampled = 0; + size_t sample_hits = 0; + const auto initial_batch_hint = std::max(k, size_t{1}); + auto initial_batch_size = initial_batch_hint; + if (filter_estimate_batch) { + std::tie(sampled, sample_hits) = sample_filter_hits( + *filter, + max_batch_size, + [](size_t) { return true; }, + sample_size_for_filter_stop(filter_stop) + ); + if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) { + pad_empty_results(result, queries.size(), k); + return; + } + initial_batch_size = predict_further_processing( + sampled, sample_hits, k, initial_batch_hint, max_batch_size + ); + } + + for (size_t i = 0; i < queries.size(); ++i) { + auto query = queries.get_datum(i); + auto iterator = + impl_->batch_iterator(std::span(query.data(), query.size())); + size_t found = 0; + size_t total_checked = 0; + auto batch_size = initial_batch_size; + do { + batch_size = predict_further_processing( + total_checked, found, k, batch_size, max_batch_size + ); + iterator.next(batch_size); + total_checked += iterator.size(); + for (const auto& neighbor : iterator.results()) { + if (filter->is_member(neighbor.id())) { + result.set(neighbor, i, found); + ++found; + if (found == k) { + break; + } + } + } + if (should_stop_filtered_search(total_checked, found, filter_stop)) { + found = 0; + break; + } + } while (found < k && !iterator.done()); + + // Pad results if not enough neighbors found + for (size_t j = found; j < k; ++j) { + result.set( + svs::Neighbor{Unspecify(), Unspecify()}, i, j + ); + } + } } void save(std::ostream& out) const { @@ -411,6 +490,24 @@ class IVFIndexImpl { return num_threads_; } + void set_intra_query_threads(size_t intra_query_threads) { + if (intra_query_threads == 0) { + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, "intra_query_threads must be at least 1"}; + } + intra_query_threads_ = intra_query_threads; + if (impl_) { + impl_->set_num_intra_query_threads(intra_query_threads); + } + } + + size_t get_intra_query_threads() const { + if (impl_) { + return impl_->get_num_intra_query_threads(); + } + return intra_query_threads_; + } + static IVFIndexImpl* load( std::istream& in, MetricType metric, diff --git a/bindings/cpp/tests/ivf_runtime_test.cpp b/bindings/cpp/tests/ivf_runtime_test.cpp index ad6a275d..1259cb50 100644 --- a/bindings/cpp/tests/ivf_runtime_test.cpp +++ b/bindings/cpp/tests/ivf_runtime_test.cpp @@ -922,3 +922,431 @@ CATCH_TEST_CASE("IVFIndexInnerProduct", "[runtime][ivf]") { svs::runtime::v0::IVFIndex::destroy(index); } + +CATCH_TEST_CASE("IVFIndexSetIntraQueryThreads", "[runtime][ivf]") { + std::cout << "[IVF] Running IVFIndexSetIntraQueryThreads..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::IVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + svs::runtime::v0::Status status = svs::runtime::v0::IVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + build_params, + /*num_threads=*/2, + /*intra_query_threads=*/1 + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + size_t got = 0; + status = index->get_intra_query_threads(&got); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(got == 1); + + status = index->set_intra_query_threads(3); + CATCH_REQUIRE(status.ok()); + + status = index->get_intra_query_threads(&got); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(got == 3); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 10; + std::vector distances(nq * k); + std::vector result_labels(nq * k); + status = index->search(nq, xq, k, distances.data(), result_labels.data()); + CATCH_REQUIRE(status.ok()); + + status = index->set_intra_query_threads(0); + CATCH_REQUIRE(!status.ok()); + + svs::runtime::v0::IVFIndex::destroy(index); +} + +CATCH_TEST_CASE("DynamicIVFIndexSetIntraQueryThreads", "[runtime][ivf]") { + std::cout << "[IVF] Running DynamicIVFIndexSetIntraQueryThreads..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::DynamicIVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + + svs::runtime::v0::Status status = svs::runtime::v0::DynamicIVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + labels.data(), + build_params, + /*default_search_params=*/{}, + /*num_threads=*/2, + /*intra_query_threads=*/1 + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + size_t got = 0; + status = index->get_intra_query_threads(&got); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(got == 1); + + status = index->set_intra_query_threads(2); + CATCH_REQUIRE(status.ok()); + + status = index->get_intra_query_threads(&got); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(got == 2); + + svs::runtime::v0::DynamicIVFIndex::destroy(index); +} + +CATCH_TEST_CASE("IVFIndexSearchWithIDFilter", "[runtime][ivf]") { + std::cout << "[IVF] Running IVFIndexSearchWithIDFilter..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::IVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + svs::runtime::v0::Status status = svs::runtime::v0::IVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + build_params + ); + CATCH_REQUIRE(status.ok()); + + const size_t min_id = 10; + const size_t max_id = 50; + test_utils::IDFilterRange filter(min_id, max_id); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + svs::runtime::v0::IVFIndex::SearchParams sp; + sp.n_probes = 10; + status = index->search(nq, xq, k, distances.data(), result_labels.data(), &sp, &filter); + CATCH_REQUIRE(status.ok()); + + for (int q = 0; q < nq; ++q) { + for (int j = 0; j < k; ++j) { + size_t id = result_labels[q * k + j]; + if (id < test_n) { + CATCH_REQUIRE(id >= min_id); + CATCH_REQUIRE(id < max_id); + } + } + } + + svs::runtime::v0::IVFIndex::destroy(index); +} + +CATCH_TEST_CASE("DynamicIVFIndexSearchWithIDFilter", "[runtime][ivf]") { + std::cout << "[IVF] Running DynamicIVFIndexSearchWithIDFilter..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::DynamicIVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + + svs::runtime::v0::Status status = svs::runtime::v0::DynamicIVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + labels.data(), + build_params + ); + CATCH_REQUIRE(status.ok()); + + const size_t min_id = 20; + const size_t max_id = 60; + test_utils::IDFilterRange filter(min_id, max_id); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + svs::runtime::v0::IVFIndex::SearchParams sp; + sp.n_probes = 10; + status = index->search(nq, xq, k, distances.data(), result_labels.data(), &sp, &filter); + CATCH_REQUIRE(status.ok()); + + for (int q = 0; q < nq; ++q) { + for (int j = 0; j < k; ++j) { + size_t id = result_labels[q * k + j]; + if (id < test_n) { + CATCH_REQUIRE(id >= min_id); + CATCH_REQUIRE(id < max_id); + } + } + } + + svs::runtime::v0::DynamicIVFIndex::destroy(index); +} + +CATCH_TEST_CASE("IVFIndexSearchWithRestrictiveFilter", "[runtime][ivf][filtered_search]") { + std::cout << "[IVF] Running IVFIndexSearchWithRestrictiveFilter..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::IVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + svs::runtime::v0::Status status = svs::runtime::v0::IVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + build_params + ); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + const size_t min_id = 0; + const size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + svs::runtime::v0::IVFIndex::SearchParams sp; + sp.n_probes = 10; + status = index->search(nq, xq, k, distances.data(), result_labels.data(), &sp, &filter); + CATCH_REQUIRE(status.ok()); + + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels[i])) { + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); + } + } + + svs::runtime::v0::IVFIndex::destroy(index); +} + +CATCH_TEST_CASE("IVFIndexFilterStopEarlyExit", "[runtime][ivf][filtered_search]") { + std::cout << "[IVF] Running IVFIndexFilterStopEarlyExit..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::IVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + svs::runtime::v0::Status status = svs::runtime::v0::IVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + build_params + ); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + const size_t min_id = 0; + const size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + // Set filter_stop = 0.5 (50%). With ~10% hit rate, search should give up + // and return unspecified results. + svs::runtime::v0::IVFIndex::SearchParams sp; + sp.n_probes = 10; + sp.filter_stop = 0.5f; + + status = index->search(nq, xq, k, distances.data(), result_labels.data(), &sp, &filter); + CATCH_REQUIRE(status.ok()); + + for (int i = 0; i < nq * k; ++i) { + CATCH_REQUIRE(!svs::runtime::v0::is_specified(result_labels[i])); + } + + // Now search without filter_stop — should find valid results in the filter range. + std::vector distances2(nq * k); + std::vector result_labels2(nq * k); + + svs::runtime::v0::IVFIndex::SearchParams sp_no_stop; + sp_no_stop.n_probes = 10; + status = index->search( + nq, xq, k, distances2.data(), result_labels2.data(), &sp_no_stop, &filter + ); + CATCH_REQUIRE(status.ok()); + + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels2[i])) { + CATCH_REQUIRE(result_labels2[i] >= min_id); + CATCH_REQUIRE(result_labels2[i] < max_id); + } + } + + svs::runtime::v0::IVFIndex::destroy(index); +} + +CATCH_TEST_CASE( + "DynamicIVFIndexSearchWithRestrictiveFilter", "[runtime][ivf][filtered_search]" +) { + std::cout << "[IVF] Running DynamicIVFIndexSearchWithRestrictiveFilter..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::DynamicIVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + + svs::runtime::v0::Status status = svs::runtime::v0::DynamicIVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + labels.data(), + build_params + ); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + const size_t min_id = 0; + const size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + svs::runtime::v0::IVFIndex::SearchParams sp; + sp.n_probes = 10; + status = index->search(nq, xq, k, distances.data(), result_labels.data(), &sp, &filter); + CATCH_REQUIRE(status.ok()); + + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels[i])) { + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); + } + } + + svs::runtime::v0::DynamicIVFIndex::destroy(index); +} + +CATCH_TEST_CASE("DynamicIVFIndexFilterStopEarlyExit", "[runtime][ivf][filtered_search]") { + std::cout << "[IVF] Running DynamicIVFIndexFilterStopEarlyExit..." << std::endl; + const auto& test_data = get_test_data(); + + svs::runtime::v0::DynamicIVFIndex* index = nullptr; + svs::runtime::v0::IVFIndex::BuildParams build_params; + build_params.num_centroids = 10; + build_params.num_iterations = 5; + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + + svs::runtime::v0::Status status = svs::runtime::v0::DynamicIVFIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + test_n, + test_data.data(), + labels.data(), + build_params + ); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + const size_t min_id = 0; + const size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + // Set filter_stop = 0.5 (50%). With ~10% hit rate, search should give up + // and return unspecified results. + svs::runtime::v0::IVFIndex::SearchParams sp; + sp.n_probes = 10; + sp.filter_stop = 0.5f; + + status = index->search(nq, xq, k, distances.data(), result_labels.data(), &sp, &filter); + CATCH_REQUIRE(status.ok()); + + for (int i = 0; i < nq * k; ++i) { + CATCH_REQUIRE(!svs::runtime::v0::is_specified(result_labels[i])); + } + + // Now search without filter_stop — should find valid results in the filter range. + std::vector distances2(nq * k); + std::vector result_labels2(nq * k); + + svs::runtime::v0::IVFIndex::SearchParams sp_no_stop; + sp_no_stop.n_probes = 10; + status = index->search( + nq, xq, k, distances2.data(), result_labels2.data(), &sp_no_stop, &filter + ); + CATCH_REQUIRE(status.ok()); + + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels2[i])) { + CATCH_REQUIRE(result_labels2[i] >= min_id); + CATCH_REQUIRE(result_labels2[i] < max_id); + } + } + + svs::runtime::v0::DynamicIVFIndex::destroy(index); +} diff --git a/include/svs/index/ivf/dynamic_ivf.h b/include/svs/index/ivf/dynamic_ivf.h index e3259e87..b60ee534 100644 --- a/include/svs/index/ivf/dynamic_ivf.h +++ b/include/svs/index/ivf/dynamic_ivf.h @@ -109,7 +109,7 @@ class DynamicIVFIndex { // Threading infrastructure (same as static IVF) InterQueryThreadPool inter_query_threadpool_; - const size_t intra_query_thread_count_; + size_t intra_query_thread_count_; mutable std::vector intra_query_threadpools_; // Search infrastructure (same as static IVF) @@ -278,6 +278,20 @@ class DynamicIVFIndex { // Re-initialize per-thread search buffers for the new thread count matmul_results_.clear(); initialize_search_buffers(); + // Re-initialize intra-query thread pools to match new inter-query pool size + intra_query_threadpools_.clear(); + initialize_thread_pools(); + } + + /// @brief Set the number of threads used for intra-query (cluster-level) + /// parallelism. Re-creates the per-query intra-query thread pools. + void set_num_intra_query_threads(size_t count) { + if (count < 1) { + throw std::invalid_argument("Intra-query thread count must be at least 1"); + } + intra_query_thread_count_ = count; + intra_query_threadpools_.clear(); + initialize_thread_pools(); } /// @brief Get threadpool handle diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index 00f45735..26959c47 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -200,6 +200,20 @@ class IVFIndex { // Re-initialize per-thread search buffers for the new thread count matmul_results_.clear(); initialize_search_buffers(); + // Re-initialize intra-query thread pools to match new inter-query pool size + intra_query_threadpools_.clear(); + initialize_thread_pools(); + } + + /// @brief Set the number of threads used for intra-query (cluster-level) + /// parallelism. Re-creates the per-query intra-query thread pools. + void set_num_intra_query_threads(size_t count) { + if (count < 1) { + throw std::invalid_argument("Intra-query thread count must be at least 1"); + } + intra_query_thread_count_ = count; + intra_query_threadpools_.clear(); + initialize_thread_pools(); } /// @brief Get the thread pool handle for inter-query parallelism @@ -544,7 +558,7 @@ class IVFIndex { ///// Threading Infrastructure ///// InterQueryThreadPool inter_query_threadpool_; // Handles parallelism across queries - const size_t intra_query_thread_count_; // Number of threads per query processing + size_t intra_query_thread_count_; // Number of threads per query processing mutable std::vector intra_query_threadpools_; // Per-query parallel cluster exploration diff --git a/include/svs/orchestrators/dynamic_ivf.h b/include/svs/orchestrators/dynamic_ivf.h index f3ffe260..aeafdbf3 100644 --- a/include/svs/orchestrators/dynamic_ivf.h +++ b/include/svs/orchestrators/dynamic_ivf.h @@ -190,6 +190,14 @@ class DynamicIVF : public manager::IndexManager { return impl_->experimental_backend_string(); } + // Intra-query (cluster-level) threading + size_t get_num_intra_query_threads() const { + return impl_->get_num_intra_query_threads(); + } + void set_num_intra_query_threads(size_t count) { + impl_->set_num_intra_query_threads(count); + } + // ID Inspection /// diff --git a/include/svs/orchestrators/ivf.h b/include/svs/orchestrators/ivf.h index 87cb1b22..c86d7bbf 100644 --- a/include/svs/orchestrators/ivf.h +++ b/include/svs/orchestrators/ivf.h @@ -29,6 +29,10 @@ class IVFInterface { ///// Backend information interface virtual std::string experimental_backend_string() const = 0; + ///// Intra-query (cluster-level) parallelism + virtual size_t get_num_intra_query_threads() const = 0; + virtual void set_num_intra_query_threads(size_t count) = 0; + ///// Distance calculation virtual double get_distance(size_t id, const AnonymousArray<1>& query) const = 0; @@ -73,6 +77,14 @@ class IVFImpl : public manager::ManagerImpl { return std::string{typename_impl.begin(), typename_impl.end() - 1}; } + ///// Intra-query (cluster-level) parallelism + [[nodiscard]] size_t get_num_intra_query_threads() const override { + return impl().get_num_intra_query_threads(); + } + void set_num_intra_query_threads(size_t count) override { + impl().set_num_intra_query_threads(count); + } + ///// Distance Calculation [[nodiscard]] double get_distance(size_t id, const AnonymousArray<1>& query) const override { @@ -146,6 +158,14 @@ class IVF : public manager::IndexManager { return impl_->experimental_backend_string(); } + ///// Intra-query (cluster-level) threading + size_t get_num_intra_query_threads() const { + return impl_->get_num_intra_query_threads(); + } + void set_num_intra_query_threads(size_t count) { + impl_->set_num_intra_query_threads(count); + } + ///// Distance Calculation template double get_distance(size_t id, const QueryType& query) const {