Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion bindings/cpp/include/svs/runtime/ivf_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,28 @@ struct SVS_RUNTIME_API IVFIndex {
size_t n_probes = Unspecify<size_t>();
/// Level of reordering/reranking done when using compressed datasets (multiplier)
float k_reorder = Unspecify<float>();
/// Minimum filter hit rate to continue filtered search.
/// If the hit rate after the first round falls below this threshold,
/// stop and return empty results (caller can fall back to exact search).
/// Default unspecified means never give up (treated as 0).
float filter_stop = Unspecify<float>();
/// Enable pre-search filter sampling to estimate hit rate before
/// cluster traversal. Uses a random sample of IDs to set the initial
/// batch size and trigger early exit.
OptionalBool filter_estimate_batch = Unspecify<bool>();
};

/// @brief Perform k-NN search on the index.
/// @param filter Optional ID filter; when non-null, only IDs satisfying
/// ``filter->is_member(id)`` are returned.
virtual Status search(
size_t n,
const float* x,
size_t k,
float* distances,
size_t* labels,
const SearchParams* params = nullptr
const SearchParams* params = nullptr,
IDFilter* filter = nullptr
) const noexcept = 0;

/// @brief Utility function to check storage kind support.
Expand Down Expand Up @@ -108,6 +120,13 @@ struct SVS_RUNTIME_API IVFIndex {
/// @brief Get the number of threads used for index operations.
virtual Status get_num_threads(size_t* num_threads) const noexcept = 0;

/// @brief Set the number of intra-query (cluster-level) threads.
/// Recreates the per-query intra-query thread pools.
virtual Status set_intra_query_threads(size_t intra_query_threads) noexcept = 0;

/// @brief Get the current number of intra-query (cluster-level) threads.
virtual Status get_intra_query_threads(size_t* intra_query_threads) const noexcept = 0;

/// @brief Load an IVF index from a stream.
static Status load(
IVFIndex** index,
Expand Down
17 changes: 15 additions & 2 deletions bindings/cpp/src/dynamic_ivf_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ struct DynamicIVFIndexManager : public DynamicIVFIndex {
size_t k,
float* distances,
size_t* labels,
const SearchParams* params = nullptr
const SearchParams* params = nullptr,
IDFilter* filter = nullptr
) const noexcept override {
return runtime_error_wrapper([&] {
auto result = svs::QueryResultView<size_t>{
svs::MatrixView<size_t>{svs::make_dims(n, k), labels},
svs::MatrixView<float>{svs::make_dims(n, k), distances}};
auto queries = svs::data::ConstSimpleDataView<float>(x, n, impl_->dimensions());
impl_->search(result, queries, params);
impl_->search(result, queries, params, filter);
});
}

Expand Down Expand Up @@ -112,6 +113,18 @@ struct DynamicIVFIndexManager : public DynamicIVFIndex {
Status get_num_threads(size_t* num_threads) const noexcept override {
return runtime_error_wrapper([&] { *num_threads = impl_->get_num_threads(); });
}

Status set_intra_query_threads(size_t intra_query_threads) noexcept override {
return runtime_error_wrapper([&] {
impl_->set_intra_query_threads(intra_query_threads);
});
}

Status get_intra_query_threads(size_t* intra_query_threads) const noexcept override {
return runtime_error_wrapper([&] {
*intra_query_threads = impl_->get_intra_query_threads();
});
}
};

} // namespace
Expand Down
100 changes: 98 additions & 2 deletions bindings/cpp/src/dynamic_ivf_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,110 @@
impl_->compact(batchsize);
}

