diff --git a/bindings/cpp/include/svs/runtime/vamana_index.h b/bindings/cpp/include/svs/runtime/vamana_index.h index 98831952..2149c0d3 100644 --- a/bindings/cpp/include/svs/runtime/vamana_index.h +++ b/bindings/cpp/include/svs/runtime/vamana_index.h @@ -40,6 +40,14 @@ struct VamanaSearchParameters { size_t search_buffer_capacity = Unspecify(); size_t prefetch_lookahead = Unspecify(); size_t prefetch_step = Unspecify(); + // Minimum filter hit rate to continue filtered search. + // If the hit rate after the first round falls below this threshold, + // stop and return empty results (caller can fall back to exact search). + // Default unspecified means never give up (treated as 0). + float filter_stop = Unspecify(); + // Enable pre-search filter sampling to estimate hit rate before graph traversal. + // Uses a random sample of IDs to set initial batch size and trigger early exit. + OptionalBool filter_estimate_batch = Unspecify(); }; } // namespace detail diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index e80f83c1..516a1653 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -118,15 +118,50 @@ class DynamicVamanaIndexImpl { // Selective search with IDSelector auto old_sp = impl_->get_search_parameters(); impl_->set_search_parameters(sp); + float filter_stop = 0.0f; + bool filter_estimate_batch = true; + if (params) { + set_if_specified(filter_stop, params->filter_stop); + set_if_specified(filter_estimate_batch, params->filter_estimate_batch); + } + const auto max_batch_size = impl_->size(); + + // Pre-search filter sampling: estimate hit rate before graph traversal. + size_t sampled = 0; + size_t sample_hits = 0; + const auto sws = sp.buffer_config_.get_search_window_size(); + const auto initial_batch_hint = std::max(k, sws); + auto initial_batch_size = initial_batch_hint; + if (filter_estimate_batch) { + std::tie(sampled, sample_hits) = sample_filter_hits( + *filter, + max_batch_size, + [this](size_t id) { return impl_->has_id(id); }, + sample_size_for_filter_stop(filter_stop) + ); + if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) { + pad_empty_results(result, queries.size(), k); + impl_->set_search_parameters(old_sp); + return; + } + initial_batch_size = predict_further_processing( + sampled, sample_hits, k, initial_batch_hint, max_batch_size + ); + } auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) { for (auto i : range) { - // For every query auto query = queries.get_datum(i); auto iterator = impl_->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + auto batch_size = initial_batch_size; do { - iterator.next(k); + batch_size = predict_further_processing( + total_checked, found, k, batch_size, max_batch_size + ); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); @@ -136,6 +171,10 @@ class DynamicVamanaIndexImpl { } } } + if (should_stop_filtered_search(total_checked, found, filter_stop)) { + found = 0; + break; + } } while (found < k && !iterator.done()); // Pad results if not enough neighbors found diff --git a/bindings/cpp/src/svs_runtime_utils.h b/bindings/cpp/src/svs_runtime_utils.h index 9e64d1fd..b081c91b 100644 --- a/bindings/cpp/src/svs_runtime_utils.h +++ b/bindings/cpp/src/svs_runtime_utils.h @@ -55,6 +55,7 @@ inline bool lvq_leanvec_enabled() { return false; } #include #include #include +#include #include #include #include @@ -403,6 +404,95 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) { } } // namespace storage +// Predict how many more items need to be processed to reach the goal, +// based on the observed hit rate so far. +// If no hits yet, returns `hint` unchanged. +// Result is capped at `max_value` (e.g., number of vectors in the index). +inline size_t predict_further_processing( + size_t processed, size_t hits, size_t goal, size_t hint, size_t max_value +) { + if (hits == 0 || hits >= goal) { + return std::min(hint, max_value); + } + float batch_size = static_cast(goal - hits) * processed / hits; + return std::min(std::max(static_cast(batch_size), size_t{1}), max_value); +} + +// Check if the filtered search should stop early based on the observed hit rate. +// Returns true if the hit rate is below the threshold, meaning the caller should +// give up and let the caller fall back to exact search. +inline bool +should_stop_filtered_search(size_t total_checked, size_t found, float filter_stop) { + if (filter_stop <= 0 || total_checked == 0) { + return false; + } + if (found == 0) { + return true; + } + float hit_rate = static_cast(found) / total_checked; + return hit_rate < filter_stop; +} + +// Default number of IDs to sample when estimating filter hit rate. +constexpr size_t kFilterSampleSize = 200; + +// Sample random IDs from [0, total_ids) and count filter hits. +// is_valid(id) is checked first; invalid IDs are skipped (for dynamic indices +// where IDs may be deleted). Keeps sampling until sample_size valid IDs checked +// or max_tries exhausted. Returns (checked, hits) — fed directly to +// predict_further_processing() and should_stop_filtered_search(). +template +inline std::pair sample_filter_hits( + const IDFilter& filter, + size_t total_ids, + IsValid is_valid, + size_t sample_size = kFilterSampleSize +) { + if (total_ids == 0) { + return {0, 0}; + } + size_t target = std::min(sample_size, total_ids); + size_t max_tries = target * 4; + std::mt19937 rng(42); + std::uniform_int_distribution dist(0, total_ids - 1); + size_t hits = 0; + size_t checked = 0; + for (size_t tries = 0; checked < target && tries < max_tries; ++tries) { + size_t id = dist(rng); + if (!is_valid(id)) { + continue; + } + if (filter.is_member(id)) { + hits++; + } + checked++; + } + return {checked, hits}; +} + +// Compute sample size for filter hit rate estimation based on filter_stop. +// Need at least 1/filter_stop samples to reliably distinguish hit rates around +// the threshold (below that, noise dominates — e.g., 0.1% vs 0.2% both look +// like 0 hits at sample_size=200). +inline size_t sample_size_for_filter_stop(float filter_stop) { + if (filter_stop <= 0) { + return kFilterSampleSize; + } + return std::max(kFilterSampleSize, static_cast(1.0f / filter_stop)); +} + +// Fill all result slots with unspecified values. +// Required when early-exiting before search: the caller-allocated result buffer +// may contain uninitialized data, so we must write valid "no result" markers. +inline void +pad_empty_results(svs::QueryResultView& result, size_t num_queries, size_t k) { + for (size_t i = 0; i < num_queries; ++i) { + for (size_t j = 0; j < k; ++j) { + result.set(Neighbor{Unspecify(), Unspecify()}, i, j); + } + } +} + inline svs::threads::ThreadPoolHandle default_threadpool() { return svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()) ); diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 45023b1d..550257d4 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -124,15 +124,49 @@ class VamanaIndexImpl { get_impl()->set_search_parameters(old_sp); }); get_impl()->set_search_parameters(sp); + float filter_stop = 0.0f; + bool filter_estimate_batch = true; + if (params) { + set_if_specified(filter_stop, params->filter_stop); + set_if_specified(filter_estimate_batch, params->filter_estimate_batch); + } + const auto max_batch_size = get_impl()->size(); + + // Pre-search filter sampling: estimate hit rate before graph traversal. + size_t sampled = 0; + size_t sample_hits = 0; + const auto sws = sp.buffer_config_.get_search_window_size(); + const auto initial_batch_hint = std::max(k, sws); + auto initial_batch_size = initial_batch_hint; + if (filter_estimate_batch) { + std::tie(sampled, sample_hits) = sample_filter_hits( + *filter, + max_batch_size, + [](size_t) { return true; }, + sample_size_for_filter_stop(filter_stop) + ); + if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) { + pad_empty_results(result, queries.size(), k); + return; + } + initial_batch_size = predict_further_processing( + sampled, sample_hits, k, initial_batch_hint, max_batch_size + ); + } auto search_closure = [&](const auto& range, uint64_t SVS_UNUSED(tid)) { for (auto i : range) { - // For every query auto query = queries.get_datum(i); auto iterator = get_impl()->batch_iterator(query); size_t found = 0; + size_t total_checked = 0; + auto batch_size = initial_batch_size; do { - iterator.next(k); + batch_size = predict_further_processing( + total_checked, found, k, batch_size, max_batch_size + ); + iterator.next(batch_size); + total_checked += iterator.size(); for (auto& neighbor : iterator.results()) { if (filter->is_member(neighbor.id())) { result.set(neighbor, i, found); @@ -142,6 +176,10 @@ class VamanaIndexImpl { } } } + if (should_stop_filtered_search(total_checked, found, filter_stop)) { + found = 0; + break; + } } while (found < k && !iterator.done()); // Pad results if not enough neighbors found diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 201375d3..abd14296 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -501,6 +501,122 @@ CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { svs::runtime::v0::DynamicVamanaIndex::destroy(index); } +CATCH_TEST_CASE("SearchWithRestrictiveFilter", "[runtime][filtered_search]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + size_t min_id = 0; + size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + status = + index->search(nq, xq, k, distances.data(), result_labels.data(), nullptr, &filter); + CATCH_REQUIRE(status.ok()); + + // All returned labels must fall inside the filter range + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels[i])) { + CATCH_REQUIRE(result_labels[i] >= min_id); + CATCH_REQUIRE(result_labels[i] < max_id); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("FilterStopEarlyExit", "[runtime][filtered_search]") { + const auto& test_data = get_test_data(); + // Build index + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + svs::runtime::v0::Status status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + const int nq = 5; + const float* xq = test_data.data(); + const int k = 5; + + // 10% selectivity: accept only IDs 0-9 out of 100 + size_t min_id = 0; + size_t max_id = test_n / 10; + test_utils::IDFilterRange filter(min_id, max_id); + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + // Set filter_stop = 0.5 (50%). With ~10% hit rate, search should give up + // and return unspecified results. + svs::runtime::v0::VamanaIndex::SearchParams search_params; + search_params.filter_stop = 0.5f; + + status = index->search( + nq, xq, k, distances.data(), result_labels.data(), &search_params, &filter + ); + CATCH_REQUIRE(status.ok()); + + // All results should be unspecified (early exit returned empty) + for (int i = 0; i < nq * k; ++i) { + CATCH_REQUIRE(!svs::runtime::v0::is_specified(result_labels[i])); + } + + // Now search without filter_stop — should find valid results + std::vector distances2(nq * k); + std::vector result_labels2(nq * k); + + status = index->search( + nq, xq, k, distances2.data(), result_labels2.data(), nullptr, &filter + ); + CATCH_REQUIRE(status.ok()); + + // Should have valid results in the filter range + for (int i = 0; i < nq * k; ++i) { + if (svs::runtime::v0::is_specified(result_labels2[i])) { + CATCH_REQUIRE(result_labels2[i] >= min_id); + CATCH_REQUIRE(result_labels2[i] < max_id); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + CATCH_TEST_CASE("RangeSearchFunctional", "[runtime]") { const auto& test_data = get_test_data(); // Build index