diff --git a/cmake/developer_package/compile_flags/os_flags.cmake b/cmake/developer_package/compile_flags/os_flags.cmake index fdfd7211c8e815..660fd6160893ae 100644 --- a/cmake/developer_package/compile_flags/os_flags.cmake +++ b/cmake/developer_package/compile_flags/os_flags.cmake @@ -4,6 +4,7 @@ include(ProcessorCount) include(CheckCXXCompilerFlag) +include(CheckCXXSourceCompiles) # # ov_disable_deprecated_warnings() @@ -91,6 +92,50 @@ macro(ov_dev_package_no_errors) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${ov_c_cxx_dev_no_errors}") endmacro() +# +# ov_check_compiler_supports_sve(flags) +# +# Checks whether CXX compiler for passed language supports SVE code compilation +# +macro(ov_check_compiler_supports_sve flags) + # Code to compile + set(SVE_CODE " + #include + int main() { + svfloat64_t a; + a = svdup_n_f64(0); + return 0; + }") + + # Save the current state of required flags + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + + # Set the flags necessary for compiling the test code with SVE support + set(CMAKE_REQUIRED_FLAGS "${CMAKE_CXX_FLAGS_INIT} ${flags}") + + # Check if the source code compiles with the given flags for C++ + CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" CXX_HAS_SVE) + + # If the compilation test is successful, set appropriate variables indicating support + if(CXX_HAS_SVE) + set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + set(CXX_SVE_FOUND TRUE CACHE BOOL "CXX SVE support") + set(CXX_SVE_FLAGS "${flags}" CACHE STRING "CXX SVE flags") + endif() + + # Restore the original state of required flags + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + # If the compilation test fails, indicate that the support is not found + if(NOT CXX_SVE_FOUND) + set(CXX_SVE_FOUND FALSE CACHE BOOL "CXX SVE support") + set(CXX_SVE_FLAGS "" CACHE STRING "CXX SVE flags") + endif() + + # Mark the variables as advanced to hide them in the default CMake GUI + mark_as_advanced(CXX_SVE_FOUND CXX_SVE_FLAGS) +endmacro() + # # ov_sse42_optimization_flags() # @@ -208,6 +253,49 @@ macro(ov_arm_neon_fp16_optimization_flags flags) endif() endmacro() +# +# ov_arm_sve_optimization_flags() +# +macro(ov_arm_sve_optimization_flags flags) + # Check for compiler SVE support + ov_check_compiler_supports_sve("-march=armv8-a+sve") + + if(OV_COMPILER_IS_INTEL_LLVM) + message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # nothing should be required here + elseif(ANDROID) + if(ANDROID_ABI STREQUAL "arm64-v8a") + set(${flags} -Wno-unused-command-line-argument) + if(CXX_SVE_FOUND) + list(APPEND ${flags} -march=armv8-a+sve) + else() + message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}") + endif() + else() + message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}") + endif() + else() + if(AARCH64) + set(${flags} -O2) + + # Add flag for SVE if supported + if(CXX_SVE_FOUND) + list(APPEND ${flags} -march=armv8-a+sve) + endif() + if(NOT CMAKE_CL_64) + list(APPEND ${flags} -ftree-vectorize) + endif() + + set(${flags} ${${flags}}) + elseif(ARM) + message(WARNING "SVE is not supported on 32-bit ARM architectures.") + else() + message(WARNING "SVE is not supported by architecture ${CMAKE_SYSTEM_PROCESSOR}") + endif() + endif() +endmacro() + # # ov_disable_all_warnings() # diff --git a/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake b/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake index c33d64635eb10b..fd534f3e600bfe 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake @@ -18,6 +18,7 @@ set(_CPU_CHECK_ANY "true") set(_CPU_CHECK_SSE42 "with_cpu_x86_sse42()") set(_CPU_CHECK_AVX "with_cpu_x86_avx()") set(_CPU_CHECK_NEON_FP16 "with_cpu_neon_fp16()") +set(_CPU_CHECK_SVE "with_cpu_sve()") set(_CPU_CHECK_AVX2 "with_cpu_x86_avx2()") set(_CPU_CHECK_AVX512F "with_cpu_x86_avx512f()") diff --git a/cmake/developer_package/cross_compile/cross_compiled_func.cmake b/cmake/developer_package/cross_compile/cross_compiled_func.cmake index 1e92fe3bfdaf8c..962aa5d373a4db 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_func.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_func.cmake @@ -3,7 +3,7 @@ # ## list of available instruction sets -set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16) +set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16 SVE) set(_ACCEPTED_ARCHS_ANY "^(ANY)$") set(_ACCEPTED_ARCHS_SSE42 "^(ANY|SSE42)$") @@ -11,6 +11,7 @@ set(_ACCEPTED_ARCHS_AVX "^(ANY|SSE42|AVX)$") set(_ACCEPTED_ARCHS_AVX2 "^(ANY|SSE42|AVX|AVX2)$") set(_ACCEPTED_ARCHS_AVX512F "^(ANY|SSE42|AVX|AVX2|AVX512F)$") set(_ACCEPTED_ARCHS_NEON_FP16 "^(ANY|NEON_FP16)$") +set(_ACCEPTED_ARCHS_SVE "^(ANY|SVE)$") ## Arch specific definitions set(_DEFINE_ANY "") @@ -19,12 +20,14 @@ set(_DEFINE_AVX "HAVE_AVX" ${_DEFINE_SSE42}) set(_DEFINE_AVX2 "HAVE_AVX2" ${_DEFINE_AVX}) set(_DEFINE_AVX512F "HAVE_AVX512F" ${_DEFINE_AVX2}) set(_DEFINE_NEON_FP16 "HAVE_NEON_FP16" ${_DEFINE_ANY}) +set(_DEFINE_SVE "HAVE_SVE" ${_DEFINE_SVE}) ## Arch specific compile options ov_avx512_optimization_flags(_FLAGS_AVX512F) ov_avx2_optimization_flags (_FLAGS_AVX2) ov_sse42_optimization_flags (_FLAGS_SSE42) ov_arm_neon_fp16_optimization_flags(_FLAGS_NEON_FP16) +ov_arm_sve_optimization_flags(_FLAGS_SVE) set(_FLAGS_AVX "") ## TBD is not defined for OV project yet set(_FLAGS_ANY "") ## @@ -185,6 +188,8 @@ endfunction() function(_currently_requested_top_arch VAR) if(ENABLE_NEON_FP16) set(RES NEON_FP16) + elseif(ENABLE_SVE) + set(RES SVE) elseif(ENABLE_AVX512F) set(RES AVX512F) elseif(ENABLE_AVX2) diff --git a/cmake/developer_package/features.cmake b/cmake/developer_package/features.cmake index 8d1f3696c6759c..ae5313cea8a8b4 100644 --- a/cmake/developer_package/features.cmake +++ b/cmake/developer_package/features.cmake @@ -51,6 +51,8 @@ ov_dependent_option (ENABLE_AVX512F "Enable AVX512 optimizations" ON "X86_64 OR ov_dependent_option(ENABLE_NEON_FP16 "Enable ARM FP16 optimizations" ON "AARCH64" OFF) +ov_dependent_option(ENABLE_SVE "Enable SVE optimizations" ON "AARCH64" OFF) + # Type of build, we add this as an explicit option to default it to ON get_property(BUILD_SHARED_LIBS_DEFAULT GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS) ov_option (BUILD_SHARED_LIBS "Build as a shared library" ${BUILD_SHARED_LIBS_DEFAULT}) diff --git a/src/inference/dev_api/openvino/runtime/system_conf.hpp b/src/inference/dev_api/openvino/runtime/system_conf.hpp index 59d56dfdd49d73..bebc2014ab8028 100644 --- a/src/inference/dev_api/openvino/runtime/system_conf.hpp +++ b/src/inference/dev_api/openvino/runtime/system_conf.hpp @@ -83,6 +83,13 @@ OPENVINO_RUNTIME_API bool with_cpu_x86_sse42(); */ OPENVINO_RUNTIME_API bool with_cpu_neon_fp16(); +/** + * @brief Checks whether CPU supports ARM SVE capability + * @ingroup ov_dev_api_system_conf + * @return `True` if ARM SVE instructions are available, `false` otherwise + */ +OPENVINO_RUNTIME_API bool with_cpu_sve(); + /** * @brief Checks whether CPU supports AVX capability * @ingroup ov_dev_api_system_conf diff --git a/src/inference/src/system_conf.cpp b/src/inference/src/system_conf.cpp index 27c671d07ad5c9..3227b1a3034903 100644 --- a/src/inference/src/system_conf.cpp +++ b/src/inference/src/system_conf.cpp @@ -22,6 +22,7 @@ # include # define ARM_COMPUTE_CPU_FEATURE_HWCAP_FPHP (1 << 9) # define ARM_COMPUTE_CPU_FEATURE_HWCAP_ASIMDHP (1 << 10) +# define ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE (1 << 24) #elif defined(__APPLE__) && defined(__aarch64__) # include # include @@ -114,6 +115,10 @@ bool with_cpu_neon_fp16() { return false; } +bool with_cpu_sve() { + return false; +} + #else // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 bool with_cpu_x86_sse42() { @@ -173,6 +178,20 @@ bool with_cpu_neon_fp16() { return false; # endif } +bool with_cpu_sve() { +# if !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + !defined(__arm__) && defined(__aarch64__) + const uint32_t hwcaps = getauxval(AT_HWCAP); + return hwcaps & ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE; +# elif !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + !defined(__aarch64__) && defined(__arm__) + return false; +# elif defined(__aarch64__) && defined(__APPLE__) + return false; +# else + return false; +# endif +} #endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 bool check_open_mp_env_vars(bool include_omp_num_threads) { diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index eb56a3fb39503e..aa6ce49a051e00 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -278,6 +278,24 @@ target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $) +# ARCH lists for softmax.cpp and mha_single_token.cpp +# Based on result of above calls, decide whether to add SVE +set(SOFTMAX_ARCH_LIST AVX512F AVX2) +set(MHA_SINGLE_TOKEN_ARCH_LIST AVX512F AVX2) + +if(ENABLE_NEON_FP16) + list(APPEND SOFTMAX_ARCH_LIST NEON_FP16) + list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST NEON_FP16) +endif() + +if(ENABLE_SVE) + list(APPEND SOFTMAX_ARCH_LIST SVE) + list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST SVE) +endif() + +list(APPEND SOFTMAX_ARCH_LIST ANY) +list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST ANY) + # Cross compiled function # TODO: The same for proposal, proposalONNX, topk cross_compiled_file(${TARGET_NAME} @@ -288,14 +306,14 @@ cross_compiled_file(${TARGET_NAME} NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 NEON_FP16 ANY + ARCH ${SOFTMAX_ARCH_LIST} src/nodes/kernels/scaled_attn/softmax.cpp API src/nodes/kernels/scaled_attn/softmax.hpp NAME attn_softmax NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 NEON_FP16 ANY + ARCH ${MHA_SINGLE_TOKEN_ARCH_LIST} src/nodes/kernels/scaled_attn/mha_single_token.cpp API src/nodes/kernels/scaled_attn/mha_single_token.hpp NAME mha_single_token diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index 4e14cf5894b04d..63cbbb4464ee92 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -17,6 +17,9 @@ #endif #if defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) +# include "arm_sve.h" +# endif # include "arm_neon.h" #endif @@ -35,6 +38,10 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float); static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float); static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); +#if defined(HAVE_SVE) +static constexpr size_t vec_len_f32_sve = svcntw(); +#endif + #ifdef HAVE_AVX512F inline __m512 cvt_bf16_to_fp32(const __m256i src) { __m512i y = _mm512_cvtepu16_epi32(src); @@ -250,6 +257,79 @@ inline void hmin(__m256& x) { #endif #ifdef OPENVINO_ARCH_ARM64 +# if defined(HAVE_SVE) +inline svfloat32_t exp_ps_sve(svbool_t& pg, svfloat32_t& src) { + // Constants + const auto log2_e = svdup_n_f32(1.4426950409f); + const auto ln2 = svdup_n_f32(0.6931473921f); + const auto half_ln2_sq = svdup_n_f32(0.2413862043f); + const auto not_mask17 = svdup_n_u32(~((1u << 17) - 1)); + const auto one = svdup_n_f32(1.0f); + + // Algorithm starts here + svfloat32_t t0 = svmul_f32_z(pg, src, log2_e); // y = x * log2(e) + svfloat32_t t1 = svrintm_f32_z(pg, t0); // rount to int (float) + svint32_t t2 = svcvt_s32_f32_z(pg, t1); // n + + t1 = svsub_f32_z(pg, t0, t1); // a = y - floor(y) + t1 = svadd_f32_z(pg, t1, one); // b = a + 1 + + svuint32_t t3 = svlsr_n_u32_z(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32) + svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v) + t4 = svscale_f32_z(pg, t4, t2); // fexpa(v) * 2^(n) + + // and_(t2.d, t1.d, not_mask17.d) + svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_z(pg, svreinterpret_u32_f32(t1), not_mask17)); + t5 = svsub_f32_z(pg, t1, t5); // z + t0 = svmla_f32_z(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z + t0 = svmla_f32_z(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z) + t0 = svmul_f32_z(pg, t0, t4); // Final result + + return t0; +} +inline svfloat32_t exp_ps_sve_legacy(svbool_t& pg, svfloat32_t& src) { + const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); + const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); + const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); + const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); + const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); + + const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto one = svdup_n_f32(1.0f); // 1 + const auto two = svdup_n_f32(2.0f); // 2 + const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b)); + const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(0xbf317200)); + const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e)); + + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + + const auto z = svmla_f32_z(pg, shift, src, inv_ln2); + auto n = svsub_f32_z(pg, z, shift); + n = svsub_f32_z(pg, n, one); + const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + + const auto r_hi = svmla_f32_z(pg, src, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + const auto r2 = svmul_f32_z(pg, r, r); + + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + + auto poly = svmla_f32_z(pg, scale, p12345, scale); + poly = svmul_f32_z(pg, poly, two); + + poly = svsel_f32(svcmplt_f32(pg, src, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, src, max_input), inf, poly); + + return poly; +} +# endif inline float32x4_t exp_ps_neon_f32(const float32x4_t& src) { const auto c1 = vreinterpretq_f32_u32(vdupq_n_u32(0x3f7ffff6)); const auto c2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3efffedb)); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index f2180b5314cc07..5a6f0d66f1f221 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -20,6 +21,9 @@ #include "softmax_kernel.hpp" #if defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) +# include +# endif # include #endif @@ -59,19 +63,35 @@ void cvt_copy(TA* dst, TB* src, size_t n) { mm256_uni_storeu_ps(dst + i, vb); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + auto _dst = reinterpret_cast(dst); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < n) { + if (n - i < vec_len_f32_sve) { + inc = n - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t b1 = svld1_f32(pg, src + i); + svst1_f32(pg, _dst + i, b1); + i += inc; + } +# else if (std::is_same::value && std::is_same::value) { for (; i + vec_len_f32_neon <= n; i += vec_len_f32_neon) { float32x4_t vb1 = __vld1q_f32(src + i); __vst1q_f32(dst + i, vb1); } } -# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) if (std::is_same::value && std::is_same::value) { for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) { auto vb1 = vld1q_f16(reinterpret_cast(src + i)); vst1q_f16(reinterpret_cast(dst + i), vb1); } } +# endif # endif #endif for (; i < n; i++) { @@ -99,6 +119,27 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal mm256_uni_storeu_ps(out + i, v_out); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + auto _v = reinterpret_cast(v); + svfloat32_t attn_w_vec_fp32 = svdup_n_f32(weight); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < S) { + if (S - i < vec_len_f32_sve) { + inc = S - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t v_value = svld1_f32(pg, _v + i); + svfloat32_t v_out = svld1_f32(pg, out + i); + + // svmla with merging to preserve inactive lane values when there's ... + // fewer than vec_len elements left + v_out = svmla_f32_m(pg, v_out, attn_w_vec_fp32, v_value); + svst1_f32(pg, out + i, v_out); + i += inc; + } +# else float32x4_t attn_w_vec_fp32 = vdupq_n_f32(weight); for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) { float32x4_t v_value = __vld1q_f32(v + i); @@ -106,6 +147,7 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal v_out = vmlaq_f32(v_out, attn_w_vec_fp32, v_value); __vst1q_f32(out + i, v_out); } +# endif #endif for (; i < S; i++) { out[i] += weight * v[i]; @@ -356,7 +398,50 @@ static float sum_q_head(T* a, size_t n) { hsum(vsum0); sum = _mm256_cvtss_f32(vsum0); #elif defined(OPENVINO_ARCH_ARM64) - size_t vec_len_f32_neon = 4; +# if defined(HAVE_SVE) + svfloat32_t sum0 = svdup_n_f32(0.0f); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + svbool_t pg = svptrue_b32(); + + for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + svfloat32_t a0 = svld1_f32(pg, a + i); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); + svfloat32_t a2 = svld1_f32(pg, a + i + vec_len_f32_sve * 2); + svfloat32_t a3 = svld1_f32(pg, a + i + vec_len_f32_sve * 3); + + sum0 = svadd_f32_z(pg, a0, sum0); + sum1 = svadd_f32_z(pg, a1, sum1); + sum2 = svadd_f32_z(pg, a2, sum2); + sum3 = svadd_f32_z(pg, a3, sum3); + } + if (i + 2 * vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, a + i); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); + + sum0 = svadd_f32_z(pg, a0, sum0); + sum1 = svadd_f32_z(pg, a1, sum1); + i += 2 * vec_len_f32_sve; + } + if (i + vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, a + i); + sum0 = svadd_f32_z(pg, a0, sum0); + i += vec_len_f32_sve; + } + // Process tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); + svfloat32_t a0 = svld1_f32(pg_rem, a + i); + sum0 = svadd_f32_m(pg_rem, sum0, a0); + i = n; + } + float32_t sum_0 = svaddv_f32(pg, sum0); + float32_t sum_1 = svaddv_f32(pg, sum1); + float32_t sum_2 = svaddv_f32(pg, sum2); + float32_t sum_3 = svaddv_f32(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); +# else float32x4_t vsum0 = vdupq_n_f32(0.0f); float32x4_t vsum1 = vdupq_n_f32(0.0f); float32x4_t vsum2 = vdupq_n_f32(0.0f); @@ -396,8 +481,8 @@ static float sum_q_head(T* a, size_t n) { sum_low = vadd_f32(sum_low, sum_high); sum_low = vpadd_f32(sum_low, sum_low); sum = vget_lane_f32(sum_low, 0); +# endif #endif - for (; i < n; i++) { float tmp = a[i]; sum += tmp; @@ -496,6 +581,63 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* sum = _mm256_cvtss_f32(vsum0); #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + svbool_t pg = svptrue_b32(); + svfloat32_t sum0 = svdup_n_f32(0.0f); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + + auto _a = reinterpret_cast(a); + auto _b = reinterpret_cast(b); + + for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); + svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len_f32_sve * 2); + svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len_f32_sve * 3); + + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); + svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len_f32_sve * 2); + svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len_f32_sve * 3); + + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + sum2 = svmla_f32_z(pg, sum2, a2, b2); + sum3 = svmla_f32_z(pg, sum3, a3, b3); + } + if (i + 2 * vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); + + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); + + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + i += 2 * vec_len_f32_sve; + } + if (i + vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t b0 = svld1_f32(pg, _b + i); + sum0 = svmla_f32_z(pg, sum0, a0, b0); + i += vec_len_f32_sve; + } + // Process the tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); + svfloat32_t a0 = svld1_f32(pg_rem, _a + i); + svfloat32_t b0 = svld1_f32(pg_rem, _b + i); + sum0 = svmla_f32_m(pg_rem, sum0, a0, b0); + i = n; + } + float32_t sum_0 = svaddv_f32(pg, sum0); + float32_t sum_1 = svaddv_f32(pg, sum1); + float32_t sum_2 = svaddv_f32(pg, sum2); + float32_t sum_3 = svaddv_f32(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); +# else float32x4_t vsum0 = vdupq_n_f32(0.0f); float32x4_t vsum1 = vdupq_n_f32(0.0f); float32x4_t vsum2 = vdupq_n_f32(0.0f); @@ -542,8 +684,8 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* float32x2_t temp_sum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)); temp_sum = vpadd_f32(temp_sum, temp_sum); sum = vget_lane_f32(temp_sum, 0); +# endif #endif - for (; i < n; i++) { sum += a[i] * b[i]; } @@ -794,6 +936,28 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str mm256_uni_storeu_ps(dst + i, result_vec_fp32); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + auto _dst = reinterpret_cast(dst); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < S) { + if (S - i < vec_len_f32_sve) { + inc = S - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + auto* src = temp + i; + auto result_vec_fp32 = svdup_n_f32(0.0f); + + for (size_t m = 0; m < M; m++) { + auto o_vec_fp32 = svld1_f32(pg, src); + result_vec_fp32 = svadd_f32_m(pg, result_vec_fp32, o_vec_fp32); + src += temp_stride; + } + svst1_f32(pg, _dst + i, result_vec_fp32); + i += inc; + } +# else for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) { auto* src = temp + i; auto result_vec_fp32 = vdupq_n_f32(0.0f); @@ -804,6 +968,7 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } __vst1q_f32(dst + i, result_vec_fp32); } +# endif #endif for (; i < S; i++) { auto* src = temp + i; @@ -1262,6 +1427,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, OPENVINO_THROW("Unsupported precision: ", query.get_precision()); } } + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index 48b92b53fa2727..35aab5b59c7d0e 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -12,6 +12,9 @@ #include "openvino/core/type/element_type.hpp" #if defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) +# include "arm_sve.h" +# endif # include "arm_neon.h" #endif @@ -657,6 +660,28 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& hsum(v_sum); sum = _mm256_cvtss_f32(v_sum); #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + svfloat32_t v_a; + svfloat32_t v_max = svdup_n_f32(max); + svfloat32_t v_sum = svdup_n_f32(0.0f); + size_t vec_len_f32_sve = svcntw(); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < size) { + if (size - i < vec_len_f32_sve) { + inc = size - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + v_a = svld1_f32(pg, a + i); + v_a = svsub_f32_z(pg, v_a, v_max); + v_a = exp_ps_sve(pg, v_a); + v_sum = svadd_f32_m(pg, v_sum, v_a); + svst1_f32(pg, a + i, v_a); + i += inc; + } + sum = svaddv_f32(svptrue_b32(), v_sum); +# else float32x4_t v_a; float32x4_t v_max = vdupq_n_f32(max); float32x4_t v_sum = vdupq_n_f32(0.0f); @@ -670,7 +695,7 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& i += vec_len_f32_neon; } sum = vaddvq_f32(v_sum); - +# endif #endif for (; i < size; i++) { a[i] = exp(a[i] - max); @@ -781,6 +806,22 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ i += (size - i); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + svfloat32_t v_scale = svdup_n_f32(val); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < size) { + if (size - i < vec_len_f32_sve) { + inc = size - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t v_a = svld1_f32(pg, a + i); + v_a = svmul_f32_z(pg, v_a, v_scale); + svst1_f32(pg, a_dst + i, v_a); + i += inc; + } +# else float32x4_t v_scale = vdupq_n_f32(val); while (i + vec_len_f32_neon <= size) { float32x4_t v_a = vld1q_f32(a + i); @@ -788,6 +829,7 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ vst1q_f32(a_dst + i, v_a); i += vec_len_f32_neon; } +# endif #endif for (; i < size; i++) { a_dst[i] = a[i] * val; @@ -972,7 +1014,7 @@ inline void attn_softmax_kernel(float* a, // divide sum float scalar = 1.0f / sum; if (dst_precision == ov::element::f32) { - multiply_scalar(a, static_cast(a_dst), scalar, len); + multiply_scalar(a, reinterpret_cast(a_dst), scalar, len); // apply causual mask to final result instead of attn_score if (total_size > len) memset(static_cast(a_dst) + len, 0, sizeof(float) * (total_size - len));