diff --git a/bindings/c/CMakeLists.txt b/bindings/c/CMakeLists.txt index c9991168..2bddac78 100644 --- a/bindings/c/CMakeLists.txt +++ b/bindings/c/CMakeLists.txt @@ -192,8 +192,9 @@ install(FILES ) # Build tests if requested -# if(SVS_BUILD_C_API_TESTS) -# add_subdirectory(tests) -# endif() +option(SVS_BUILD_C_API_TESTS "Build C API tests" ON) +if(SVS_BUILD_C_API_TESTS) + add_subdirectory(tests) +endif() add_subdirectory(samples) diff --git a/bindings/c/include/svs/c_api/svs_c.h b/bindings/c/include/svs/c_api/svs_c.h index 65cf7fc4..bff8195b 100644 --- a/bindings/c/include/svs/c_api/svs_c.h +++ b/bindings/c/include/svs/c_api/svs_c.h @@ -32,6 +32,7 @@ enum svs_error_code { SVS_ERROR_NOT_IMPLEMENTED = 5, SVS_ERROR_UNSUPPORTED_HW = 6, SVS_ERROR_RUNTIME = 7, + SVS_ERROR_INVALID_OPERATION = 8, SVS_ERROR_UNKNOWN = 1000 }; @@ -501,6 +502,32 @@ SVS_API bool svs_index_dynamic_compact( svs_index_h index, size_t batchsize /*=0*/, svs_error_h out_err /*=NULL*/ ); +/// @brief Get number of threads used for search in the index's thread pool +/// @param index The index handle +/// @param out_num_threads Pointer to store the retrieved number of threads +/// @param out_err An optional error handle to capture errors +/// @return true on success, false on failure +SVS_API bool svs_index_get_num_threads( + svs_index_h index, size_t* out_num_threads, svs_error_h out_err /*=NULL*/ +); + +/// @brief Set number of threads for search in the index's thread pool +/// @param index The index handle +/// @param num_threads The number of threads to set +/// @param out_err An optional error handle to capture errors +/// @return true on success, false on failure +/// @remarks This function is only supported for indices built with threadpool kinds +/// SVS_THREADPOOL_KIND_NATIVE or SVS_THREADPOOL_KIND_OMP. Attempting to call this +/// function on indices built with SVS_THREADPOOL_KIND_CUSTOM or +/// SVS_THREADPOOL_KIND_SINGLE_THREAD will fail and return false. +/// @error On failure, if out_err is provided, it will contain: +/// - SVS_ERROR_INVALID_OPERATION if the index was built with an unsupported threadpool kind +/// - SVS_ERROR_INVALID_ARGUMENT if num_threads is invalid or zero +/// - SVS_ERROR_RUNTIME for other runtime failures +SVS_API bool svs_index_set_num_threads( + svs_index_h index, size_t num_threads, svs_error_h out_err /*=NULL*/ +); + #ifdef __cplusplus } #endif diff --git a/bindings/c/src/error.hpp b/bindings/c/src/error.hpp index 48f40f80..1c183dd3 100644 --- a/bindings/c/src/error.hpp +++ b/bindings/c/src/error.hpp @@ -87,6 +87,11 @@ class not_implemented : public std::logic_error { using std::logic_error::logic_error; }; +class invalid_operation : public std::logic_error { + public: + using std::logic_error::logic_error; +}; + class unsupported_hw : public std::runtime_error { public: using std::runtime_error::runtime_error; @@ -104,6 +109,9 @@ Result wrap_exceptions(Callable&& func, svs_error_h err, Result err_res = {}) no } catch (const svs::c_runtime::not_implemented& ex) { SET_ERROR(err, SVS_ERROR_NOT_IMPLEMENTED, ex.what()); return err_res; + } catch (const svs::c_runtime::invalid_operation& ex) { + SET_ERROR(err, SVS_ERROR_INVALID_OPERATION, ex.what()); + return err_res; } catch (const svs::c_runtime::unsupported_hw& ex) { SET_ERROR(err, SVS_ERROR_UNSUPPORTED_HW, ex.what()); return err_res; diff --git a/bindings/c/src/index.hpp b/bindings/c/src/index.hpp index 42547c1d..86df78c9 100644 --- a/bindings/c/src/index.hpp +++ b/bindings/c/src/index.hpp @@ -18,6 +18,7 @@ #include "svs/c_api/svs_c.h" #include "algorithm.hpp" +#include "threadpool.hpp" #include #include @@ -32,8 +33,10 @@ namespace svs::c_runtime { struct Index { svs_algorithm_type algorithm; - Index(svs_algorithm_type algorithm) - : algorithm(algorithm) {} + ThreadPoolBuilder pool_builder; + Index(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder) + : algorithm(algorithm) + , pool_builder(pool_builder) {} virtual ~Index() = default; virtual svs::QueryResult search( svs::data::ConstSimpleDataView queries, @@ -45,11 +48,13 @@ struct Index { virtual float get_distance(size_t id, std::span query) const = 0; virtual void reconstruct_at(svs::data::SimpleDataView dst, std::span ids) = 0; + virtual size_t get_num_threads() { return pool_builder.get_threads_num(); }; + virtual void set_num_threads(size_t num_threads) = 0; }; struct DynamicIndex : public Index { - DynamicIndex(svs_algorithm_type algorithm) - : Index(algorithm) {} + DynamicIndex(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder) + : Index(algorithm, pool_builder) {} ~DynamicIndex() = default; virtual size_t add_points( @@ -63,8 +68,8 @@ struct DynamicIndex : public Index { struct IndexVamana : public Index { svs::Vamana index; - IndexVamana(svs::Vamana&& index) - : Index{SVS_ALGORITHM_TYPE_VAMANA} + IndexVamana(svs::Vamana&& index, ThreadPoolBuilder pool_builder) + : Index{SVS_ALGORITHM_TYPE_VAMANA, pool_builder} , index(std::move(index)) {} ~IndexVamana() = default; svs::QueryResult search( @@ -99,12 +104,17 @@ struct IndexVamana : public Index { override { index.reconstruct_at(dst, ids); } + + void set_num_threads(size_t num_threads) override { + pool_builder.resize(num_threads); + index.set_threadpool(pool_builder.build()); + } }; struct DynamicIndexVamana : public DynamicIndex { svs::DynamicVamana index; - DynamicIndexVamana(svs::DynamicVamana&& index) - : DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA) + DynamicIndexVamana(svs::DynamicVamana&& index, ThreadPoolBuilder pool_builder) + : DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA, pool_builder) , index(std::move(index)) {} ~DynamicIndexVamana() = default; @@ -170,5 +180,10 @@ struct DynamicIndexVamana : public DynamicIndex { index.compact(batchsize); } } + + void set_num_threads(size_t num_threads) override { + pool_builder.resize(num_threads); + index.set_threadpool(pool_builder.build()); + } }; } // namespace svs::c_runtime diff --git a/bindings/c/src/index_builder.hpp b/bindings/c/src/index_builder.hpp index 288d9efa..4b221397 100644 --- a/bindings/c/src/index_builder.hpp +++ b/bindings/c/src/index_builder.hpp @@ -69,13 +69,16 @@ struct IndexBuilder { if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) { auto vamana_algorithm = std::static_pointer_cast(algorithm); - auto index = std::make_shared(dispatch_vamana_index_build( - vamana_algorithm->build_parameters(), - data, - storage.get(), - to_distance_type(distance_metric), - pool_builder.build() - )); + auto index = std::make_shared( + dispatch_vamana_index_build( + vamana_algorithm->build_parameters(), + data, + storage.get(), + to_distance_type(distance_metric), + pool_builder.build() + ), + pool_builder + ); return index; } @@ -86,13 +89,16 @@ struct IndexBuilder { if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) { auto vamana_algorithm = std::static_pointer_cast(algorithm); - auto index = std::make_shared(dispatch_vamana_index_load( - vamana_algorithm->build_parameters(), - directory, - storage.get(), - to_distance_type(distance_metric), - pool_builder.build() - )); + auto index = std::make_shared( + dispatch_vamana_index_load( + vamana_algorithm->build_parameters(), + directory, + storage.get(), + to_distance_type(distance_metric), + pool_builder.build() + ), + pool_builder + ); return index; } @@ -107,8 +113,8 @@ struct IndexBuilder { if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) { auto vamana_algorithm = std::static_pointer_cast(algorithm); - auto index = - std::make_shared(dispatch_dynamic_vamana_index_build( + auto index = std::make_shared( + dispatch_dynamic_vamana_index_build( vamana_algorithm->build_parameters(), data, ids, @@ -116,7 +122,9 @@ struct IndexBuilder { to_distance_type(distance_metric), pool_builder.build(), blocksize_bytes - )); + ), + pool_builder + ); return index; } @@ -128,15 +136,17 @@ struct IndexBuilder { if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) { auto vamana_algorithm = std::static_pointer_cast(algorithm); - auto index = - std::make_shared(dispatch_dynamic_vamana_index_load( + auto index = std::make_shared( + dispatch_dynamic_vamana_index_load( vamana_algorithm->build_parameters(), directory, storage.get(), to_distance_type(distance_metric), pool_builder.build(), blocksize_bytes - )); + ), + pool_builder + ); return index; } diff --git a/bindings/c/src/svs_c.cpp b/bindings/c/src/svs_c.cpp index 4f48a802..d67523dc 100644 --- a/bindings/c/src/svs_c.cpp +++ b/bindings/c/src/svs_c.cpp @@ -661,20 +661,32 @@ extern "C" size_t svs_index_dynamic_delete_points( svs_index_h index, const size_t* ids, size_t num_ids, svs_error_h out_err ) { using namespace svs::c_runtime; - return wrap_exceptions( + std::shared_ptr dynamic_index_ptr; + auto result = wrap_exceptions( [&]() { EXPECT_ARG_NOT_NULL(index); EXPECT_ARG_NOT_NULL(ids); EXPECT_ARG_GT_THAN(num_ids, 0); - auto dynamic_index_ptr = std::dynamic_pointer_cast(index->impl); + dynamic_index_ptr = std::dynamic_pointer_cast(index->impl); INVALID_ARGUMENT_IF( dynamic_index_ptr == nullptr, "Index does not support dynamic updates" ); - return dynamic_index_ptr->delete_points(std::span(ids, num_ids)); + return 0; // return 0 for success, actual deletion happens in the next + // wrap_exceptions call }, out_err, static_cast(-1) ); + if (result != 0) { + return result; + } + // Call delete_points in a separate wrap_exceptions to return 0 if no entries are + // deleted. + return wrap_exceptions( + [&]() { return dynamic_index_ptr->delete_points(std::span(ids, num_ids)); }, + out_err, + 0 + ); } extern "C" bool svs_index_dynamic_has_id( @@ -787,3 +799,37 @@ svs_index_dynamic_compact(svs_index_h index, size_t batchsize, svs_error_h out_e false ); } + +extern "C" bool +svs_index_get_num_threads(svs_index_h index, size_t* out_num_threads, svs_error_h out_err) { + using namespace svs::c_runtime; + return wrap_exceptions( + [&]() { + EXPECT_ARG_NOT_NULL(index); + EXPECT_ARG_NOT_NULL(out_num_threads); + auto& index_ptr = index->impl; + INVALID_ARGUMENT_IF(index_ptr == nullptr, "Invalid index handle"); + *out_num_threads = index_ptr->get_num_threads(); + return true; + }, + out_err, + false + ); +} + +extern "C" bool +svs_index_set_num_threads(svs_index_h index, size_t num_threads, svs_error_h out_err) { + using namespace svs::c_runtime; + return wrap_exceptions( + [&]() { + EXPECT_ARG_NOT_NULL(index); + EXPECT_ARG_GT_THAN(num_threads, 0); + auto& index_ptr = index->impl; + INVALID_ARGUMENT_IF(index_ptr == nullptr, "Invalid index handle"); + index_ptr->set_num_threads(num_threads); + return true; + }, + out_err, + false + ); +} diff --git a/bindings/c/src/threadpool.hpp b/bindings/c/src/threadpool.hpp index d2bef66b..9c529ec2 100644 --- a/bindings/c/src/threadpool.hpp +++ b/bindings/c/src/threadpool.hpp @@ -17,6 +17,7 @@ #include "svs/c_api/svs_c.h" +#include "error.hpp" #include "types_support.hpp" #include @@ -74,7 +75,8 @@ class ThreadPoolBuilder { ThreadPoolBuilder(svs_threadpool_kind kind, size_t num_threads) : kind(kind) - , num_threads(num_threads) { + , num_threads(kind == SVS_THREADPOOL_KIND_SINGLE_THREAD ? 1 : num_threads) + , user_threadpool(nullptr) { if (kind == SVS_THREADPOOL_KIND_CUSTOM) { throw std::invalid_argument( "SVS_THREADPOOL_KIND_CUSTOM cannot be built automatically." @@ -91,6 +93,31 @@ class ThreadPoolBuilder { return std::max(size_t{1}, size_t{std::thread::hardware_concurrency()}); } + svs_threadpool_kind get_kind() const { return kind; } + svs_threadpool_i get_user_threadpool() const { return user_threadpool; } + + size_t get_threads_num() const { + if (kind == SVS_THREADPOOL_KIND_CUSTOM) { + return user_threadpool->ops.size(user_threadpool->self); + } + return num_threads; + } + + void resize(size_t new_num_threads) { + if (new_num_threads == 0) { + throw std::invalid_argument("Number of threads must be greater than zero."); + } + if (kind == SVS_THREADPOOL_KIND_SINGLE_THREAD) { + throw svs::c_runtime::invalid_operation( + "Cannot resize a single-threaded threadpool." + ); + } + if (kind == SVS_THREADPOOL_KIND_CUSTOM) { + throw svs::c_runtime::invalid_operation("Cannot resize a custom threadpool."); + } + num_threads = new_num_threads; + } + svs::threads::ThreadPoolHandle build() const { using namespace svs::threads; switch (kind) { diff --git a/bindings/c/tests/CMakeLists.txt b/bindings/c/tests/CMakeLists.txt new file mode 100644 index 00000000..8d6c1309 --- /dev/null +++ b/bindings/c/tests/CMakeLists.txt @@ -0,0 +1,105 @@ +# Copyright 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(TARGET_NAME svs_c_api_test) + +# Check if Catch2 is available +find_package(Catch2 3 QUIET) + +if(NOT Catch2_FOUND) + message(STATUS "Catch2 not found, fetching from GitHub...") + include(FetchContent) + + # Do wide printing for the console logger for Catch2 + set(CATCH_CONFIG_CONSOLE_WIDTH "100" CACHE STRING "" FORCE) + set(CATCH_BUILD_TESTING OFF CACHE BOOL "" FORCE) + set(CATCH_CONFIG_ENABLE_BENCHMARKING OFF CACHE BOOL "" FORCE) + set(CATCH_CONFIG_FAST_COMPILE OFF CACHE BOOL "" FORCE) + set(CATCH_CONFIG_PREFIX_ALL ON CACHE BOOL "" FORCE) + + set(PRESET_CMAKE_CXX_STANDARD ${CMAKE_CXX_STANDARD}) + set(CMAKE_CXX_STANDARD 20) + FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.11.0 + ) + FetchContent_MakeAvailable(Catch2) + set(CMAKE_CXX_STANDARD ${PRESET_CMAKE_CXX_STANDARD}) +endif() + +# Define test sources +set(C_API_TEST_SOURCES + c_api_error.cpp + c_api_algorithm.cpp + c_api_storage.cpp + c_api_search_params.cpp + c_api_index_builder.cpp + c_api_index.cpp + c_api_dynamic_index.cpp +) + +# Create test executable +add_executable(${TARGET_NAME} ${C_API_TEST_SOURCES}) + +# Link with C API library and Catch2 +target_link_libraries(${TARGET_NAME} PRIVATE + svs_c_api + Catch2::Catch2WithMain +) + +# Set C++ standard +target_compile_features(${TARGET_NAME} PRIVATE cxx_std_20) +set_target_properties(${TARGET_NAME} PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS OFF +) + +# Include directories +target_include_directories(${TARGET_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../include +) + +# Add test to CTest +include(CTest) +enable_testing() + +# Add the test to CTest +add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME}) + +# Set test properties +set_tests_properties(${TARGET_NAME} PROPERTIES + LABELS "c_api" +) + +# Add Catch2 CMake module path +if(NOT Catch2_FOUND) + # Catch2 was fetched, use its source directory + list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) +else() + # Catch2 was found via find_package, use its module directory + list(APPEND CMAKE_MODULE_PATH ${Catch2_DIR}) +endif() + +include(Catch) +catch_discover_tests(${TARGET_NAME}) + +# Add a custom target to run tests +add_custom_target(run_c_api_tests + COMMAND ${TARGET_NAME} + DEPENDS ${TARGET_NAME} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + COMMENT "Running C API tests..." +) diff --git a/bindings/c/tests/README.md b/bindings/c/tests/README.md new file mode 100644 index 00000000..0a1cac58 --- /dev/null +++ b/bindings/c/tests/README.md @@ -0,0 +1,164 @@ +# C API Tests + +This directory contains comprehensive tests for the SVS C API using the Catch2 testing framework. + +## Test Structure + +The tests are organized into separate files by functionality: + +- **c_api_error.cpp**: Tests for error handling functionality +- **c_api_algorithm.cpp**: Tests for algorithm creation and configuration (Vamana) +- **c_api_storage.cpp**: Tests for storage configurations (Simple, LeanVec, LVQ, SQ) +- **c_api_search_params.cpp**: Tests for search parameter creation and configuration +- **c_api_index_builder.cpp**: Tests for index builder creation and configuration +- **c_api_index.cpp**: Tests for index building, searching, and basic operations +- **c_api_dynamic_index.cpp**: Tests for dynamic index operations (add, delete, consolidate, compact) + +Note: The main() function is provided by Catch2::Catch2WithMain automatically. + +## Building the Tests + +The tests are built as part of the C API build process. To build them: + +```bash +# From the build directory +cmake -DSVS_BUILD_C_API_TESTS=ON .. +make svs_c_api_tests +``` + +To disable building tests: + +```bash +cmake -DSVS_BUILD_C_API_TESTS=OFF .. +``` + +## Running the Tests + +### Run all tests + +```bash +./svs_c_api_tests +``` + +### Run specific test cases + +```bash +# Run error handling tests only +./svs_c_api_tests "[c_api][error]" + +# Run algorithm tests only +./svs_c_api_tests "[c_api][algorithm]" + +# Run all index tests +./svs_c_api_tests "[c_api][index]" + +# Run dynamic index tests +./svs_c_api_tests "[c_api][dynamic]" +``` + +### Run with verbose output + +```bash +./svs_c_api_tests -s +``` + +### List all available tests + +```bash +./svs_c_api_tests --list-tests +``` + +### Run with CTest + +```bash +ctest -R svs_c_api_tests +``` + +## Test Coverage + +The tests cover the following aspects of the C API: + +### Error Handling + +- Error handle creation and cleanup +- Error state checking +- Error codes and messages +- NULL error handle support + +### Algorithm Configuration + +- Vamana algorithm creation +- Parameter getters and setters (graph_degree, build_window_size, alpha, search_history) +- Invalid parameter handling + +### Storage Configuration + +- Simple storage (Float32, Float16, Int8, Uint8) +- LeanVec storage (various primary/secondary combinations) +- LVQ storage (with and without residual) +- Scalar Quantization storage + +### Search Parameters + +- Vamana search parameter creation +- Various window sizes + +### Index Builder + +- Index builder creation with different metrics (Euclidean, Cosine, Dot Product) +- Storage configuration +- Thread pool configuration (Native, OMP, Custom) + +### Index Operations + +- Index building from data +- Searching with queries +- Different K values +- Distance calculation +- Vector reconstruction +- Thread count management + +### Dynamic Index Operations + +- Dynamic index building with/without explicit IDs +- Adding points +- Deleting points +- ID existence checking +- Index consolidation +- Index compaction +- Search after modifications + +## Test Patterns + +The tests follow the patterns established in the SVS project: + +1. Use `CATCH_TEST_CASE` for test case definitions +2. Use `CATCH_SECTION` for test subsections +3. Use `CATCH_REQUIRE` for assertions +4. Clean up all resources (free handles) after each test +5. Test both success and error paths +6. Test with and without NULL error handles + +## Adding New Tests + +When adding new tests: + +1. Create a new `.cpp` file or add to an existing one +2. Follow the existing structure and naming conventions +3. Include proper copyright header +4. Use appropriate test tags: `[c_api][functionality]` +5. Add the new test file to `CMakeLists.txt` if needed +6. Clean up all allocated resources +7. Test both success and error conditions + +## Dependencies + +- Catch2 v3.x (automatically fetched if not found) +- SVS C API library +- C++17 or later compiler + +## Notes + +- Tests use a simple sequential thread pool for deterministic behavior +- Test data is generated programmatically for repeatability +- Some tests may be skipped if optional features are not enabled (e.g., LVQ/LeanVec) diff --git a/bindings/c/tests/c_api_algorithm.cpp b/bindings/c/tests/c_api_algorithm.cpp new file mode 100644 index 00000000..e72f044a --- /dev/null +++ b/bindings/c/tests/c_api_algorithm.cpp @@ -0,0 +1,199 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// C API +#include "svs/c_api/svs_c.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +// Standard library +#include + +CATCH_TEST_CASE("C API Vamana Algorithm", "[c_api][algorithm][vamana]") { + CATCH_SECTION("Vamana Algorithm Creation") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Algorithm Get Graph Degree") { + svs_error_h error = svs_error_create(); + size_t expected_degree = 64; + + svs_algorithm_h algorithm = + svs_algorithm_create_vamana(expected_degree, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + size_t actual_degree = 0; + bool success = + svs_algorithm_vamana_get_graph_degree(algorithm, &actual_degree, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + CATCH_REQUIRE(actual_degree == expected_degree); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Algorithm Set Graph Degree") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + size_t new_degree = 96; + bool success = svs_algorithm_vamana_set_graph_degree(algorithm, new_degree, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + size_t actual_degree = 0; + success = svs_algorithm_vamana_get_graph_degree(algorithm, &actual_degree, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(actual_degree == new_degree); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Algorithm Get Build Window Size") { + svs_error_h error = svs_error_create(); + size_t expected_window = 128; + + svs_algorithm_h algorithm = + svs_algorithm_create_vamana(64, expected_window, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + size_t actual_window = 0; + bool success = + svs_algorithm_vamana_get_build_window_size(algorithm, &actual_window, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + CATCH_REQUIRE(actual_window == expected_window); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Algorithm Set Build Window Size") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + size_t new_window = 256; + bool success = + svs_algorithm_vamana_set_build_window_size(algorithm, new_window, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + size_t actual_window = 0; + success = + svs_algorithm_vamana_get_build_window_size(algorithm, &actual_window, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(actual_window == new_window); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Algorithm Get/Set Alpha") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + // Get default alpha + float alpha = 0.0f; + bool success = svs_algorithm_vamana_get_alpha(algorithm, &alpha, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + CATCH_REQUIRE(alpha > 0.0f); + + // Set new alpha + float new_alpha = 1.5f; + success = svs_algorithm_vamana_set_alpha(algorithm, new_alpha, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + // Verify the change + float actual_alpha = 0.0f; + success = svs_algorithm_vamana_get_alpha(algorithm, &actual_alpha, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(std::abs(actual_alpha - new_alpha) < 1e-6f); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Algorithm Get/Set Search History") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + // Get default search history setting + bool use_history = false; + bool success = + svs_algorithm_vamana_get_use_search_history(algorithm, &use_history, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + // Set search history + bool new_value = !use_history; + success = svs_algorithm_vamana_set_use_search_history(algorithm, new_value, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + // Verify the change + bool actual_value = false; + success = + svs_algorithm_vamana_get_use_search_history(algorithm, &actual_value, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(actual_value == new_value); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Algorithm with NULL Error") { + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, nullptr); + CATCH_REQUIRE(algorithm != nullptr); + + size_t degree = 0; + bool success = svs_algorithm_vamana_get_graph_degree(algorithm, °ree, nullptr); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(degree == 64); + + svs_algorithm_free(algorithm); + } + + CATCH_SECTION("Vamana Algorithm Invalid Parameters") { + svs_error_h error = svs_error_create(); + + // Try to create with invalid parameters + svs_algorithm_h algorithm = svs_algorithm_create_vamana(0, 0, 0, error); + CATCH_REQUIRE(algorithm == nullptr); + CATCH_REQUIRE(svs_error_ok(error) == false); + + svs_error_free(error); + } +} diff --git a/bindings/c/tests/c_api_dynamic_index.cpp b/bindings/c/tests/c_api_dynamic_index.cpp new file mode 100644 index 00000000..e629bbf2 --- /dev/null +++ b/bindings/c/tests/c_api_dynamic_index.cpp @@ -0,0 +1,419 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// C API +#include "svs/c_api/svs_c.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +// Standard library +#include +#include + +namespace { + +// Helper function to generate test data +void generate_test_data(std::vector& data, size_t num_vectors, size_t dimension) { + data.resize(num_vectors * dimension); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = static_cast((i * 13) % 100) / 100.0f; + } +} + +// Sequential threadpool for testing +size_t sequential_tp_size(void* /*self*/) { return 1; } + +void sequential_tp_parallel_for( + void* /*self*/, void (*func)(void*, size_t), void* svs_param, size_t n +) { + for (size_t i = 0; i < n; ++i) { + func(svs_param, i); + } +} + +} // namespace + +CATCH_TEST_CASE("C API Dynamic Index", "[c_api][index][dynamic]") { + const size_t NUM_VECTORS = 50; + const size_t DIMENSION = 32; + const size_t K = 5; + + std::vector data; + std::vector ids(NUM_VECTORS); + generate_test_data(data, NUM_VECTORS, DIMENSION); + + // Generate sequential IDs + for (size_t i = 0; i < NUM_VECTORS; ++i) { + ids[i] = i; + } + + CATCH_SECTION("Dynamic Index Build with IDs") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + CATCH_REQUIRE(builder != nullptr); + + // Build dynamic index with explicit IDs + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Build without IDs") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + CATCH_REQUIRE(builder != nullptr); + + // Build dynamic index without explicit IDs (auto-generated) + svs_index_h index = + svs_index_build_dynamic(builder, data.data(), nullptr, NUM_VECTORS, 0, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Has ID") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Check for existing IDs + for (size_t i = 0; i < 5; ++i) { + bool has_id = false; + bool success = svs_index_dynamic_has_id(index, ids[i], &has_id, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(has_id == true); + } + + // Check for non-existing ID + bool has_id = false; + bool success = svs_index_dynamic_has_id(index, NUM_VECTORS + 100, &has_id, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(has_id == false); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Add Points") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Add new points + size_t num_new_points = 5; + std::vector new_data; + std::vector new_ids(num_new_points); + generate_test_data(new_data, num_new_points, DIMENSION); + + for (size_t i = 0; i < num_new_points; ++i) { + new_ids[i] = NUM_VECTORS + i; + } + + size_t added_count = svs_index_dynamic_add_points( + index, new_data.data(), new_ids.data(), num_new_points, error + ); + CATCH_REQUIRE(added_count == num_new_points); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify new IDs exist + for (size_t i = 0; i < num_new_points; ++i) { + bool has_id = false; + bool success = svs_index_dynamic_has_id(index, new_ids[i], &has_id, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(has_id == true); + } + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Delete Points") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Delete some points + size_t ids_to_delete[] = {0, 5, 10}; + size_t num_to_delete = 3; + + size_t deleted_count = + svs_index_dynamic_delete_points(index, ids_to_delete, num_to_delete, error); + CATCH_REQUIRE(deleted_count == num_to_delete); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify deleted IDs don't exist + for (size_t i = 0; i < num_to_delete; ++i) { + bool has_id = false; + bool success = + svs_index_dynamic_has_id(index, ids_to_delete[i], &has_id, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(has_id == false); + } + + // Verify other IDs still exist + bool has_id = false; + bool success = svs_index_dynamic_has_id(index, 1, &has_id, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(has_id == true); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Add and Delete") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Delete some points + size_t ids_to_delete[] = {0, 1}; + svs_index_dynamic_delete_points(index, ids_to_delete, 2, error); + CATCH_REQUIRE(svs_error_ok(error)); + + // Add new points with the deleted IDs + std::vector new_data; + generate_test_data(new_data, 2, DIMENSION); + + size_t added_count = + svs_index_dynamic_add_points(index, new_data.data(), ids_to_delete, 2, error); + CATCH_REQUIRE(added_count == 2); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify IDs exist again + for (size_t i = 0; i < 2; ++i) { + bool has_id = false; + bool success = + svs_index_dynamic_has_id(index, ids_to_delete[i], &has_id, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(has_id == true); + } + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Consolidate") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Add and delete some points + std::vector new_data; + std::vector new_ids = {NUM_VECTORS, NUM_VECTORS + 1}; + generate_test_data(new_data, 2, DIMENSION); + + svs_index_dynamic_add_points(index, new_data.data(), new_ids.data(), 2, error); + + size_t ids_to_delete[] = {0, 1}; + svs_index_dynamic_delete_points(index, ids_to_delete, 2, error); + + // Consolidate the index + bool success = svs_index_dynamic_consolidate(index, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Compact") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Compact the index + bool success = svs_index_dynamic_compact(index, 0, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + // Delete some points + size_t ids_to_delete[] = {0, 1, 2}; + svs_index_dynamic_delete_points(index, ids_to_delete, 3, error); + CATCH_REQUIRE(svs_error_ok(error)); + + // Consolidate the index + success = svs_index_dynamic_consolidate(index, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + // Compact the index + success = svs_index_dynamic_compact(index, 0, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Search After Modifications") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Add some points + std::vector new_data; + std::vector new_ids = {NUM_VECTORS, NUM_VECTORS + 1, NUM_VECTORS + 2}; + generate_test_data(new_data, 3, DIMENSION); + svs_index_dynamic_add_points(index, new_data.data(), new_ids.data(), 3, error); + + // Delete some points + size_t ids_to_delete[] = {0, 1}; + svs_index_dynamic_delete_points(index, ids_to_delete, 2, error); + + // Perform search + std::vector queries; + generate_test_data(queries, 2, DIMENSION); + + svs_search_results_t results = + svs_index_search(index, queries.data(), 2, K, nullptr, error); + CATCH_REQUIRE(results != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(results->num_queries == 2); + + // Verify deleted IDs don't appear in results + for (size_t i = 0; i < results->num_queries * K; ++i) { + size_t result_id = results->indices[i]; + CATCH_REQUIRE(result_id != 0); + CATCH_REQUIRE(result_id != 1); + } + + svs_search_results_free(results); + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Dynamic Index Delete Non-existing ID") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build_dynamic( + builder, data.data(), ids.data(), NUM_VECTORS, 0, error + ); + CATCH_REQUIRE(index != nullptr); + + // Try to delete non-existing ID + size_t non_existing_id = NUM_VECTORS + 1000; + size_t deleted_count = + svs_index_dynamic_delete_points(index, &non_existing_id, 1, error); + // Should return 0 for non-existing ID + CATCH_REQUIRE(deleted_count == 0); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } +} diff --git a/bindings/c/tests/c_api_error.cpp b/bindings/c/tests/c_api_error.cpp new file mode 100644 index 00000000..e62c8888 --- /dev/null +++ b/bindings/c/tests/c_api_error.cpp @@ -0,0 +1,79 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// C API +#include "svs/c_api/svs_c.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +CATCH_TEST_CASE("C API Error Handling", "[c_api][error]") { + CATCH_SECTION("Error Creation and Cleanup") { + svs_error_h error = svs_error_create(); + CATCH_REQUIRE(error != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + CATCH_REQUIRE(svs_error_get_code(error) == SVS_OK); + CATCH_REQUIRE(svs_error_get_message(error) != nullptr); + svs_error_free(error); + } + + CATCH_SECTION("Error State After API Call") { + svs_error_h error = svs_error_create(); + CATCH_REQUIRE(error != nullptr); + + // Create a valid algorithm - should not set error + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + CATCH_REQUIRE(svs_error_get_code(error) == SVS_OK); + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Error State After Invalid API Call") { + svs_error_h error = svs_error_create(); + CATCH_REQUIRE(error != nullptr); + + // Try to create algorithm with invalid parameters (e.g., 0 graph degree) + svs_algorithm_h algorithm = svs_algorithm_create_vamana(0, 0, 0, error); + CATCH_REQUIRE(algorithm == nullptr); + CATCH_REQUIRE(svs_error_ok(error) == false); + CATCH_REQUIRE(svs_error_get_code(error) != SVS_OK); + CATCH_REQUIRE(svs_error_get_message(error) != nullptr); + + svs_error_free(error); + } + + CATCH_SECTION("Multiple Error Handles") { + svs_error_h error1 = svs_error_create(); + svs_error_h error2 = svs_error_create(); + + CATCH_REQUIRE(error1 != nullptr); + CATCH_REQUIRE(error2 != nullptr); + CATCH_REQUIRE(error1 != error2); + + svs_error_free(error1); + svs_error_free(error2); + } + + CATCH_SECTION("NULL Error Handle") { + // API calls should work with NULL error handle + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, nullptr); + CATCH_REQUIRE(algorithm != nullptr); + svs_algorithm_free(algorithm); + } +} diff --git a/bindings/c/tests/c_api_index.cpp b/bindings/c/tests/c_api_index.cpp new file mode 100644 index 00000000..eed471d8 --- /dev/null +++ b/bindings/c/tests/c_api_index.cpp @@ -0,0 +1,583 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// C API +#include "svs/c_api/svs_c.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +// Standard library +#include +#include +#include + +namespace { + +// Helper function to generate test data +void generate_test_data(std::vector& data, size_t num_vectors, size_t dimension) { + data.resize(num_vectors * dimension); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = static_cast((i * 7) % 100) / 100.0f; + } +} + +// Helper to calculate Euclidean distance +float euclidean_distance(const float* a, const float* b, size_t dim) { + float sum = 0.0f; + for (size_t i = 0; i < dim; ++i) { + float diff = a[i] - b[i]; + sum += diff * diff; + } + return std::sqrt(sum); +} + +// Sequential threadpool for testing +size_t sequential_tp_size(void* /*self*/) { return 1; } + +void sequential_tp_parallel_for( + void* /*self*/, void (*func)(void*, size_t), void* svs_param, size_t n +) { + for (size_t i = 0; i < n; ++i) { + func(svs_param, i); + } +} + +} // namespace + +CATCH_TEST_CASE("C API Index Build and Search", "[c_api][index][build][search]") { + const size_t NUM_VECTORS = 100; + const size_t NUM_QUERIES = 5; + const size_t DIMENSION = 32; + const size_t K = 10; + + std::vector data; + std::vector queries; + generate_test_data(data, NUM_VECTORS, DIMENSION); + generate_test_data(queries, NUM_QUERIES, DIMENSION); + + CATCH_SECTION("Basic Index Build and Search") { + svs_error_h error = svs_error_create(); + + // Create algorithm + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + CATCH_REQUIRE(algorithm != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Create builder + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + CATCH_REQUIRE(builder != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Build index with default threadpool + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Create search parameters + svs_search_params_h search_params = svs_search_params_create_vamana(50, error); + CATCH_REQUIRE(search_params != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Perform search + svs_search_results_t results = + svs_index_search(index, queries.data(), NUM_QUERIES, K, search_params, error); + CATCH_REQUIRE(results != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Validate results structure + CATCH_REQUIRE(results->num_queries == NUM_QUERIES); + CATCH_REQUIRE(results->results_per_query != nullptr); + CATCH_REQUIRE(results->indices != nullptr); + CATCH_REQUIRE(results->distances != nullptr); + + // Check that each query returned K results + for (size_t i = 0; i < NUM_QUERIES; ++i) { + CATCH_REQUIRE(results->results_per_query[i] == K); + } + + // Check that indices are within valid range + for (size_t i = 0; i < NUM_QUERIES * K; ++i) { + CATCH_REQUIRE(results->indices[i] < NUM_VECTORS); + } + + // Check that distances are non-negative + for (size_t i = 0; i < NUM_QUERIES * K; ++i) { + CATCH_REQUIRE(results->distances[i] >= 0.0f); + } + + // Cleanup + svs_search_results_free(results); + svs_search_params_free(search_params); + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Search without Search Parameters") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + CATCH_REQUIRE(builder != nullptr); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + + // Search without explicit search parameters (uses defaults) + svs_search_results_t results = + svs_index_search(index, queries.data(), NUM_QUERIES, K, nullptr, error); + CATCH_REQUIRE(results != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(results->num_queries == NUM_QUERIES); + + svs_search_results_free(results); + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index with Different Storage Types") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + CATCH_REQUIRE(algorithm != nullptr); + + // Test with Float16 storage + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + CATCH_REQUIRE(builder != nullptr); + + svs_storage_h storage = svs_storage_create_simple(SVS_DATA_TYPE_FLOAT16, error); + CATCH_REQUIRE(storage != nullptr); + + bool success = svs_index_builder_set_storage(builder, storage, error); + CATCH_REQUIRE(success); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + + svs_search_results_t results = + svs_index_search(index, queries.data(), NUM_QUERIES, K, nullptr, error); + CATCH_REQUIRE(results != nullptr); + CATCH_REQUIRE(results->num_queries == NUM_QUERIES); + + svs_search_results_free(results); + svs_index_free(index); + svs_storage_free(storage); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index with Custom Threadpool") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + // Set custom threadpool + struct svs_threadpool_interface custom_pool = { + {sequential_tp_size, sequential_tp_parallel_for}, nullptr}; + bool success = + svs_index_builder_set_threadpool_custom(builder, &custom_pool, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify index works with custom threadpool + svs_search_results_t results = + svs_index_search(index, queries.data(), NUM_QUERIES, K, nullptr, error); + CATCH_REQUIRE(results != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_search_results_free(results); + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Get Distance") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + + // Get distance from first vector to first query + float distance = -1.0f; + bool success = svs_index_get_distance(index, 0, queries.data(), &distance, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(distance >= 0.0f); + + // Verify distance is approximately correct + float expected_distance = + euclidean_distance(data.data(), queries.data(), DIMENSION); + CATCH_REQUIRE(std::abs(distance - expected_distance) < 0.1f); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Reconstruct") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + + // Reconstruct first 3 vectors + size_t ids[] = {0, 5, 10}; + size_t num_ids = 3; + std::vector reconstructed(num_ids * DIMENSION); + + bool success = svs_index_reconstruct( + index, ids, num_ids, reconstructed.data(), DIMENSION, error + ); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify reconstructed data is close to original + for (size_t i = 0; i < num_ids; ++i) { + size_t id = ids[i]; + const float* original = &data[id * DIMENSION]; + const float* recon = &reconstructed[i * DIMENSION]; + + float distance = euclidean_distance(original, recon, DIMENSION); + CATCH_REQUIRE(distance < 1.0f); // Allow some reconstruction error + } + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Search with Different K Values") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + + // Test with different K values + size_t k_values[] = {1, 5, 10, 20}; + for (size_t i = 0; i < sizeof(k_values) / sizeof(k_values[0]); ++i) { + size_t k = k_values[i]; + svs_search_results_t results = + svs_index_search(index, queries.data(), NUM_QUERIES, k, nullptr, error); + CATCH_REQUIRE(results != nullptr); + CATCH_REQUIRE(results->num_queries == NUM_QUERIES); + + for (size_t q = 0; q < NUM_QUERIES; ++q) { + CATCH_REQUIRE(results->results_per_query[q] == k); + } + + svs_search_results_free(results); + } + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Multiple Searches on Same Index") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + + // Perform multiple searches + for (size_t i = 0; i < 3; ++i) { + svs_search_results_t results = + svs_index_search(index, queries.data(), NUM_QUERIES, K, nullptr, error); + CATCH_REQUIRE(results != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(results->num_queries == NUM_QUERIES); + svs_search_results_free(results); + } + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } +} + +CATCH_TEST_CASE("C API Threadpool Management", "[c_api][index][threadpool]") { + const size_t NUM_VECTORS = 100; + const size_t DIMENSION = 32; + + std::vector data; + generate_test_data(data, NUM_VECTORS, DIMENSION); + + CATCH_SECTION("Native Threadpool Get/Set Num Threads") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + // Set native threadpool + bool success = + svs_index_builder_set_threadpool(builder, SVS_THREADPOOL_KIND_NATIVE, 2, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Get current number of threads + size_t num_threads = 0; + success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(num_threads == 2); + + // Set to different number of threads + success = svs_index_set_num_threads(index, 4, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify the change + success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(num_threads == 4); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("OMP Threadpool Get/Set Num Threads") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + // Set OMP threadpool + bool success = + svs_index_builder_set_threadpool(builder, SVS_THREADPOOL_KIND_OMP, 3, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Get current number of threads + size_t num_threads = 0; + success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(num_threads == 3); + + // Set to different number of threads + success = svs_index_set_num_threads(index, 5, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify the change + success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(num_threads == 5); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Custom Threadpool Get/Set Num Threads") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + // Set custom threadpool + struct svs_threadpool_interface custom_pool = { + {sequential_tp_size, sequential_tp_parallel_for}, nullptr}; + bool success = + svs_index_builder_set_threadpool_custom(builder, &custom_pool, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Get number of threads from custom threadpool + size_t num_threads = 0; + success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(num_threads == 1); // Sequential threadpool reports size 1 + + // Setting num_threads on custom threadpool should fail with + // SVS_ERROR_INVALID_OPERATION + success = svs_index_set_num_threads(index, 2, error); + CATCH_REQUIRE_FALSE(success); + CATCH_REQUIRE_FALSE(svs_error_ok(error)); + CATCH_REQUIRE(svs_error_get_code(error) == SVS_ERROR_INVALID_OPERATION); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Single Thread Threadpool") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + // Set single thread threadpool + bool success = svs_index_builder_set_threadpool( + builder, SVS_THREADPOOL_KIND_SINGLE_THREAD, 1, error + ); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Get number of threads + size_t num_threads = 0; + success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(num_threads == 1); + + // Try to set number of threads (should fail with SVS_ERROR_INVALID_OPERATION since + // it's single thread) + success = svs_index_set_num_threads(index, 2, error); + CATCH_REQUIRE_FALSE(success); + CATCH_REQUIRE_FALSE(svs_error_ok(error)); + CATCH_REQUIRE(svs_error_get_code(error) == SVS_ERROR_INVALID_OPERATION); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Default Threadpool") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + // Don't set any threadpool - use default + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + CATCH_REQUIRE(svs_error_ok(error)); + + // Get number of threads from default threadpool + size_t num_threads = 0; + bool success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + CATCH_REQUIRE(num_threads > 0); // Should have at least 1 thread + + // Try to set number of threads + success = svs_index_set_num_threads(index, 2, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(svs_error_ok(error)); + + // Verify the change + success = svs_index_get_num_threads(index, &num_threads, error); + CATCH_REQUIRE(success); + CATCH_REQUIRE(num_threads == 2); + CATCH_REQUIRE(svs_error_ok(error)); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Invalid Set Num Threads") { + svs_error_h error = svs_error_create(); + + svs_algorithm_h algorithm = svs_algorithm_create_vamana(16, 32, 50, error); + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, DIMENSION, algorithm, error + ); + + svs_index_h index = svs_index_build(builder, data.data(), NUM_VECTORS, error); + CATCH_REQUIRE(index != nullptr); + + // Try to set to 0 threads (invalid) - should fail with SVS_ERROR_INVALID_ARGUMENT + bool success = svs_index_set_num_threads(index, 0, error); + CATCH_REQUIRE(success == false); + CATCH_REQUIRE(svs_error_ok(error) == false); + CATCH_REQUIRE(svs_error_get_code(error) == SVS_ERROR_INVALID_ARGUMENT); + + svs_index_free(index); + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } +} diff --git a/bindings/c/tests/c_api_index_builder.cpp b/bindings/c/tests/c_api_index_builder.cpp new file mode 100644 index 00000000..5aaa8d13 --- /dev/null +++ b/bindings/c/tests/c_api_index_builder.cpp @@ -0,0 +1,226 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// C API +#include "svs/c_api/svs_c.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +// Standard library +#include + +namespace { + +// Helper function to generate random test data +void generate_test_data(std::vector& data, size_t num_vectors, size_t dimension) { + data.resize(num_vectors * dimension); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = static_cast(i % 100) / 100.0f; + } +} + +// Sequential threadpool implementation for testing +size_t sequential_tp_size(void* /*self*/) { return 1; } + +void sequential_tp_parallel_for( + void* /*self*/, void (*func)(void*, size_t), void* svs_param, size_t n +) { + for (size_t i = 0; i < n; ++i) { + func(svs_param, i); + } +} + +} // namespace + +CATCH_TEST_CASE("C API Index Builder", "[c_api][index_builder]") { + CATCH_SECTION("Index Builder Creation") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = + svs_index_builder_create(SVS_DISTANCE_METRIC_EUCLIDEAN, 128, algorithm, error); + CATCH_REQUIRE(builder != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Builder with Different Metrics") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + // Euclidean + svs_index_builder_h builder1 = + svs_index_builder_create(SVS_DISTANCE_METRIC_EUCLIDEAN, 128, algorithm, error); + CATCH_REQUIRE(builder1 != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + // Cosine + svs_index_builder_h builder2 = + svs_index_builder_create(SVS_DISTANCE_METRIC_COSINE, 128, algorithm, error); + CATCH_REQUIRE(builder2 != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + // Dot Product + svs_index_builder_h builder3 = svs_index_builder_create( + SVS_DISTANCE_METRIC_DOT_PRODUCT, 128, algorithm, error + ); + CATCH_REQUIRE(builder3 != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_index_builder_free(builder1); + svs_index_builder_free(builder2); + svs_index_builder_free(builder3); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Builder Set Storage") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = + svs_index_builder_create(SVS_DISTANCE_METRIC_EUCLIDEAN, 128, algorithm, error); + CATCH_REQUIRE(builder != nullptr); + + svs_storage_h storage = svs_storage_create_simple(SVS_DATA_TYPE_FLOAT32, error); + CATCH_REQUIRE(storage != nullptr); + + bool success = svs_index_builder_set_storage(builder, storage, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_index_builder_free(builder); + svs_storage_free(storage); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Builder Set Threadpool Native") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = + svs_index_builder_create(SVS_DISTANCE_METRIC_EUCLIDEAN, 128, algorithm, error); + CATCH_REQUIRE(builder != nullptr); + + bool success = + svs_index_builder_set_threadpool(builder, SVS_THREADPOOL_KIND_NATIVE, 2, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Builder Set Threadpool OMP") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = + svs_index_builder_create(SVS_DISTANCE_METRIC_EUCLIDEAN, 128, algorithm, error); + CATCH_REQUIRE(builder != nullptr); + + bool success = + svs_index_builder_set_threadpool(builder, SVS_THREADPOOL_KIND_OMP, 2, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Builder Set Custom Threadpool") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = + svs_index_builder_create(SVS_DISTANCE_METRIC_EUCLIDEAN, 128, algorithm, error); + CATCH_REQUIRE(builder != nullptr); + + struct svs_threadpool_interface custom_pool = { + {sequential_tp_size, sequential_tp_parallel_for}, nullptr}; + + bool success = + svs_index_builder_set_threadpool_custom(builder, &custom_pool, error); + CATCH_REQUIRE(success == true); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Builder with NULL Error") { + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, nullptr); + CATCH_REQUIRE(algorithm != nullptr); + + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, 128, algorithm, nullptr + ); + CATCH_REQUIRE(builder != nullptr); + + svs_index_builder_free(builder); + svs_algorithm_free(algorithm); + } + + CATCH_SECTION("Index Builder with Various Dimensions") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + size_t dimensions[] = {32, 64, 128, 256, 384, 512, 768, 1024}; + for (size_t i = 0; i < sizeof(dimensions) / sizeof(dimensions[0]); ++i) { + svs_index_builder_h builder = svs_index_builder_create( + SVS_DISTANCE_METRIC_EUCLIDEAN, dimensions[i], algorithm, error + ); + CATCH_REQUIRE(builder != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + svs_index_builder_free(builder); + } + + svs_algorithm_free(algorithm); + svs_error_free(error); + } + + CATCH_SECTION("Index Builder Invalid Parameters") { + svs_error_h error = svs_error_create(); + svs_algorithm_h algorithm = svs_algorithm_create_vamana(64, 128, 100, error); + CATCH_REQUIRE(algorithm != nullptr); + + // Try to create with 0 dimension + svs_index_builder_h builder = + svs_index_builder_create(SVS_DISTANCE_METRIC_EUCLIDEAN, 0, algorithm, error); + // Behavior depends on implementation + if (builder != nullptr) { + svs_index_builder_free(builder); + } + + svs_algorithm_free(algorithm); + svs_error_free(error); + } +} diff --git a/bindings/c/tests/c_api_search_params.cpp b/bindings/c/tests/c_api_search_params.cpp new file mode 100644 index 00000000..0ea7130c --- /dev/null +++ b/bindings/c/tests/c_api_search_params.cpp @@ -0,0 +1,88 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// C API +#include "svs/c_api/svs_c.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +CATCH_TEST_CASE("C API Search Parameters", "[c_api][search_params]") { + CATCH_SECTION("Vamana Search Parameters Creation") { + svs_error_h error = svs_error_create(); + + size_t search_window_size = 100; + svs_search_params_h params = + svs_search_params_create_vamana(search_window_size, error); + CATCH_REQUIRE(params != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_search_params_free(params); + svs_error_free(error); + } + + CATCH_SECTION("Vamana Search Parameters Various Sizes") { + svs_error_h error = svs_error_create(); + + size_t sizes[] = {10, 50, 100, 200, 500, 1000}; + for (size_t i = 0; i < sizeof(sizes) / sizeof(sizes[0]); ++i) { + svs_search_params_h params = svs_search_params_create_vamana(sizes[i], error); + CATCH_REQUIRE(params != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + svs_search_params_free(params); + } + + svs_error_free(error); + } + + CATCH_SECTION("Search Parameters with NULL Error") { + svs_search_params_h params = svs_search_params_create_vamana(100, nullptr); + CATCH_REQUIRE(params != nullptr); + + svs_search_params_free(params); + } + + CATCH_SECTION("Multiple Search Parameters Handles") { + svs_error_h error = svs_error_create(); + + svs_search_params_h params1 = svs_search_params_create_vamana(50, error); + svs_search_params_h params2 = svs_search_params_create_vamana(100, error); + svs_search_params_h params3 = svs_search_params_create_vamana(200, error); + + CATCH_REQUIRE(params1 != nullptr); + CATCH_REQUIRE(params2 != nullptr); + CATCH_REQUIRE(params3 != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_search_params_free(params1); + svs_search_params_free(params2); + svs_search_params_free(params3); + svs_error_free(error); + } + + CATCH_SECTION("Search Parameters with Invalid Size") { + svs_error_h error = svs_error_create(); + + // Try to create with size 0 + svs_search_params_h params = svs_search_params_create_vamana(0, error); + // Behavior depends on implementation - either nullptr or valid handle + if (params != nullptr) { + svs_search_params_free(params); + } + + svs_error_free(error); + } +} diff --git a/bindings/c/tests/c_api_storage.cpp b/bindings/c/tests/c_api_storage.cpp new file mode 100644 index 00000000..b11f4d72 --- /dev/null +++ b/bindings/c/tests/c_api_storage.cpp @@ -0,0 +1,181 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// C API +#include "svs/c_api/svs_c.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +CATCH_TEST_CASE("C API Storage", "[c_api][storage]") { + CATCH_SECTION("Simple Storage Float32") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = svs_storage_create_simple(SVS_DATA_TYPE_FLOAT32, error); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("Simple Storage Float16") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = svs_storage_create_simple(SVS_DATA_TYPE_FLOAT16, error); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("Simple Storage INT8") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = svs_storage_create_simple(SVS_DATA_TYPE_INT8, error); + CATCH_REQUIRE(storage == nullptr); + CATCH_REQUIRE_FALSE(svs_error_ok(error)); + CATCH_REQUIRE(svs_error_get_code(error) == SVS_ERROR_INVALID_ARGUMENT); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("Simple Storage UINT8") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = svs_storage_create_simple(SVS_DATA_TYPE_UINT8, error); + CATCH_REQUIRE(storage == nullptr); + CATCH_REQUIRE_FALSE(svs_error_ok(error)); + CATCH_REQUIRE(svs_error_get_code(error) == SVS_ERROR_INVALID_ARGUMENT); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("LeanVec Storage") { + svs_error_h error = svs_error_create(); + + size_t leanvec_dims = 64; + svs_storage_h storage = svs_storage_create_leanvec( + leanvec_dims, SVS_DATA_TYPE_UINT8, SVS_DATA_TYPE_UINT8, error + ); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("LeanVec Storage UINT4") { + svs_error_h error = svs_error_create(); + + size_t leanvec_dims = 64; + svs_storage_h storage = svs_storage_create_leanvec( + leanvec_dims, SVS_DATA_TYPE_UINT4, SVS_DATA_TYPE_UINT4, error + ); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("LVQ Storage UINT4") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = + svs_storage_create_lvq(SVS_DATA_TYPE_UINT4, SVS_DATA_TYPE_VOID, error); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("LVQ Storage UINT8") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = + svs_storage_create_lvq(SVS_DATA_TYPE_UINT8, SVS_DATA_TYPE_VOID, error); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("LVQ Storage with Residual") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = + svs_storage_create_lvq(SVS_DATA_TYPE_UINT4, SVS_DATA_TYPE_UINT8, error); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("Scalar Quantization Storage UINT8") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = svs_storage_create_sq(SVS_DATA_TYPE_UINT8, error); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("Scalar Quantization Storage INT8") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage = svs_storage_create_sq(SVS_DATA_TYPE_INT8, error); + CATCH_REQUIRE(storage != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage); + svs_error_free(error); + } + + CATCH_SECTION("Storage with NULL Error") { + svs_storage_h storage = svs_storage_create_simple(SVS_DATA_TYPE_FLOAT32, nullptr); + CATCH_REQUIRE(storage != nullptr); + + svs_storage_free(storage); + } + + CATCH_SECTION("Multiple Storage Handles") { + svs_error_h error = svs_error_create(); + + svs_storage_h storage1 = svs_storage_create_simple(SVS_DATA_TYPE_FLOAT32, error); + svs_storage_h storage2 = svs_storage_create_simple(SVS_DATA_TYPE_FLOAT16, error); + svs_storage_h storage3 = + svs_storage_create_leanvec(64, SVS_DATA_TYPE_UINT8, SVS_DATA_TYPE_UINT8, error); + + CATCH_REQUIRE(storage1 != nullptr); + CATCH_REQUIRE(storage2 != nullptr); + CATCH_REQUIRE(storage3 != nullptr); + CATCH_REQUIRE(svs_error_ok(error) == true); + + svs_storage_free(storage1); + svs_storage_free(storage2); + svs_storage_free(storage3); + svs_error_free(error); + } +}