Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@


.cache/
.vscode/
bin/
build/
data/
results/

*.pyc
.clang-*
Expand Down
4 changes: 2 additions & 2 deletions include/rabitqlib/index/estimator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,13 @@ inline void split_single_estdist(
) {
ConstBinDataMap<float> cur_bin(bin_data, padded_dim);

ip_x0_qr = warmup_ip_x0_q<SplitSingleQuery<float>::kNumBits>(
ip_x0_qr = warmup_ip_x0_q_512(
cur_bin.bin_code(),
q_obj.query_bin(),
q_obj.delta(),
q_obj.vl(),
padded_dim,
SplitSingleQuery<float>::kNumBits
q_obj.num_bits()
);

est_dist =
Expand Down
10 changes: 6 additions & 4 deletions include/rabitqlib/index/query.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,21 @@ class SplitSingleQuery {

metric_type_ = (metric_type == METRIC_IP) ? METRIC_IP : METRIC_L2;

std::vector<uint16_t> quant_query = std::vector<uint16_t>(padded_dim);
std::vector<uint8_t> quant_query(padded_dim);

// quantize query by rabitq
quant::quantize_scalar<float, uint16_t>(
quant::quantize_scalar<float, uint8_t>(
rotated_query, padded_dim, kNumBits, quant_query.data(), delta_, vl_, config
);

// represent quantized query as u64
rabitqlib::new_transpose_bin(
rabitqlib::new_transpose_bin_512(
quant_query.data(), QueryBin_.data(), padded_dim, kNumBits
);
}

[[nodiscard]] size_t num_bits() const { return kNumBits; }

[[nodiscard]] const uint64_t* query_bin() const { return QueryBin_.data(); }

[[nodiscard]] const T* rotated_query() const { return rotated_query_; }
Expand Down Expand Up @@ -184,4 +186,4 @@ class SplitSingleQuery {
void set_g_error(T norm) { G_error_ = norm; }
};

} // namespace rabitqlib
} // namespace rabitqlib
33 changes: 33 additions & 0 deletions include/rabitqlib/utils/space.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,39 @@ static inline void new_transpose_bin(
#endif
}

static inline void new_transpose_bin_512(
const uint8_t* q, uint64_t* tq, size_t padded_dim, size_t b_query
) {
#if defined(__AVX512BW__)
// Keep full 512-dim blocks as 8 chunks, but store the tail as compact
// [b_query x num_chunks] so runtime can use maskz loads without query padding.
for (size_t i = 0; i < padded_dim;) {
size_t block_size = 512;
if (i + 512 > padded_dim) {
block_size = padded_dim - i;
}
size_t num_chunks = block_size / 64;

for (size_t k = 0; k < num_chunks; ++k) {
const uint8_t* current_q = q + i + k * 64;
__m512i vec = _mm512_loadu_si512(current_q);

for (size_t j = 0; j < b_query; ++j) {
int bit_idx = b_query - 1 - j;
__mmask64 m = _mm512_test_epi8_mask(vec, _mm512_set1_epi8(1 << bit_idx));
tq[(b_query - j - 1) * num_chunks + k] = static_cast<uint64_t>(m);
}
}

i += block_size;
tq += num_chunks * b_query;
}
#else
std::cerr << "AVX512BW is required for new_transpose_bin_512\n";
exit(1);
#endif
}

inline float mask_ip_x0_q_old(const float* query, const uint64_t* data, size_t padded_dim) {
auto num_blk = padded_dim / 64;
const auto* it_data = data;
Expand Down
72 changes: 71 additions & 1 deletion include/rabitqlib/utils/warmup_space.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,76 @@ inline __m256i popcount_avx2(__m256i v) {
#endif
}

inline float warmup_ip_x0_q_512(
const uint64_t* data,
const uint64_t* query,
float delta,
float vl,
size_t padded_dim,
size_t b_query
) {
#if defined(__AVX512VPOPCNTDQ__) && defined(__AVX512BW__)
size_t ip_scalar = 0;
size_t ppc_scalar = 0;

__m512i acc_ip = _mm512_setzero_si512();
__m512i acc_ppc = _mm512_setzero_si512();

size_t i = 0;
size_t dim_end_512 = (padded_dim / 512) * 512;

__m512i acc_bits[b_query];
for (size_t j = 0; j < b_query; ++j) {
acc_bits[j] = _mm512_setzero_si512();
}

for (; i < dim_end_512; i += 512) {
__m512i data_vec = _mm512_loadu_si512(data);
data += 8;

acc_ppc = _mm512_add_epi64(acc_ppc, _mm512_popcnt_epi64(data_vec));

for (size_t j = 0; j < b_query; ++j) {
__m512i query_vec = _mm512_loadu_si512(query);
query += 8;

__m512i pop = _mm512_popcnt_epi64(_mm512_and_si512(data_vec, query_vec));
acc_bits[j] = _mm512_add_epi64(acc_bits[j], pop);
}
}

size_t remaining_dim = padded_dim - i;
if (remaining_dim > 0) {
size_t num_chunks = remaining_dim / 64;
auto valid_mask = static_cast<__mmask8>((1u << num_chunks) - 1u);

__m512i data_vec = _mm512_maskz_loadu_epi64(valid_mask, data);
acc_ppc = _mm512_add_epi64(acc_ppc, _mm512_popcnt_epi64(data_vec));

for (size_t j = 0; j < b_query; ++j) {
__m512i query_vec = _mm512_maskz_loadu_epi64(valid_mask, query);
query += num_chunks;

__m512i pop = _mm512_popcnt_epi64(_mm512_and_si512(data_vec, query_vec));
acc_bits[j] = _mm512_add_epi64(acc_bits[j], pop);
}
}

for (size_t j = 0; j < b_query; ++j) {
__m128i shift = _mm_cvtsi32_si128(static_cast<int>(j));
acc_ip = _mm512_add_epi64(acc_ip, _mm512_sll_epi64(acc_bits[j], shift));
}

ip_scalar += static_cast<size_t>(_mm512_reduce_add_epi64(acc_ip));
ppc_scalar += static_cast<size_t>(_mm512_reduce_add_epi64(acc_ppc));

return (delta * static_cast<float>(ip_scalar)) + (vl * static_cast<float>(ppc_scalar));
#else
std::cerr << "AVX512 VPOPCNTDQ and AVX512BW are required for warmup_ip_x0_q_512\n";
exit(1);
#endif
}

template <uint32_t b_query>
inline float warmup_ip_x0_q(
const uint64_t* data, // pointer to data blocks (each 64 bits)
Expand Down Expand Up @@ -231,4 +301,4 @@ inline float warmup_ip_x0_q(
}

return (delta * static_cast<float>(ip)) + (vl * static_cast<float>(ppc));
}
}