Skip to content

Commit

Permalink
[CPU][ARM] Fixed cvt_copy fast path for mha_single_token_kernel (#28265)
Browse files Browse the repository at this point in the history
### Details:
- This PR fixes incorrect cvt_copy rountine behavior inside
mha_single_token kenrel on ARM platforms. In case
__ARM_FEATURE_FP16_VECTOR_ARITHMETIC is defined on the system and fp32
inference scalar code path is chosen.
- Additionally cvt_copy impl is refactored via template specialization
for better readability
- Follow-up after #28182
  • Loading branch information
dmitry-gorokhov authored Jan 8, 2025
1 parent 6c610e9 commit b27eefb
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,50 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
auto vb = mm256_uni_loadu_ps(src + i);
mm256_uni_storeu_ps(dst + i, vb);
}
#elif defined(OPENVINO_ARCH_ARM64)
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}

#if defined(OPENVINO_ARCH_ARM64)
# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
if (std::is_same<TA, ov::float16>::value && std::is_same<TB, ov::float16>::value) {
# if defined(HAVE_SVE)
size_t inc = vec_len_f16_sve();
svbool_t pg = svptrue_b16();
template <>
void cvt_copy(ov::float16* dst, ov::float16* src, size_t n) {
size_t i = 0;
size_t inc = vec_len_f16_sve();
svbool_t pg = svptrue_b16();

while (i < n) {
if (n - i < vec_len_f16_sve()) {
inc = n - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
svfloat16_t b1 = svld1_f16(pg, reinterpret_cast<const float16_t*>(src + i));
svst1_f16(pg, reinterpret_cast<float16_t*>(dst + i), b1);
i += inc;
}
# else
for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) {
auto vb1 = vld1q_f16(reinterpret_cast<const float16_t*>(src + i));
vst1q_f16(reinterpret_cast<float16_t*>(dst + i), vb1);
while (i < n) {
if (n - i < vec_len_f16_sve()) {
inc = n - i;
pg = svwhilelt_b16(0, static_cast<int>(inc));
}
# endif
svfloat16_t b1 = svld1_f16(pg, reinterpret_cast<const float16_t*>(src + i));
svst1_f16(pg, reinterpret_cast<float16_t*>(dst + i), b1);
i += inc;
}
# else
# if defined(HAVE_SVE)
}
# else // NEON
template <>
void cvt_copy(ov::float16* dst, ov::float16* src, size_t n) {
size_t i = 0;
for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) {
auto vb1 = vld1q_f16(reinterpret_cast<const float16_t*>(src + i));
vst1q_f16(reinterpret_cast<float16_t*>(dst + i), vb1);
}
for (; i < n; i++) {
dst[i] = src[i];
}
}
# endif // defined(HAVE_SVE)
# endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

# if defined(HAVE_SVE)
template <>
void cvt_copy(float* dst, float* src, size_t n) {
size_t i = 0;
auto _dst = reinterpret_cast<float32_t*>(dst);
size_t inc = vec_len_f32_sve();
svbool_t pg = svptrue_b32();
Expand All @@ -100,20 +119,21 @@ void cvt_copy(TA* dst, TB* src, size_t n) {
svst1_f32(pg, _dst + i, b1);
i += inc;
}
# else
if (std::is_same<TA, float>::value && std::is_same<TB, float>::value) {
for (; i + vec_len_f32_neon <= n; i += vec_len_f32_neon) {
float32x4_t vb1 = __vld1q_f32(src + i);
__vst1q_f32(dst + i, vb1);
}
}
# else // NEON
template <>
void cvt_copy(float* dst, float* src, size_t n) {
size_t i = 0;
for (; i + vec_len_f32_neon <= n; i += vec_len_f32_neon) {
float32x4_t vb1 = __vld1q_f32(src + i);
__vst1q_f32(dst + i, vb1);
}
# endif
# endif
#endif
for (; i < n; i++) {
dst[i] = src[i];
}
}
# endif // defined(HAVE_SVE)
#endif // defined(OPENVINO_ARCH_ARM64)

template <typename T>
static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scale, float* zp) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ const std::vector<std::vector<InputShape>> inputShapes = {
// B, H, L0, S
{{-1, 8, -1, 64}, {{4, 8, 0, 64}, {4, 8, 10, 64}, {4, 8, 11, 64}, {4, 8, 12, 64}, {4, 8, 13, 64}}},
},
// big batch to check cvt_copy fast-path inside mha_single_token_kernel
{
// B, H, L1, S
{{-1, 8, -1, 64}, {{129, 8, 10, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}}},
// B, H, L0, S
{{-1, 8, -1, 64}, {{129, 8, 0, 64}, {129, 8, 10, 64}, {129, 8, 11, 64}, {129, 8, 12, 64}, {129, 8, 13, 64}}},
},
};

INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ const std::vector<std::vector<InputShape>> inputShapes = {
// B, H, L0, S
{{-1, 8, -1, 64}, {{4, 8, 0, 64}, {4, 8, 10, 64}, {4, 8, 11, 64}, {4, 8, 12, 64}, {4, 8, 13, 64}}},
},
// big batch to check cvt_copy fast-path inside mha_single_token_kernel
{
// B, H, L1, S
{{-1, 8, -1, 64}, {{129, 8, 10, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}}},
// B, H, L0, S
{{-1, 8, -1, 64}, {{129, 8, 0, 64}, {129, 8, 10, 64}, {129, 8, 11, 64}, {129, 8, 12, 64}, {129, 8, 13, 64}}},
},
};

INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ const std::vector<std::vector<InputShape>> inputShapes = {
// B, H, L0, S
{{-1, 8, -1, 64}, {{4, 8, 0, 64}, {4, 8, 10, 64}, {4, 8, 11, 64}, {4, 8, 12, 64}, {4, 8, 13, 64}}},
},
// big batch to check cvt_copy fast-path inside mha_single_token_kernel
{
// B, H, L1, S
{{-1, 8, -1, 64}, {{129, 8, 10, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}, {129, 8, 1, 64}}},
// B, H, L0, S
{{-1, 8, -1, 64}, {{129, 8, 0, 64}, {129, 8, 10, 64}, {129, 8, 11, 64}, {129, 8, 12, 64}, {129, 8, 13, 64}}},
},
};

INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
Expand Down

0 comments on commit b27eefb

Please sign in to comment.