Skip to content

Commit

Permalink
[ARM] [SDPA] SVE implementation of MHASingleToken for FP32 (#27273)
Browse files Browse the repository at this point in the history
### Details:
- Adds SVE FP32 implementations for functions called during execution of
`MHASingleToken` for SVE-128, SVE-256 and SVE-512 platforms.
- SVE implementations are compiled only if runtime support for SVE is
detected on the hardware, otherwise it falls back to Neon.
- Adds a new implementation for exponential function `exp_ps_<isa>`
using fewer FMA operations. Executes ~18% faster and has better output
precision.

**Note:** I am aware of the Neon FP16 implementation of SDPA added
recently. To accommodate for this, the current SVE changes will be used
only if the hardware does not have ARM FP16 support. I will follow up
with SVE FP16 implementations soon.

### [SVE] Benchmarking results
Below are the benchmarking results of execution time of each ported
function. Measurements were performed by running each function
individually on dummy inputs (128 fp32 elements) for 1,000,000
iterations and computing average time (in micro-seconds).


![image](https://github.com/user-attachments/assets/3f82238f-af7e-4b68-b4b1-259cf389e41a)

Execution time of `MHASingleToken` as a whole was also measured for two
LLMs, the results of which are shown below. For LlaMA-3-8B, the SVE-128
and SVE-512 systems at my disposal did not have enough memory, so only
SVE-256 results are shown. While there is an improvement overall, these
results could be contaminated with run-to-run variation due to the small
execution time of the kernel.

**Benchmarking details:** Prompt length of 108 tokens was used; total
time for generating 50 tokens was measured and average execution time
was computed.


![image](https://github.com/user-attachments/assets/893c1c46-085f-46af-ab5a-2c1481c75f68)

### New exponential implementation
It is based on the discussion in [these
slides](https://www.slideshare.net/slideshow/hpc-phys20201203/239717194#23)
(this is based on a past talk in Fujitsu hence the document is in
Japanese, sorry!). The algorithm followed is slightly different from the
current implementation, in that it uses `fexpa` instruction available on
ARM and requires only 3 Taylor expansion terms (2 FMA operations) to be
precise until the 8th decimal place.

Our benchmarking results showed this implementation to be 44%-58% faster
than the existing Neon implementation. It is ~18% faster than the SVE
implementation of the current algorithm in Neon.


![image](https://github.com/user-attachments/assets/117df21d-3977-499c-8ab8-8f4346286113)

In this PR, the new implementation is called by default. The SVE port of
the existing Neon implementation has also been retained, if needed.
  • Loading branch information
NishantPrabhuFujitsu authored Dec 17, 2024
1 parent 44f7ddb commit b543d0b
Show file tree
Hide file tree
Showing 10 changed files with 437 additions and 9 deletions.
88 changes: 88 additions & 0 deletions cmake/developer_package/compile_flags/os_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

include(ProcessorCount)
include(CheckCXXCompilerFlag)
include(CheckCXXSourceCompiles)

#
# ov_disable_deprecated_warnings()
Expand Down Expand Up @@ -91,6 +92,50 @@ macro(ov_dev_package_no_errors)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${ov_c_cxx_dev_no_errors}")
endmacro()

#
# ov_check_compiler_supports_sve(flags)
#
# Checks whether CXX compiler for passed language supports SVE code compilation
#
macro(ov_check_compiler_supports_sve flags)
# Code to compile
set(SVE_CODE "
#include <arm_sve.h>
int main() {
svfloat64_t a;
a = svdup_n_f64(0);
return 0;
}")

# Save the current state of required flags
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})

# Set the flags necessary for compiling the test code with SVE support
set(CMAKE_REQUIRED_FLAGS "${CMAKE_CXX_FLAGS_INIT} ${flags}")

# Check if the source code compiles with the given flags for C++
CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" CXX_HAS_SVE)

# If the compilation test is successful, set appropriate variables indicating support
if(CXX_HAS_SVE)
set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host")
set(CXX_SVE_FOUND TRUE CACHE BOOL "CXX SVE support")
set(CXX_SVE_FLAGS "${flags}" CACHE STRING "CXX SVE flags")
endif()

# Restore the original state of required flags
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})

# If the compilation test fails, indicate that the support is not found
if(NOT CXX_SVE_FOUND)
set(CXX_SVE_FOUND FALSE CACHE BOOL "CXX SVE support")
set(CXX_SVE_FLAGS "" CACHE STRING "CXX SVE flags")
endif()

# Mark the variables as advanced to hide them in the default CMake GUI
mark_as_advanced(CXX_SVE_FOUND CXX_SVE_FLAGS)
endmacro()

#
# ov_sse42_optimization_flags(<output flags>)
#
Expand Down Expand Up @@ -208,6 +253,49 @@ macro(ov_arm_neon_fp16_optimization_flags flags)
endif()
endmacro()

#
# ov_arm_sve_optimization_flags(<output flags>)
#
macro(ov_arm_sve_optimization_flags flags)
# Check for compiler SVE support
ov_check_compiler_supports_sve("-march=armv8-a+sve")

if(OV_COMPILER_IS_INTEL_LLVM)
message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}")
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
# nothing should be required here
elseif(ANDROID)
if(ANDROID_ABI STREQUAL "arm64-v8a")
set(${flags} -Wno-unused-command-line-argument)
if(CXX_SVE_FOUND)
list(APPEND ${flags} -march=armv8-a+sve)
else()
message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}")
endif()
else()
message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}")
endif()
else()
if(AARCH64)
set(${flags} -O2)

