diff --git a/.gitignore b/.gitignore index 6c2a39d..b07cbcd 100644 --- a/.gitignore +++ b/.gitignore @@ -33,9 +33,11 @@ .cache/ +.vscode/ bin/ build/ data/ +results/ *.pyc .clang-* diff --git a/include/rabitqlib/index/estimator.hpp b/include/rabitqlib/index/estimator.hpp index c8f4219..1f8baae 100644 --- a/include/rabitqlib/index/estimator.hpp +++ b/include/rabitqlib/index/estimator.hpp @@ -199,13 +199,13 @@ inline void split_single_estdist( ) { ConstBinDataMap cur_bin(bin_data, padded_dim); - ip_x0_qr = warmup_ip_x0_q::kNumBits>( + ip_x0_qr = warmup_ip_x0_q_512( cur_bin.bin_code(), q_obj.query_bin(), q_obj.delta(), q_obj.vl(), padded_dim, - SplitSingleQuery::kNumBits + q_obj.num_bits() ); est_dist = diff --git a/include/rabitqlib/index/query.hpp b/include/rabitqlib/index/query.hpp index fb98479..3aa426b 100644 --- a/include/rabitqlib/index/query.hpp +++ b/include/rabitqlib/index/query.hpp @@ -142,19 +142,21 @@ class SplitSingleQuery { metric_type_ = (metric_type == METRIC_IP) ? METRIC_IP : METRIC_L2; - std::vector quant_query = std::vector(padded_dim); + std::vector quant_query(padded_dim); // quantize query by rabitq - quant::quantize_scalar( + quant::quantize_scalar( rotated_query, padded_dim, kNumBits, quant_query.data(), delta_, vl_, config ); // represent quantized query as u64 - rabitqlib::new_transpose_bin( + rabitqlib::new_transpose_bin_512( quant_query.data(), QueryBin_.data(), padded_dim, kNumBits ); } + [[nodiscard]] size_t num_bits() const { return kNumBits; } + [[nodiscard]] const uint64_t* query_bin() const { return QueryBin_.data(); } [[nodiscard]] const T* rotated_query() const { return rotated_query_; } @@ -184,4 +186,4 @@ class SplitSingleQuery { void set_g_error(T norm) { G_error_ = norm; } }; -} // namespace rabitqlib \ No newline at end of file +} // namespace rabitqlib diff --git a/include/rabitqlib/utils/space.hpp b/include/rabitqlib/utils/space.hpp index 446a0e2..120d8f5 100644 --- a/include/rabitqlib/utils/space.hpp +++ b/include/rabitqlib/utils/space.hpp @@ -904,6 +904,39 @@ static inline void new_transpose_bin( #endif } +static inline void new_transpose_bin_512( + const uint8_t* q, uint64_t* tq, size_t padded_dim, size_t b_query +) { +#if defined(__AVX512BW__) + // Keep full 512-dim blocks as 8 chunks, but store the tail as compact + // [b_query x num_chunks] so runtime can use maskz loads without query padding. + for (size_t i = 0; i < padded_dim;) { + size_t block_size = 512; + if (i + 512 > padded_dim) { + block_size = padded_dim - i; + } + size_t num_chunks = block_size / 64; + + for (size_t k = 0; k < num_chunks; ++k) { + const uint8_t* current_q = q + i + k * 64; + __m512i vec = _mm512_loadu_si512(current_q); + + for (size_t j = 0; j < b_query; ++j) { + int bit_idx = b_query - 1 - j; + __mmask64 m = _mm512_test_epi8_mask(vec, _mm512_set1_epi8(1 << bit_idx)); + tq[(b_query - j - 1) * num_chunks + k] = static_cast(m); + } + } + + i += block_size; + tq += num_chunks * b_query; + } +#else + std::cerr << "AVX512BW is required for new_transpose_bin_512\n"; + exit(1); +#endif +} + inline float mask_ip_x0_q_old(const float* query, const uint64_t* data, size_t padded_dim) { auto num_blk = padded_dim / 64; const auto* it_data = data; diff --git a/include/rabitqlib/utils/warmup_space.hpp b/include/rabitqlib/utils/warmup_space.hpp index 7347bf6..978ad91 100644 --- a/include/rabitqlib/utils/warmup_space.hpp +++ b/include/rabitqlib/utils/warmup_space.hpp @@ -34,6 +34,76 @@ inline __m256i popcount_avx2(__m256i v) { #endif } +inline float warmup_ip_x0_q_512( + const uint64_t* data, + const uint64_t* query, + float delta, + float vl, + size_t padded_dim, + size_t b_query +) { +#if defined(__AVX512VPOPCNTDQ__) && defined(__AVX512BW__) + size_t ip_scalar = 0; + size_t ppc_scalar = 0; + + __m512i acc_ip = _mm512_setzero_si512(); + __m512i acc_ppc = _mm512_setzero_si512(); + + size_t i = 0; + size_t dim_end_512 = (padded_dim / 512) * 512; + + __m512i acc_bits[b_query]; + for (size_t j = 0; j < b_query; ++j) { + acc_bits[j] = _mm512_setzero_si512(); + } + + for (; i < dim_end_512; i += 512) { + __m512i data_vec = _mm512_loadu_si512(data); + data += 8; + + acc_ppc = _mm512_add_epi64(acc_ppc, _mm512_popcnt_epi64(data_vec)); + + for (size_t j = 0; j < b_query; ++j) { + __m512i query_vec = _mm512_loadu_si512(query); + query += 8; + + __m512i pop = _mm512_popcnt_epi64(_mm512_and_si512(data_vec, query_vec)); + acc_bits[j] = _mm512_add_epi64(acc_bits[j], pop); + } + } + + size_t remaining_dim = padded_dim - i; + if (remaining_dim > 0) { + size_t num_chunks = remaining_dim / 64; + auto valid_mask = static_cast<__mmask8>((1u << num_chunks) - 1u); + + __m512i data_vec = _mm512_maskz_loadu_epi64(valid_mask, data); + acc_ppc = _mm512_add_epi64(acc_ppc, _mm512_popcnt_epi64(data_vec)); + + for (size_t j = 0; j < b_query; ++j) { + __m512i query_vec = _mm512_maskz_loadu_epi64(valid_mask, query); + query += num_chunks; + + __m512i pop = _mm512_popcnt_epi64(_mm512_and_si512(data_vec, query_vec)); + acc_bits[j] = _mm512_add_epi64(acc_bits[j], pop); + } + } + + for (size_t j = 0; j < b_query; ++j) { + __m128i shift = _mm_cvtsi32_si128(static_cast(j)); + acc_ip = _mm512_add_epi64(acc_ip, _mm512_sll_epi64(acc_bits[j], shift)); + } + + ip_scalar += static_cast(_mm512_reduce_add_epi64(acc_ip)); + ppc_scalar += static_cast(_mm512_reduce_add_epi64(acc_ppc)); + + return (delta * static_cast(ip_scalar)) + (vl * static_cast(ppc_scalar)); +#else + std::cerr << "AVX512 VPOPCNTDQ and AVX512BW are required for warmup_ip_x0_q_512\n"; + exit(1); +#endif +} + template inline float warmup_ip_x0_q( const uint64_t* data, // pointer to data blocks (each 64 bits) @@ -231,4 +301,4 @@ inline float warmup_ip_x0_q( } return (delta * static_cast(ip)) + (vl * static_cast(ppc)); -} \ No newline at end of file +}