diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index b76b198d30c7ab..490f8f5cd39a01 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -128,7 +128,8 @@ file(GLOB_RECURSE HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/src/*.h ${CMAKE_CURRENT_SOURCE_DIR}/src/*.hpp) if(NOT OV_CPU_WITH_ACL) - list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/acl/*) + list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/acl/* + ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/acl/*) endif() if(NOT X86_64) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp new file mode 100644 index 00000000000000..afaae43850bc16 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp @@ -0,0 +1,106 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "gemm_kernel.hpp" +#define THROW_ERROR(...) OPENVINO_THROW("ACL gemm executor Init Failure '", __VA_ARGS__) + +namespace ov { +namespace intel_cpu { + GemmKernel::GemmKernel(size_t M, + size_t N, + size_t K, + bool b_transposed, + ov::element::Type inType) + : M(M), + N(N), + K(K), + b_transposed(b_transposed) { + if (!one_of(inType, ov::element::f32, ov::element::f16, ov::element::bf16)) + THROW_ERROR("brgemm kernel only supports bf16, f16 and f32"); + + if (inType == ov::element::f32) + format = arm_compute::Format::F32; + else if (inType == ov::element::f16) + format = arm_compute::Format::F16; + else if (inType == ov::element::bf16) + format = arm_compute::Format::BFLOAT16; + + + aclGemmKernel = std::make_unique(); + } + + arm_compute::Status GemmKernel::executeGemm(void *a, + void *b, + arm_compute::TensorInfo& dstInfo, + arm_compute::Tensor& dstTensor, + arm_compute::Strides aStrides, + arm_compute::Strides bStrides, + void *c, + float alpha, + float beta, + arm_compute::Strides* outStrides, + void* out) { + aInfo.init( + shapeCast({M, N}), + format, + aStrides, + size_t(0), + (size_t)(M * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); + + arm_compute::TensorShape bShape; + if (b_transposed) + bShape = shapeCast({K, N}); + else + bShape = shapeCast({N, K}); + + bInfo.init( + bShape, + format, + bStrides, + size_t(0), + (size_t)(K * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); + + aTensor.allocator()->init(aInfo); + bTensor.allocator()->init(bInfo); + + if (c != nullptr) { + cInfo.init(shapeCast({M, K}), format); + cTensor.allocator()->init(cInfo); + } + + if (outStrides != nullptr) + dstInfo.init( + shapeCast({M, K}), + format, + *outStrides, + size_t(0), + (size_t)(M * K * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); + else + dstInfo.init(shapeCast({M, K}), format); + + dstTensor.allocator()->init(dstInfo); + + aTensor.allocator()->import_memory(reinterpret_cast(a)); + bTensor.allocator()->import_memory(reinterpret_cast(b)); + cTensor.allocator()->import_memory(reinterpret_cast(c)); + + if (out == nullptr) + dstTensor.allocator()->allocate(); + else + dstTensor.allocator()->import_memory(out); + + if (b_transposed) + aclGemmInfo.set_pretranspose_B(true); + + auto status = aclGemmKernel->validate(&aInfo, &bInfo, &cInfo, &dstInfo, 1.0, 0.0, aclGemmInfo); + + if (c == nullptr) + aclGemmKernel->configure(&aTensor, &bTensor, nullptr, &dstTensor, alpha, beta, aclGemmInfo); + else + aclGemmKernel->configure(&aTensor, &bTensor, &cTensor, &dstTensor, alpha, beta, aclGemmInfo); + aclGemmKernel->run(); + + return status; + } +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp new file mode 100644 index 00000000000000..620f42f239cbbb --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once +#include +#include +#include "nodes/executors/acl/acl_utils.hpp" +#include "utils/general_utils.h" + +#include "arm_compute/runtime/NEON/NEFunctions.h" +#include "arm_compute/core/Types.h" + +namespace ov { +namespace intel_cpu { +class GemmKernel { +public: + GemmKernel(size_t M, + size_t N, + size_t K, + bool b_transposed = false, + ov::element::Type inType = ov::element::f32); + + arm_compute::Status executeGemm(void* a, + void* b, + arm_compute::TensorInfo& dstInfo, + arm_compute::Tensor& dstTensor, + arm_compute::Strides aStrides, + arm_compute::Strides bStrides, + void* c = nullptr, + float alpha = 1.0f, + float beta = 0.0f, + arm_compute::Strides* outStrides = nullptr, + void* out = nullptr); + +private: + size_t M = 0; + size_t N = 0, K = 0; + bool b_transposed = false; + arm_compute::Format format; + arm_compute::TensorInfo aInfo; + arm_compute::TensorInfo bInfo; + arm_compute::TensorInfo cInfo; + arm_compute::Tensor aTensor; + arm_compute::Tensor bTensor; + arm_compute::Tensor cTensor; + arm_compute::Tensor dTensor; + std::unique_ptr aclGemmKernel; + arm_compute::GEMMInfo aclGemmInfo; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 326353157a3a1b..7471ebc33c9809 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -23,6 +23,10 @@ # include "mlas/sgemm.hpp" #endif +#ifdef OV_CPU_WITH_ACL +# include "kernels/acl/gemm_kernel.hpp" +#endif + #include "utils/plain_tensor.hpp" #include "kernels/scaled_attn/softmax.hpp" #include "kernels/scaled_attn/mha_single_token.hpp" @@ -505,6 +509,147 @@ struct MHAKernel { } }; +#ifdef OV_CPU_WITH_ACL +template <> +struct MHAKernel { + const GraphContext::CPtr context; + size_t m_block_size; + + MHAKernel() = delete; + explicit MHAKernel(GraphContext::CPtr ctx): context(ctx) { + m_block_size = 512; + select_nfltmax_at_0 = false; + } + + PlainTensor causal_mask; + bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this + void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) { + causal_mask = mask; + select_nfltmax_at_0 = _select_nfltmax_at_0; + } + + // Q, K, V is ready, do attention + // query [B, H, q_len, S] + // present_key [B, H, kv_len, S] stride of last dim maybe > 1 + // present_value [B, H, kv_len, S] + // attention_mask [B, 1, q_len, kv_len] + // alibi + // output_emb [B, L1, H*S] + void operator()(dnnl::stream strm, + PlainTensor& query, + PlainTensor& present_key, + PlainTensor& present_value, + const PlainTensor& alibi_mask, + const PlainTensor& attention_mask, + PlainTensor& output_emb, + bool has_out_transpose, + bool auto_causal, + float d_scale = 0.0f) { + auto B = query.size(0); + auto H = query.size(1); + auto q_len = query.size(2); + auto head_size = query.size(3); + auto kv_len = present_key.size(2); + auto h_group_num = present_key.size(1); + size_t h_each_group_len = H / h_group_num; + + if (d_scale == 0.0f) + d_scale = 1.0f / sqrt(head_size); + auto k_stride_s = present_key.stride(3); + + auto m_blocks = (q_len + m_block_size - 1) / m_block_size; + + parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) { + auto m_start = m_blk * m_block_size; + auto m_end = std::min(m_start + m_block_size, q_len); + auto m_cnt = m_end - m_start; + + float* q_ptr = &query.at({b, h, m_start, 0}); + float* k_ptr = &present_key.at({b, h / h_each_group_len, 0, 0}); + float* v_ptr = &present_value.at({b, h / h_each_group_len, 0, 0}); + + float* alibi_ptr = nullptr; + auto alibi_stride = 0; + if (alibi_mask) { + alibi_ptr = &alibi_mask.at({b, h, 0, 0}, true); + if (alibi_mask.size(2) > 1) + alibi_stride = alibi_mask.stride(2); + } + uint8_t* attn_mask_ptr = nullptr; + auto attn_mask_stride = 0; + if (attention_mask) { + attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, 0, 0}, true)); + if (attention_mask.size(2) > 1) + attn_mask_stride = attention_mask.stride(2) * sizeof(float); + } + uint8_t* cmask_ptr = nullptr; + auto cmask_stride = 0; + if (causal_mask) { + cmask_ptr = &causal_mask.at({b, h, 0, 0}, true); + if (causal_mask.size(2) > 1) + cmask_stride = causal_mask.stride(2); + } + + arm_compute::Tensor qkTensor; + arm_compute::TensorInfo qkInfo; + + bool b_transpose = false; + if (k_stride_s == 1) + b_transpose = true; + GemmKernel qk_gemm(m_cnt, head_size, kv_len, b_transpose); + + arm_compute::Strides qStrides({query.stride_bytes(3), query.stride_bytes(2)}); + arm_compute::Strides kStrides({present_key.stride_bytes(3), present_key.stride_bytes(2)}); + qk_gemm.executeGemm(reinterpret_cast(q_ptr), + reinterpret_cast(k_ptr), + qkInfo, + qkTensor, + qStrides, + kStrides); + + auto qk = reinterpret_cast(qkTensor.buffer()); + + + for (size_t m = m_start; m < m_end; m++) { + // apply attention mask & sofmax + auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; + attn_softmax(qk + (m - m_start) * kv_len, + qk + (m - m_start) * kv_len, + d_scale, + alibi_ptr + m * alibi_stride, + attn_mask_ptr + m * attn_mask_stride, + cmask_ptr + m * cmask_stride, + select_nfltmax_at_0, + ncausal, + kv_len, + ov::element::f32, + ov::element::f32); + } + arm_compute::TensorInfo outInfo; + arm_compute::Tensor outTensor; + + auto out = has_out_transpose ? &output_emb.at({b, m_start, h * head_size}) : &output_emb.at({b, h, m_start}); + auto strides = arm_compute::Strides({output_emb.stride_bytes(1), output_emb.stride_bytes(2)}); + GemmKernel out_gemm(m_cnt, kv_len, head_size); + + arm_compute::Strides vStrides({present_value.stride_bytes(3), present_value.stride_bytes(2)}); + out_gemm.executeGemm(qkTensor.buffer(), + reinterpret_cast(v_ptr), + outInfo, + outTensor, + qkInfo.strides_in_bytes(), + vStrides, + nullptr, + 1.0, + 0.0, + &strides, + reinterpret_cast(out)); + qkTensor.allocator()->free(); + }); + } +}; +#endif + #ifdef OV_CPU_WITH_MLAS template <> struct MHAKernel { @@ -935,7 +1080,9 @@ void ScaledDotProductAttention::createPrimitive() { executor = std::make_shared>(context); #endif } else { -#ifdef OV_CPU_WITH_MLAS +#ifdef OV_CPU_WITH_ACL + executor = std::make_shared>(context); +#elif defined(OV_CPU_WITH_MLAS) executor = std::make_shared>(context); #elif defined(OPENVINO_ARCH_X86_64) if (with_cpu_x86_avx512_core()) { diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 67568eda8d06e6..bbf12727478e43 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -36,7 +36,7 @@ class ScaledDotProductAttention : public Node { void createPrimitive() override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; - enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS}; + enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS, KT_ACL}; void assignState(const std::shared_ptr& state, int idx);