From 1aa80c38d1e14cc78ae30296bd6b874574fd4a34 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Mon, 13 May 2024 13:11:05 -0400 Subject: [PATCH] [ARM NEON] Get rid of redundant instructions in ScalarQuantizer Signed-off-by: Alexandr Guzhva --- faiss/impl/ScalarQuantizer.cpp | 58 +++++++++++++--------------------- 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 07d77d5622..e3b29e621d 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -101,8 +101,7 @@ struct Codec8bit { } float32x4_t res1 = vld1q_f32(result); float32x4_t res2 = vld1q_f32(result + 4); - float32x4x2_t res = vzipq_f32(res1, res2); - return vuzpq_f32(res.val[0], res.val[1]); + return {res1, res2}; } #endif }; @@ -153,8 +152,7 @@ struct Codec4bit { } float32x4_t res1 = vld1q_f32(result); float32x4_t res2 = vld1q_f32(result + 4); - float32x4x2_t res = vzipq_f32(res1, res2); - return vuzpq_f32(res.val[0], res.val[1]); + return {res1, res2}; } #endif }; @@ -266,8 +264,7 @@ struct Codec6bit { } float32x4_t res1 = vld1q_f32(result); float32x4_t res2 = vld1q_f32(result + 4); - float32x4x2_t res = vzipq_f32(res1, res2); - return vuzpq_f32(res.val[0], res.val[1]); + return {res1, res2}; } #endif }; @@ -345,16 +342,14 @@ struct QuantizerTemplate : QuantizerTemplate { FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { float32x4x2_t xi = Codec::decode_8_components(code, i); - float32x4x2_t res = vzipq_f32( - vfmaq_f32( + return {vfmaq_f32( vdupq_n_f32(this->vmin), xi.val[0], vdupq_n_f32(this->vdiff)), vfmaq_f32( vdupq_n_f32(this->vmin), xi.val[1], - vdupq_n_f32(this->vdiff))); - return vuzpq_f32(res.val[0], res.val[1]); + vdupq_n_f32(this->vdiff))}; } }; @@ -431,10 +426,8 @@ struct QuantizerTemplate : QuantizerTemplate { float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i); float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i); - float32x4x2_t res = vzipq_f32( - vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), - vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])); - return vuzpq_f32(res.val[0], res.val[1]); + return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]), + vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])}; } }; @@ -496,10 +489,9 @@ struct QuantizerFP16<8> : QuantizerFP16<1> { FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { - uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i)); - return vzipq_f32( - vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), - vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))); + uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i)); + return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), + vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))}; } }; #endif @@ -568,8 +560,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { } float32x4_t res1 = vld1q_f32(result); float32x4_t res2 = vld1q_f32(result + 4); - float32x4x2_t res = vzipq_f32(res1, res2); - return vuzpq_f32(res.val[0], res.val[1]); + return {res1, res2}; } }; @@ -868,7 +859,7 @@ struct SimilarityL2<8> { float32x4x2_t accu8; FAISS_ALWAYS_INLINE void begin_8() { - accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)); + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; yi = y; } @@ -882,8 +873,7 @@ struct SimilarityL2<8> { float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); - accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + accu8 = {accu8_0, accu8_1}; } FAISS_ALWAYS_INLINE void add_8_components_2( @@ -895,8 +885,7 @@ struct SimilarityL2<8> { float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0); float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1); - float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); - accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + accu8 = {accu8_0, accu8_1}; } FAISS_ALWAYS_INLINE float result_8() { @@ -996,7 +985,7 @@ struct SimilarityIP<8> { float32x4x2_t accu8; FAISS_ALWAYS_INLINE void begin_8() { - accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)); + accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)}; yi = y; } @@ -1006,8 +995,7 @@ struct SimilarityIP<8> { float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]); float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]); - float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); - accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + accu8 = {accu8_0, accu8_1}; } FAISS_ALWAYS_INLINE void add_8_components_2( @@ -1015,19 +1003,17 @@ struct SimilarityIP<8> { float32x4x2_t x2) { float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]); float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]); - float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1); - accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]); + accu8 = {accu8_0, accu8_1}; } FAISS_ALWAYS_INLINE float result_8() { - float32x4x2_t sum_tmp = vzipq_f32( + float32x4x2_t sum = { vpaddq_f32(accu8.val[0], accu8.val[0]), - vpaddq_f32(accu8.val[1], accu8.val[1])); - float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]); - float32x4x2_t sum2_tmp = vzipq_f32( + vpaddq_f32(accu8.val[1], accu8.val[1])}; + + float32x4x2_t sum2 = { vpaddq_f32(sum.val[0], sum.val[0]), - vpaddq_f32(sum.val[1], sum.val[1])); - float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]); + vpaddq_f32(sum.val[1], sum.val[1])}; return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0); } };