Skip to content

Commit

Permalink
[GSOC][CPU][ARM] ACL scaled attention (openvinotoolkit#25183)
Browse files Browse the repository at this point in the history
### Details:
 - This PR aims to add ACL implementation for scaled attention
  • Loading branch information
mory91 authored Jul 16, 2024
1 parent f579633 commit b8ba903
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ file(GLOB_RECURSE HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/src/*.h
${CMAKE_CURRENT_SOURCE_DIR}/src/*.hpp)

if(NOT OV_CPU_WITH_ACL)
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/acl/*)
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/acl/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/acl/*)
endif()

if(NOT X86_64)
Expand Down
106 changes: 106 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gemm_kernel.hpp"
#define THROW_ERROR(...) OPENVINO_THROW("ACL gemm executor Init Failure '", __VA_ARGS__)

namespace ov {
namespace intel_cpu {
GemmKernel::GemmKernel(size_t M,
size_t N,
size_t K,
bool b_transposed,
ov::element::Type inType)
: M(M),
N(N),
K(K),
b_transposed(b_transposed) {
if (!one_of(inType, ov::element::f32, ov::element::f16, ov::element::bf16))
THROW_ERROR("brgemm kernel only supports bf16, f16 and f32");

if (inType == ov::element::f32)
format = arm_compute::Format::F32;
else if (inType == ov::element::f16)
format = arm_compute::Format::F16;
else if (inType == ov::element::bf16)
format = arm_compute::Format::BFLOAT16;


aclGemmKernel = std::make_unique<arm_compute::NEGEMM>();
}

arm_compute::Status GemmKernel::executeGemm(void *a,
void *b,
arm_compute::TensorInfo& dstInfo,
arm_compute::Tensor& dstTensor,
arm_compute::Strides aStrides,
arm_compute::Strides bStrides,
void *c,
float alpha,
float beta,
arm_compute::Strides* outStrides,
void* out) {
aInfo.init(
shapeCast({M, N}),
format,
aStrides,
size_t(0),
(size_t)(M * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format))));

arm_compute::TensorShape bShape;
if (b_transposed)
bShape = shapeCast({K, N});
else
bShape = shapeCast({N, K});

bInfo.init(
bShape,
format,
bStrides,
size_t(0),
(size_t)(K * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format))));

aTensor.allocator()->init(aInfo);
bTensor.allocator()->init(bInfo);

if (c != nullptr) {
cInfo.init(shapeCast({M, K}), format);
cTensor.allocator()->init(cInfo);
}

if (outStrides != nullptr)
dstInfo.init(
shapeCast({M, K}),
format,
*outStrides,
size_t(0),
(size_t)(M * K * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format))));
else
dstInfo.init(shapeCast({M, K}), format);

dstTensor.allocator()->init(dstInfo);

aTensor.allocator()->import_memory(reinterpret_cast<void *>(a));
bTensor.allocator()->import_memory(reinterpret_cast<void *>(b));
cTensor.allocator()->import_memory(reinterpret_cast<void *>(c));

if (out == nullptr)
dstTensor.allocator()->allocate();
else
dstTensor.allocator()->import_memory(out);

if (b_transposed)
aclGemmInfo.set_pretranspose_B(true);

auto status = aclGemmKernel->validate(&aInfo, &bInfo, &cInfo, &dstInfo, 1.0, 0.0, aclGemmInfo);

if (c == nullptr)
aclGemmKernel->configure(&aTensor, &bTensor, nullptr, &dstTensor, alpha, beta, aclGemmInfo);
else
aclGemmKernel->configure(&aTensor, &bTensor, &cTensor, &dstTensor, alpha, beta, aclGemmInfo);
aclGemmKernel->run();

return status;
}
} // namespace intel_cpu
} // namespace ov
52 changes: 52 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstddef>
#include <openvino/core/type/element_type.hpp>
#include "nodes/executors/acl/acl_utils.hpp"
#include "utils/general_utils.h"

#include "arm_compute/runtime/NEON/NEFunctions.h"
#include "arm_compute/core/Types.h"

namespace ov {
namespace intel_cpu {
class GemmKernel {
public:
GemmKernel(size_t M,
size_t N,
size_t K,
bool b_transposed = false,
ov::element::Type inType = ov::element::f32);

arm_compute::Status executeGemm(void* a,
void* b,
arm_compute::TensorInfo& dstInfo,
arm_compute::Tensor& dstTensor,
arm_compute::Strides aStrides,
arm_compute::Strides bStrides,
void* c = nullptr,
float alpha = 1.0f,
float beta = 0.0f,
arm_compute::Strides* outStrides = nullptr,
void* out = nullptr);

private:
size_t M = 0;
size_t N = 0, K = 0;
bool b_transposed = false;
arm_compute::Format format;
arm_compute::TensorInfo aInfo;
arm_compute::TensorInfo bInfo;
arm_compute::TensorInfo cInfo;
arm_compute::Tensor aTensor;
arm_compute::Tensor bTensor;
arm_compute::Tensor cTensor;
arm_compute::Tensor dTensor;
std::unique_ptr<arm_compute::NEGEMM> aclGemmKernel;
arm_compute::GEMMInfo aclGemmInfo;
};

} // namespace intel_cpu
} // namespace ov
149 changes: 148 additions & 1 deletion src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
# include "mlas/sgemm.hpp"
#endif

#ifdef OV_CPU_WITH_ACL
# include "kernels/acl/gemm_kernel.hpp"
#endif

#include "utils/plain_tensor.hpp"
#include "kernels/scaled_attn/softmax.hpp"
#include "kernels/scaled_attn/mha_single_token.hpp"
Expand Down Expand Up @@ -505,6 +509,147 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
}
};

#ifdef OV_CPU_WITH_ACL
template <>
struct MHAKernel<ScaledDotProductAttention::KT_ACL, float> {
const GraphContext::CPtr context;
size_t m_block_size;

MHAKernel() = delete;
explicit MHAKernel(GraphContext::CPtr ctx): context(ctx) {
m_block_size = 512;
select_nfltmax_at_0 = false;
}

PlainTensor causal_mask;
bool select_nfltmax_at_0; // set attn_score to -FLT_MAX when causal_mask[...] equal to this
void set_causal_mask(PlainTensor mask, bool _select_nfltmax_at_0) {
causal_mask = mask;
select_nfltmax_at_0 = _select_nfltmax_at_0;
}

// Q, K, V is ready, do attention
// query [B, H, q_len, S]
// present_key [B, H, kv_len, S] stride of last dim maybe > 1
// present_value [B, H, kv_len, S]
// attention_mask [B, 1, q_len, kv_len]
// alibi
// output_emb [B, L1, H*S]
void operator()(dnnl::stream strm,
PlainTensor& query,
PlainTensor& present_key,
PlainTensor& present_value,
const PlainTensor& alibi_mask,
const PlainTensor& attention_mask,
PlainTensor& output_emb,
bool has_out_transpose,
bool auto_causal,
float d_scale = 0.0f) {
auto B = query.size(0);
auto H = query.size(1);
auto q_len = query.size(2);
auto head_size = query.size(3);
auto kv_len = present_key.size(2);
auto h_group_num = present_key.size(1);
size_t h_each_group_len = H / h_group_num;

if (d_scale == 0.0f)
d_scale = 1.0f / sqrt(head_size);
auto k_stride_s = present_key.stride(3);

auto m_blocks = (q_len + m_block_size - 1) / m_block_size;

parallel_for3d(B, H, m_blocks, [&](size_t b, size_t h, size_t m_blk) {
auto m_start = m_blk * m_block_size;
auto m_end = std::min(m_start + m_block_size, q_len);
auto m_cnt = m_end - m_start;

float* q_ptr = &query.at<float>({b, h, m_start, 0});
float* k_ptr = &present_key.at<float>({b, h / h_each_group_len, 0, 0});
float* v_ptr = &present_value.at<float>({b, h / h_each_group_len, 0, 0});

float* alibi_ptr = nullptr;
auto alibi_stride = 0;
if (alibi_mask) {
alibi_ptr = &alibi_mask.at<float>({b, h, 0, 0}, true);
if (alibi_mask.size(2) > 1)
alibi_stride = alibi_mask.stride(2);
}
uint8_t* attn_mask_ptr = nullptr;
auto attn_mask_stride = 0;
if (attention_mask) {
attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<float>({b, h, 0, 0}, true));
if (attention_mask.size(2) > 1)
attn_mask_stride = attention_mask.stride(2) * sizeof(float);
}
uint8_t* cmask_ptr = nullptr;
auto cmask_stride = 0;
if (causal_mask) {
cmask_ptr = &causal_mask.at<uint8_t>({b, h, 0, 0}, true);
if (causal_mask.size(2) > 1)
cmask_stride = causal_mask.stride(2);
}

arm_compute::Tensor qkTensor;
arm_compute::TensorInfo qkInfo;

bool b_transpose = false;
if (k_stride_s == 1)
b_transpose = true;
GemmKernel qk_gemm(m_cnt, head_size, kv_len, b_transpose);

arm_compute::Strides qStrides({query.stride_bytes(3), query.stride_bytes(2)});
arm_compute::Strides kStrides({present_key.stride_bytes(3), present_key.stride_bytes(2)});
qk_gemm.executeGemm(reinterpret_cast<void *>(q_ptr),
reinterpret_cast<void *>(k_ptr),
qkInfo,
qkTensor,
qStrides,
kStrides);

auto qk = reinterpret_cast<float*>(qkTensor.buffer());


for (size_t m = m_start; m < m_end; m++) {
// apply attention mask & sofmax
auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len;
attn_softmax(qk + (m - m_start) * kv_len,
qk + (m - m_start) * kv_len,
d_scale,
alibi_ptr + m * alibi_stride,
attn_mask_ptr + m * attn_mask_stride,
cmask_ptr + m * cmask_stride,
select_nfltmax_at_0,
ncausal,
kv_len,
ov::element::f32,
ov::element::f32);
}
arm_compute::TensorInfo outInfo;
arm_compute::Tensor outTensor;

auto out = has_out_transpose ? &output_emb.at<float>({b, m_start, h * head_size}) : &output_emb.at<float>({b, h, m_start});
auto strides = arm_compute::Strides({output_emb.stride_bytes(1), output_emb.stride_bytes(2)});
GemmKernel out_gemm(m_cnt, kv_len, head_size);

arm_compute::Strides vStrides({present_value.stride_bytes(3), present_value.stride_bytes(2)});
out_gemm.executeGemm(qkTensor.buffer(),
reinterpret_cast<void *>(v_ptr),
outInfo,
outTensor,
qkInfo.strides_in_bytes(),
vStrides,
nullptr,
1.0,
0.0,
&strides,
reinterpret_cast<void*>(out));
qkTensor.allocator()->free();
});
}
};
#endif

#ifdef OV_CPU_WITH_MLAS
template <>
struct MHAKernel<ScaledDotProductAttention::KT_MLAS, float> {
Expand Down Expand Up @@ -935,7 +1080,9 @@ void ScaledDotProductAttention::createPrimitive() {
executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(context);
#endif
} else {
#ifdef OV_CPU_WITH_MLAS
#ifdef OV_CPU_WITH_ACL
executor = std::make_shared<AttentionExecutor<KT_ACL, float>>(context);
#elif defined(OV_CPU_WITH_MLAS)
executor = std::make_shared<AttentionExecutor<KT_MLAS, float>>(context);
#elif defined(OPENVINO_ARCH_X86_64)
if (with_cpu_x86_avx512_core()) {
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/scaled_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ScaledDotProductAttention : public Node {
void createPrimitive() override;
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;

enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS};
enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS, KT_ACL};

void assignState(const std::shared_ptr<VariableStateKVcache>& state, int idx);

Expand Down

0 comments on commit b8ba903

Please sign in to comment.