# Add flag for SVE if supported
if(CXX_SVE_FOUND)
list(APPEND ${flags} -march=armv8-a+sve)
endif()
if(NOT CMAKE_CL_64)
list(APPEND ${flags} -ftree-vectorize)
endif()

set(${flags} ${${flags}})
elseif(ARM)
message(WARNING "SVE is not supported on 32-bit ARM architectures.")
else()
message(WARNING "SVE is not supported by architecture ${CMAKE_SYSTEM_PROCESSOR}")
endif()
endif()
endmacro()

#
# ov_disable_all_warnings(<target1 [target2 target3 ...]>)
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ set(_CPU_CHECK_ANY "true")
set(_CPU_CHECK_SSE42 "with_cpu_x86_sse42()")
set(_CPU_CHECK_AVX "with_cpu_x86_avx()")
set(_CPU_CHECK_NEON_FP16 "with_cpu_neon_fp16()")
set(_CPU_CHECK_SVE "with_cpu_sve()")
set(_CPU_CHECK_AVX2 "with_cpu_x86_avx2()")
set(_CPU_CHECK_AVX512F "with_cpu_x86_avx512f()")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#

## list of available instruction sets
set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16)
set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16 SVE)

set(_ACCEPTED_ARCHS_ANY "^(ANY)$")
set(_ACCEPTED_ARCHS_SSE42 "^(ANY|SSE42)$")
set(_ACCEPTED_ARCHS_AVX "^(ANY|SSE42|AVX)$")
set(_ACCEPTED_ARCHS_AVX2 "^(ANY|SSE42|AVX|AVX2)$")
set(_ACCEPTED_ARCHS_AVX512F "^(ANY|SSE42|AVX|AVX2|AVX512F)$")
set(_ACCEPTED_ARCHS_NEON_FP16 "^(ANY|NEON_FP16)$")
set(_ACCEPTED_ARCHS_SVE "^(ANY|SVE)$")

## Arch specific definitions
set(_DEFINE_ANY "")
Expand All @@ -19,12 +20,14 @@ set(_DEFINE_AVX "HAVE_AVX" ${_DEFINE_SSE42})
set(_DEFINE_AVX2 "HAVE_AVX2" ${_DEFINE_AVX})
set(_DEFINE_AVX512F "HAVE_AVX512F" ${_DEFINE_AVX2})
set(_DEFINE_NEON_FP16 "HAVE_NEON_FP16" ${_DEFINE_ANY})
set(_DEFINE_SVE "HAVE_SVE" ${_DEFINE_SVE})

## Arch specific compile options
ov_avx512_optimization_flags(_FLAGS_AVX512F)
ov_avx2_optimization_flags (_FLAGS_AVX2)
ov_sse42_optimization_flags (_FLAGS_SSE42)
ov_arm_neon_fp16_optimization_flags(_FLAGS_NEON_FP16)
ov_arm_sve_optimization_flags(_FLAGS_SVE)
set(_FLAGS_AVX "") ## TBD is not defined for OV project yet
set(_FLAGS_ANY "") ##