void search(
svs::QueryResultView<size_t> result,
svs::data::ConstSimpleDataView<float> queries,
const IVFIndex::SearchParams* params = nullptr
const IVFIndex::SearchParams* params = nullptr,
IDFilter* filter = nullptr
) const {
if (!impl_) {
auto& dists = result.distances();
std::fill(dists.begin(), dists.end(), Unspecify<float>());
auto& inds = result.indices();
std::fill(inds.begin(), inds.end(), Unspecify<size_t>());
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}

if (queries.size() == 0) {
return;
}

const size_t k = result.n_neighbors();
if (k == 0) {
throw StatusException{ErrorCode::INVALID_ARGUMENT, "k must be greater than 0"};
}

auto sp = make_search_parameters(params);
impl_->search(result, queries, sp);

// Simple search
if (filter == nullptr) {
impl_->search(result, queries, sp);
return;
}

// Selective search with IDFilter: use batch iterator to over-fetch and
// filter per-neighbor, mirroring the Vamana approach.
auto old_sp = impl_->get_search_parameters();
auto sp_restore = svs::lib::make_scope_guard([&]() noexcept {
impl_->set_search_parameters(old_sp);
});
impl_->set_search_parameters(sp);

float filter_stop = 0.0f;
bool filter_estimate_batch = true;
if (params) {
set_if_specified(filter_stop, params->filter_stop);
set_if_specified(filter_estimate_batch, params->filter_estimate_batch);
}
const auto max_batch_size = impl_->size();

// Pre-search filter sampling: estimate hit rate before cluster traversal.
size_t sampled = 0;
size_t sample_hits = 0;
const auto initial_batch_hint = std::max(k, size_t{1});
auto initial_batch_size = initial_batch_hint;
if (filter_estimate_batch) {
std::tie(sampled, sample_hits) = sample_filter_hits(
*filter,
max_batch_size,
[this](size_t id) { return impl_->has_id(id); },
sample_size_for_filter_stop(filter_stop)
);
if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) {
pad_empty_results(result, queries.size(), k);
return;
}
initial_batch_size = predict_further_processing(
sampled, sample_hits, k, initial_batch_hint, max_batch_size
);
}

for (size_t i = 0; i < queries.size(); ++i) {
auto query = queries.get_datum(i);
auto iterator =
impl_->batch_iterator(std::span<const float>(query.data(), query.size()));
size_t found = 0;
size_t total_checked = 0;
auto batch_size = initial_batch_size;
do {
batch_size = predict_further_processing(
total_checked, found, k, batch_size, max_batch_size
);
iterator.next(batch_size);
total_checked += iterator.size();
for (const auto& neighbor : iterator.results()) {
if (filter->is_member(neighbor.id())) {
result.set(neighbor, i, found);
++found;
if (found == k) {
break;
}
}
}
if (should_stop_filtered_search(total_checked, found, filter_stop)) {
found = 0;
break;
}
} while (found < k && !iterator.done());

for (size_t j = found; j < k; ++j) {
result.set(
svs::Neighbor<size_t>{Unspecify<size_t>(), Unspecify<float>()}, i, j
);
}
}
}

Check notice on line 228 in bindings/cpp/src/dynamic_ivf_index_impl.h

View check run for this annotation

codefactor.io / CodeFactor

bindings/cpp/src/dynamic_ivf_index_impl.h#L125-L228

