Skip to content

Commit

Permalink
improve code_distance performance for small M
Browse files Browse the repository at this point in the history
  • Loading branch information
vorj committed Oct 11, 2024
1 parent 4418735 commit 72d008b
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions faiss/impl/code_distance/code_distance-sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ static inline void distance_codes_kernel(
partialSum = svadd_f32_m(pg, partialSum, collected);
}

static float distance_single_code_sve_for_small_m(
// the product quantizer
const size_t M,
// precomputed distances, layout (M, ksub)
const float* sim_table,
// codes
const uint8_t* __restrict code) {
constexpr size_t nbits = 8u;

const size_t ksub = 1 << nbits;

const auto offsets_0 = svindex_u32(0, static_cast<uint32_t>(ksub));

// loop
const auto pg = svwhilelt_b32_u64(0, M);

auto mm1 = svld1ub_u32(pg, code);
mm1 = svadd_u32_x(pg, mm1, offsets_0);
const auto collected0 = svld1_gather_u32index_f32(pg, sim_table, mm1);
return svaddv_f32(pg, collected0);
}

template <typename PQDecoderT>
std::enable_if_t<std::is_same_v<PQDecoderT, PQDecoder8>, float> inline distance_single_code_sve(
// the product quantizer
Expand All @@ -57,6 +79,9 @@ std::enable_if_t<std::is_same_v<PQDecoderT, PQDecoder8>, float> inline distance_
// precomputed distances, layout (M, ksub)
const float* sim_table,
const uint8_t* code) {
if (M <= svcntw())
return distance_single_code_sve_for_small_m(M, sim_table, code);

const float* tab = sim_table;

const size_t ksub = 1 << nbits;
Expand Down Expand Up @@ -171,6 +196,50 @@ distance_four_codes_sve(
result3);
}

static void distance_four_codes_sve_for_small_m(
// the product quantizer
const size_t M,
// precomputed distances, layout (M, ksub)
const float* sim_table,
// codes
const uint8_t* __restrict code0,
const uint8_t* __restrict code1,
const uint8_t* __restrict code2,
const uint8_t* __restrict code3,
// computed distances
float& result0,
float& result1,
float& result2,
float& result3) {
constexpr size_t nbits = 8u;

const size_t ksub = 1 << nbits;

const auto offsets_0 = svindex_u32(0, static_cast<uint32_t>(ksub));

const auto quad_lanes = svcntw();

// loop
const auto pg = svwhilelt_b32_u64(0, M);

auto mm10 = svld1ub_u32(pg, code0);
auto mm11 = svld1ub_u32(pg, code1);
auto mm12 = svld1ub_u32(pg, code2);
auto mm13 = svld1ub_u32(pg, code3);
mm10 = svadd_u32_x(pg, mm10, offsets_0);
mm11 = svadd_u32_x(pg, mm11, offsets_0);
mm12 = svadd_u32_x(pg, mm12, offsets_0);
mm13 = svadd_u32_x(pg, mm13, offsets_0);
const auto collected0 = svld1_gather_u32index_f32(pg, sim_table, mm10);
const auto collected1 = svld1_gather_u32index_f32(pg, sim_table, mm11);
const auto collected2 = svld1_gather_u32index_f32(pg, sim_table, mm12);
const auto collected3 = svld1_gather_u32index_f32(pg, sim_table, mm13);
result0 = svaddv_f32(pg, collected0);
result1 = svaddv_f32(pg, collected1);
result2 = svaddv_f32(pg, collected2);
result3 = svaddv_f32(pg, collected3);
}

// Combines 4 operations of distance_single_code()
template <typename PQDecoderT>
std::enable_if_t<std::is_same_v<PQDecoderT, PQDecoder8>, void>
Expand All @@ -191,6 +260,21 @@ distance_four_codes_sve(
float& result1,
float& result2,
float& result3) {
if (M <= svcntw()) {
distance_four_codes_sve_for_small_m(
M,
sim_table,
code0,
code1,
code2,
code3,
result0,
result1,
result2,
result3);
return;
}

const float* tab = sim_table;

const size_t ksub = 1 << nbits;
Expand Down

0 comments on commit 72d008b

Please sign in to comment.