Expand Down Expand Up @@ -185,6 +188,8 @@ endfunction()
function(_currently_requested_top_arch VAR)
if(ENABLE_NEON_FP16)
set(RES NEON_FP16)
elseif(ENABLE_SVE)
set(RES SVE)
elseif(ENABLE_AVX512F)
set(RES AVX512F)
elseif(ENABLE_AVX2)
Expand Down
2 changes: 2 additions & 0 deletions cmake/developer_package/features.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ ov_dependent_option (ENABLE_AVX512F "Enable AVX512 optimizations" ON "X86_64 OR

ov_dependent_option(ENABLE_NEON_FP16 "Enable ARM FP16 optimizations" ON "AARCH64" OFF)

ov_dependent_option(ENABLE_SVE "Enable SVE optimizations" ON "AARCH64" OFF)

# Type of build, we add this as an explicit option to default it to ON
get_property(BUILD_SHARED_LIBS_DEFAULT GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS)
ov_option (BUILD_SHARED_LIBS "Build as a shared library" ${BUILD_SHARED_LIBS_DEFAULT})
Expand Down
7 changes: 7 additions & 0 deletions src/inference/dev_api/openvino/runtime/system_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ OPENVINO_RUNTIME_API bool with_cpu_x86_sse42();
*/
OPENVINO_RUNTIME_API bool with_cpu_neon_fp16();

/**
* @brief Checks whether CPU supports ARM SVE capability
* @ingroup ov_dev_api_system_conf
* @return `True` if ARM SVE instructions are available, `false` otherwise
*/
OPENVINO_RUNTIME_API bool with_cpu_sve();

/**
* @brief Checks whether CPU supports AVX capability
* @ingroup ov_dev_api_system_conf
Expand Down
19 changes: 19 additions & 0 deletions src/inference/src/system_conf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# include <sys/auxv.h>
# define ARM_COMPUTE_CPU_FEATURE_HWCAP_FPHP (1 << 9)
# define ARM_COMPUTE_CPU_FEATURE_HWCAP_ASIMDHP (1 << 10)
# define ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE (1 << 24)
#elif defined(__APPLE__) && defined(__aarch64__)
# include <sys/sysctl.h>
# include <sys/types.h>
Expand Down Expand Up @@ -114,6 +115,10 @@ bool with_cpu_neon_fp16() {
return false;
}

bool with_cpu_sve() {
return false;
}

#else // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64

bool with_cpu_x86_sse42() {
Expand Down Expand Up @@ -173,6 +178,20 @@ bool with_cpu_neon_fp16() {
return false;
# endif
}
bool with_cpu_sve() {
# if !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \
!defined(__arm__) && defined(__aarch64__)
const uint32_t hwcaps = getauxval(AT_HWCAP);
return hwcaps & ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE;
# elif !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \
!defined(__aarch64__) && defined(__arm__)
return false;
# elif defined(__aarch64__) && defined(__APPLE__)
return false;
# else
return false;
# endif
}
#endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64

bool check_open_mp_env_vars(bool include_omp_num_threads) {
Expand Down
22 changes: 20 additions & 2 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,24 @@ target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $<TARGET_PROPERTY:dnnl,
# is not (yet) needed.
target_include_directories(${TARGET_NAME} PRIVATE $<TARGET_PROPERTY:openvino::reference,INTERFACE_INCLUDE_DIRECTORIES>)

# ARCH lists for softmax.cpp and mha_single_token.cpp
# Based on result of above calls, decide whether to add SVE
set(SOFTMAX_ARCH_LIST AVX512F AVX2)
set(MHA_SINGLE_TOKEN_ARCH_LIST AVX512F AVX2)

if(ENABLE_NEON_FP16)
list(APPEND SOFTMAX_ARCH_LIST NEON_FP16)
list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST NEON_FP16)
endif()

if(ENABLE_SVE)
list(APPEND SOFTMAX_ARCH_LIST SVE)
list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST SVE)
endif()

list(APPEND SOFTMAX_ARCH_LIST ANY)
list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST ANY)

# Cross compiled function
# TODO: The same for proposal, proposalONNX, topk
cross_compiled_file(${TARGET_NAME}
Expand All @@ -288,14 +306,14 @@ cross_compiled_file(${TARGET_NAME}
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 NEON_FP16 ANY
ARCH ${SOFTMAX_ARCH_LIST}
src/nodes/kernels/scaled_attn/softmax.cpp
API src/nodes/kernels/scaled_attn/softmax.hpp
NAME attn_softmax
NAMESPACE ov::Extensions::Cpu::XARCH
)
cross_compiled_file(${TARGET_NAME}
ARCH AVX512F AVX2 NEON_FP16 ANY
ARCH ${MHA_SINGLE_TOKEN_ARCH_LIST}
src/nodes/kernels/scaled_attn/mha_single_token.cpp
API src/nodes/kernels/scaled_attn/mha_single_token.hpp
NAME mha_single_token
Expand Down
80 changes: 80 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#endif

#if defined(OPENVINO_ARCH_ARM64)
# if defined(HAVE_SVE)
# include "arm_sve.h"
# endif
# include "arm_neon.h"
#endif

Expand All @@ -35,6 +38,10 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float);
static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float);
static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);

#if defined(HAVE_SVE)
static constexpr size_t vec_len_f32_sve = svcntw();
#endif

#ifdef HAVE_AVX512F
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
__m512i y = _mm512_cvtepu16_epi32(src);
Expand Down Expand Up @@ -250,6 +257,79 @@ inline void hmin(__m256& x) {
#endif

#ifdef OPENVINO_ARCH_ARM64
# if defined(HAVE_SVE)
inline svfloat32_t exp_ps_sve(svbool_t& pg, svfloat32_t& src) {
// Constants
const auto log2_e = svdup_n_f32(1.4426950409f);
const auto ln2 = svdup_n_f32(0.6931473921f);
const auto half_ln2_sq = svdup_n_f32(0.2413862043f);
const auto not_mask17 = svdup_n_u32(~((1u << 17) - 1));
const auto one = svdup_n_f32(1.0f);

// Algorithm starts here
svfloat32_t t0 = svmul_f32_z(pg, src, log2_e); // y = x * log2(e)
svfloat32_t t1 = svrintm_f32_z(pg, t0); // rount to int (float)
svint32_t t2 = svcvt_s32_f32_z(pg, t1); // n

t1 = svsub_f32_z(pg, t0, t1); // a = y - floor(y)
t1 = svadd_f32_z(pg, t1, one); // b = a + 1

svuint32_t t3 = svlsr_n_u32_z(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
t4 = svscale_f32_z(pg, t4, t2); // fexpa(v) * 2^(n)

// and_(t2.d, t1.d, not_mask17.d)
svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_z(pg, svreinterpret_u32_f32(t1), not_mask17));
t5 = svsub_f32_z(pg, t1, t5); // z
t0 = svmla_f32_z(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
t0 = svmla_f32_z(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
t0 = svmul_f32_z(pg, t0, t4); // Final result

return t0;
}
inline svfloat32_t exp_ps_sve_legacy(svbool_t& pg, svfloat32_t& src) {
const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6));
const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb));
const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33));
const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17));
const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010));

const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f
const auto one = svdup_n_f32(1.0f); // 1
const auto two = svdup_n_f32(2.0f); // 2
const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b));
const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(0xbf317200));
const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e));

const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
const auto zero = svdup_n_f32(0.f);
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)

const auto z = svmla_f32_z(pg, shift, src, inv_ln2);
auto n = svsub_f32_z(pg, z, shift);
n = svsub_f32_z(pg, n, one);
const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n

const auto r_hi = svmla_f32_z(pg, src, n, neg_ln2_hi);
const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo);
const auto r2 = svmul_f32_z(pg, r, r);

const auto p1 = svmul_f32_z(pg, c1, r);
const auto p23 = svmla_f32_z(pg, c2, c3, r);
const auto p45 = svmla_f32_z(pg, c4, c5, r);
const auto p2345 = svmla_f32_z(pg, p23, p45, r2);
const auto p12345 = svmla_f32_z(pg, p1, p2345, r2);

auto poly = svmla_f32_z(pg, scale, p12345, scale);
poly = svmul_f32_z(pg, poly, two);

poly = svsel_f32(svcmplt_f32(pg, src, min_input), zero, poly);
poly = svsel_f32(svcmpgt_f32(pg, src, max_input), inf, poly);

return poly;
}
# endif
inline float32x4_t exp_ps_neon_f32(const float32x4_t& src) {
const auto c1 = vreinterpretq_f32_u32(vdupq_n_u32(0x3f7ffff6));
const auto c2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3efffedb));
Expand Down
Loading

0 comments on commit b543d0b

Please sign in to comment.