Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SIMD Histogram Subroutines on aarch64 #2447

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions faiss/utils/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ template uint16_t partition_fuzzy<CMax<uint16_t, int>>(
* Histogram subroutines
******************************************************************/

#ifdef __AVX2__
#if defined(__AVX2__) || defined(__aarch64__)
/// FIXME when MSB of uint16 is set
// this code does not compile properly with GCC 7.4.0

Expand All @@ -833,7 +833,7 @@ simd32uint8 accu4to8(simd16uint16 a4) {
simd16uint16 a8_0 = a4 & mask4;
simd16uint16 a8_1 = (a4 >> 4) & mask4;

return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
return simd32uint8(hadd(a8_0, a8_1));
}

simd16uint16 accu8to16(simd32uint8 a8) {
Expand All @@ -842,10 +842,10 @@ simd16uint16 accu8to16(simd32uint8 a8) {
simd16uint16 a8_0 = simd16uint16(a8) & mask8;
simd16uint16 a8_1 = (simd16uint16(a8) >> 8) & mask8;

return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
return hadd(a8_0, a8_1);
}

static const simd32uint8 shifts(_mm256_setr_epi8(
static const simd32uint8 shifts = simd32uint8::create<
1,
16,
0,
Expand Down Expand Up @@ -877,7 +877,7 @@ static const simd32uint8 shifts(_mm256_setr_epi8(
0,
0,
4,
64));
64>();

// 2-bit accumulator: we can add only up to 3 elements
// on output we return 2*4-bit results
Expand Down Expand Up @@ -937,7 +937,7 @@ simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
simd16uint16 a16lo = accu8to16(a8lo);
simd16uint16 a16hi = accu8to16(a8hi);

simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
simd16uint16 a16 = hadd(a16lo, a16hi);

// the 2 lanes must still be combined
return a16;
Expand All @@ -947,51 +947,44 @@ simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
* 16 bins
************************************************************/

static const simd32uint8 shifts2(_mm256_setr_epi8(
static const simd32uint8 shifts2 = simd32uint8::create<
1,
2,
4,
8,
16,
32,
64,
(char)128,
128,
1,
2,
4,
8,
16,
32,
64,
(char)128,
128,
1,
2,
4,
8,
16,
32,
64,
(char)128,
128,
1,
2,
4,
8,
16,
32,
64,
(char)128));
128>();

simd32uint8 shiftr_16(simd32uint8 x, int n) {
return simd32uint8(simd16uint16(x) >> n);
}

inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
__m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
__m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);

return simd32uint8(a1b0) + simd32uint8(a0b1);
}

// 2-bit accumulator: we can add only up to 3 elements
// on output we return 2*4-bit results
template <int N, class Preproc>
Expand All @@ -1018,7 +1011,7 @@ void compute_accu2_16(
// contains 0s for out-of-bounds elements

simd16uint16 lt8 = (v >> 3) == simd16uint16(0);
lt8.i = _mm256_xor_si256(lt8.i, _mm256_set1_epi16(0xff00));
lt8 = lt8 ^ simd16uint16(0xff00);

a1 = a1 & lt8;

Expand All @@ -1036,11 +1029,15 @@ void compute_accu2_16(
simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
simd32uint8 mask4(0x0f);

simd32uint8 a8_0 = combine_2x2(a4_0 & mask4, shiftr_16(a4_0, 4) & mask4);
simd16uint16 a8_0 = combine2x2(
(simd16uint16)(a4_0 & mask4),
(simd16uint16)(shiftr_16(a4_0, 4) & mask4));

simd32uint8 a8_1 = combine_2x2(a4_1 & mask4, shiftr_16(a4_1, 4) & mask4);
simd16uint16 a8_1 = combine2x2(
(simd16uint16)(a4_1 & mask4),
(simd16uint16)(shiftr_16(a4_1, 4) & mask4));

return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
return simd32uint8(hadd(a8_0, a8_1));
}

template <class Preproc>
Expand Down Expand Up @@ -1079,10 +1076,9 @@ simd16uint16 histogram_16(const uint16_t* data, Preproc pp, size_t n_in) {
simd16uint16 a16lo = accu8to16(a8lo);
simd16uint16 a16hi = accu8to16(a8hi);

simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
simd16uint16 a16 = hadd(a16lo, a16hi);

__m256i perm32 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
a16 = simd16uint16{simd8uint32{a16}.unzip()};

return a16;
}
Expand Down
82 changes: 82 additions & 0 deletions faiss/utils/simdlib_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ struct simd16uint16 : simd256bit {
return simd16uint16(_mm256_or_si256(i, other.i));
}

simd16uint16 operator^(simd256bit other) const {
return simd16uint16(_mm256_xor_si256(i, other.i));
}

// returns binary masks
friend simd16uint16 operator==(const simd256bit lhs, const simd256bit rhs) {
return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i));
Expand Down Expand Up @@ -255,6 +259,10 @@ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
return ge;
}

inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
return simd16uint16(_mm256_hadd_epi16(a.i, b.i));
}