Complex Method
void save(std::ostream& out) const {
if (!impl_) {
throw StatusException{
Expand All @@ -174,6 +252,24 @@
return num_threads_;
}

void set_intra_query_threads(size_t intra_query_threads) {
if (intra_query_threads == 0) {
throw StatusException{
ErrorCode::INVALID_ARGUMENT, "intra_query_threads must be at least 1"};
}
intra_query_threads_ = intra_query_threads;
if (impl_) {
impl_->set_num_intra_query_threads(intra_query_threads);
}
}

size_t get_intra_query_threads() const {
if (impl_) {
return impl_->get_num_intra_query_threads();
}
return intra_query_threads_;
}

static DynamicIVFIndexImpl* load(
std::istream& in,
MetricType metric,
Expand Down
17 changes: 15 additions & 2 deletions bindings/cpp/src/ivf_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ struct IVFIndexManager : public IVFIndex {
size_t k,
float* distances,
size_t* labels,
const SearchParams* params = nullptr
const SearchParams* params = nullptr,
IDFilter* filter = nullptr
) const noexcept override {
return runtime_error_wrapper([&] {
auto result = svs::QueryResultView<size_t>{
svs::MatrixView<size_t>{svs::make_dims(n, k), labels},
svs::MatrixView<float>{svs::make_dims(n, k), distances}};
auto queries = svs::data::ConstSimpleDataView<float>(x, n, impl_->dimensions());
impl_->search(result, queries, params);
impl_->search(result, queries, params, filter);
});
}

Expand All @@ -76,6 +77,18 @@ struct IVFIndexManager : public IVFIndex {
Status get_num_threads(size_t* num_threads) const noexcept override {
return runtime_error_wrapper([&] { *num_threads = impl_->get_num_threads(); });
}

Status set_intra_query_threads(size_t intra_query_threads) noexcept override {
return runtime_error_wrapper([&] {
impl_->set_intra_query_threads(intra_query_threads);
});
}

Status get_intra_query_threads(size_t* intra_query_threads) const noexcept override {
return runtime_error_wrapper([&] {
*intra_query_threads = impl_->get_intra_query_threads();
});
}
};

} // namespace
Expand Down
101 changes: 99 additions & 2 deletions bindings/cpp/src/ivf_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,32 +360,111 @@
init_impl(data);
}

void search(
svs::QueryResultView<size_t> result,
svs::data::ConstSimpleDataView<float> queries,
const IVFIndex::SearchParams* params = nullptr
const IVFIndex::SearchParams* params = nullptr,
IDFilter* filter = nullptr
) const {
if (!impl_) {
auto& dists = result.distances();
std::fill(dists.begin(), dists.end(), Unspecify<float>());
auto& inds = result.indices();
std::fill(inds.begin(), inds.end(), Unspecify<size_t>());
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}

if (queries.size() == 0) {
return;
}

const size_t k = result.n_neighbors();
if (k == 0) {
throw StatusException{ErrorCode::INVALID_ARGUMENT, "k must be greater than 0"};
}

auto sp = make_search_parameters(params);
impl_->search(result, queries, sp);

// Simple search
if (filter == nullptr) {
impl_->search(result, queries, sp);
return;
}

// Selective search with IDFilter: use batch iterator to over-fetch and
// filter per-neighbor, mirroring the Vamana approach.
auto old_sp = impl_->get_search_parameters();
auto sp_restore = svs::lib::make_scope_guard([&]() noexcept {
impl_->set_search_parameters(old_sp);
});
impl_->set_search_parameters(sp);

float filter_stop = 0.0f;
bool filter_estimate_batch = true;
if (params) {
set_if_specified(filter_stop, params->filter_stop);
set_if_specified(filter_estimate_batch, params->filter_estimate_batch);
}
const auto max_batch_size = impl_->size();

// Pre-search filter sampling: estimate hit rate before cluster traversal.
size_t sampled = 0;
size_t sample_hits = 0;
const auto initial_batch_hint = std::max(k, size_t{1});
auto initial_batch_size = initial_batch_hint;
if (filter_estimate_batch) {
std::tie(sampled, sample_hits) = sample_filter_hits(
*filter,
max_batch_size,
[](size_t) { return true; },
sample_size_for_filter_stop(filter_stop)
);
if (should_stop_filtered_search(sampled, sample_hits, filter_stop)) {
pad_empty_results(result, queries.size(), k);
return;
}
initial_batch_size = predict_further_processing(
sampled, sample_hits, k, initial_batch_hint, max_batch_size
);
}

for (size_t i = 0; i < queries.size(); ++i) {
auto query = queries.get_datum(i);
auto iterator =
impl_->batch_iterator(std::span<const float>(query.data(), query.size()));
size_t found = 0;
size_t total_checked = 0;
auto batch_size = initial_batch_size;
do {
batch_size = predict_further_processing(
total_checked, found, k, batch_size, max_batch_size
);
iterator.next(batch_size);
total_checked += iterator.size();
for (const auto& neighbor : iterator.results()) {
if (filter->is_member(neighbor.id())) {
result.set(neighbor, i, found);
++found;
if (found == k) {
break;
}
}
}
if (should_stop_filtered_search(total_checked, found, filter_stop)) {
found = 0;
break;
}
} while (found < k && !iterator.done());

// Pad results if not enough neighbors found
for (size_t j = found; j < k; ++j) {
result.set(
svs::Neighbor<size_t>{Unspecify<size_t>(), Unspecify<float>()}, i, j
);
}
}
}

Check notice on line 467 in bindings/cpp/src/ivf_index_impl.h

View check run for this annotation

codefactor.io / CodeFactor

bindings/cpp/src/ivf_index_impl.h#L363-L467

Complex Method
void save(std::ostream& out) const {
if (!impl_) {
throw StatusException{
Expand All @@ -411,6 +490,24 @@
return num_threads_;
}

void set_intra_query_threads(size_t intra_query_threads) {
if (intra_query_threads == 0) {
throw StatusException{
ErrorCode::INVALID_ARGUMENT, "intra_query_threads must be at least 1"};
}
intra_query_threads_ = intra_query_threads;
if (impl_) {
impl_->set_num_intra_query_threads(intra_query_threads);
}
}

size_t get_intra_query_threads() const {
if (impl_) {
return impl_->get_num_intra_query_threads();
}
return intra_query_threads_;
}

static IVFIndexImpl* load(
std::istream& in,
MetricType metric,
Expand Down
Loading
Loading