// vector of 32 unsigned 8-bit integers
struct simd32uint8 : simd256bit {
simd32uint8() {}
Expand All @@ -265,6 +273,75 @@ struct simd32uint8 : simd256bit {

explicit simd32uint8(uint8_t x) : simd256bit(_mm256_set1_epi8(x)) {}

template <
uint8_t _0,
uint8_t _1,
uint8_t _2,
uint8_t _3,
uint8_t _4,
uint8_t _5,
uint8_t _6,
uint8_t _7,
uint8_t _8,
uint8_t _9,
uint8_t _10,
uint8_t _11,
uint8_t _12,
uint8_t _13,
uint8_t _14,
uint8_t _15,
uint8_t _16,
uint8_t _17,
uint8_t _18,
uint8_t _19,
uint8_t _20,
uint8_t _21,
uint8_t _22,
uint8_t _23,
uint8_t _24,
uint8_t _25,
uint8_t _26,
uint8_t _27,
uint8_t _28,
uint8_t _29,
uint8_t _30,
uint8_t _31>
static simd32uint8 create() {
return simd32uint8(_mm256_setr_epi8(
(char)_0,
(char)_1,
(char)_2,
(char)_3,
(char)_4,
(char)_5,
(char)_6,
(char)_7,
(char)_8,
(char)_9,
(char)_10,
(char)_11,
(char)_12,
(char)_13,
(char)_14,
(char)_15,
(char)_16,
(char)_17,
(char)_18,
(char)_19,
(char)_20,
(char)_21,
(char)_22,
(char)_23,
(char)_24,
(char)_25,
(char)_26,
(char)_27,
(char)_28,
(char)_29,
(char)_30,
(char)_31));
}

explicit simd32uint8(simd256bit x) : simd256bit(x) {}

explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
Expand Down Expand Up @@ -412,6 +489,11 @@ struct simd8uint32 : simd256bit {
void set1(uint32_t x) {
i = _mm256_set1_epi32((int)x);
}

simd8uint32 unzip() const {
return simd8uint32(_mm256_permutevar8x32_epi32(
i, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)));
}
};

struct simd8float32 : simd256bit {
Expand Down
106 changes: 106 additions & 0 deletions faiss/utils/simdlib_emulated.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ struct simd16uint16 : simd256bit {
});
}

simd16uint16 operator^(const simd256bit& other) const {
return binary_func(
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
return a ^ b;
});
}

// returns binary masks
simd16uint16 operator==(const simd16uint16& other) const {
return binary_func(*this, other, [](uint16_t a, uint16_t b) {
Expand Down Expand Up @@ -288,6 +295,30 @@ inline uint32_t cmp_le32(
return gem;
}

// hadd does not cross lanes
inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
simd16uint16 c;
c.u16[0] = a.u16[0] + a.u16[1];
c.u16[1] = a.u16[2] + a.u16[3];
c.u16[2] = a.u16[4] + a.u16[5];
c.u16[3] = a.u16[6] + a.u16[7];
c.u16[4] = b.u16[0] + b.u16[1];
c.u16[5] = b.u16[2] + b.u16[3];
c.u16[6] = b.u16[4] + b.u16[5];
c.u16[7] = b.u16[6] + b.u16[7];

c.u16[8] = a.u16[8] + a.u16[9];
c.u16[9] = a.u16[10] + a.u16[11];
c.u16[10] = a.u16[12] + a.u16[13];
c.u16[11] = a.u16[14] + a.u16[15];
c.u16[12] = b.u16[8] + b.u16[9];
c.u16[13] = b.u16[10] + b.u16[11];
c.u16[14] = b.u16[12] + b.u16[13];
c.u16[15] = b.u16[14] + b.u16[15];

return c;
}

// vector of 32 unsigned 8-bit integers
struct simd32uint8 : simd256bit {
simd32uint8() {}
Expand All @@ -299,6 +330,75 @@ struct simd32uint8 : simd256bit {
explicit simd32uint8(uint8_t x) {
set1(x);
}
template <
uint8_t _0,
uint8_t _1,
uint8_t _2,
uint8_t _3,
uint8_t _4,
uint8_t _5,
uint8_t _6,
uint8_t _7,
uint8_t _8,
uint8_t _9,
uint8_t _10,
uint8_t _11,
uint8_t _12,
uint8_t _13,
uint8_t _14,
uint8_t _15,
uint8_t _16,
uint8_t _17,
uint8_t _18,
uint8_t _19,
uint8_t _20,
uint8_t _21,
uint8_t _22,
uint8_t _23,
uint8_t _24,
uint8_t _25,
uint8_t _26,
uint8_t _27,
uint8_t _28,
uint8_t _29,
uint8_t _30,
uint8_t _31>
static simd32uint8 create() {
simd32uint8 ret;
ret.u8[0] = _0;
ret.u8[1] = _1;
ret.u8[2] = _2;
ret.u8[3] = _3;
ret.u8[4] = _4;
ret.u8[5] = _5;
ret.u8[6] = _6;
ret.u8[7] = _7;
ret.u8[8] = _8;
ret.u8[9] = _9;
ret.u8[10] = _10;
ret.u8[11] = _11;
ret.u8[12] = _12;
ret.u8[13] = _13;
ret.u8[14] = _14;
ret.u8[15] = _15;
ret.u8[16] = _16;
ret.u8[17] = _17;
ret.u8[18] = _18;
ret.u8[19] = _19;
ret.u8[20] = _20;
ret.u8[21] = _21;
ret.u8[22] = _22;
ret.u8[23] = _23;
ret.u8[24] = _24;
ret.u8[25] = _25;
ret.u8[26] = _26;
ret.u8[27] = _27;
ret.u8[28] = _28;
ret.u8[29] = _29;
ret.u8[30] = _30;
ret.u8[31] = _31;
return ret;
}

explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}

Expand Down Expand Up @@ -512,6 +612,12 @@ struct simd8uint32 : simd256bit {
u32[i] = x;
}
}

simd8uint32 unzip() const {
const uint32_t ret[] = {
u32[0], u32[2], u32[4], u32[6], u32[1], u32[3], u32[5], u32[7]};
return simd8uint32{ret};
}
};

struct simd8float32 : simd256bit {
Expand Down
Loading