forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GSOC][CPU][ARM] ACL scaled attention (openvinotoolkit#25183)
### Details: - This PR aims to add ACL implementation for scaled attention
- Loading branch information
Showing
5 changed files
with
309 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
#include "gemm_kernel.hpp" | ||
#define THROW_ERROR(...) OPENVINO_THROW("ACL gemm executor Init Failure '", __VA_ARGS__) | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
GemmKernel::GemmKernel(size_t M, | ||
size_t N, | ||
size_t K, | ||
bool b_transposed, | ||
ov::element::Type inType) | ||
: M(M), | ||
N(N), | ||
K(K), | ||
b_transposed(b_transposed) { | ||
if (!one_of(inType, ov::element::f32, ov::element::f16, ov::element::bf16)) | ||
THROW_ERROR("brgemm kernel only supports bf16, f16 and f32"); | ||
|
||
if (inType == ov::element::f32) | ||
format = arm_compute::Format::F32; | ||
else if (inType == ov::element::f16) | ||
format = arm_compute::Format::F16; | ||
else if (inType == ov::element::bf16) | ||
format = arm_compute::Format::BFLOAT16; | ||
|
||
|
||
aclGemmKernel = std::make_unique<arm_compute::NEGEMM>(); | ||
} | ||
|
||
arm_compute::Status GemmKernel::executeGemm(void *a, | ||
void *b, | ||
arm_compute::TensorInfo& dstInfo, | ||
arm_compute::Tensor& dstTensor, | ||
arm_compute::Strides aStrides, | ||
arm_compute::Strides bStrides, | ||
void *c, | ||
float alpha, | ||
float beta, | ||
arm_compute::Strides* outStrides, | ||
void* out) { | ||
aInfo.init( | ||
shapeCast({M, N}), | ||
format, | ||
aStrides, | ||
size_t(0), | ||
(size_t)(M * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); | ||
|
||
arm_compute::TensorShape bShape; | ||
if (b_transposed) | ||
bShape = shapeCast({K, N}); | ||
else | ||
bShape = shapeCast({N, K}); | ||
|
||
bInfo.init( | ||
bShape, | ||
format, | ||
bStrides, | ||
size_t(0), | ||
(size_t)(K * N * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); | ||
|
||
aTensor.allocator()->init(aInfo); | ||
bTensor.allocator()->init(bInfo); | ||
|
||
if (c != nullptr) { | ||
cInfo.init(shapeCast({M, K}), format); | ||
cTensor.allocator()->init(cInfo); | ||
} | ||
|
||
if (outStrides != nullptr) | ||
dstInfo.init( | ||
shapeCast({M, K}), | ||
format, | ||
*outStrides, | ||
size_t(0), | ||
(size_t)(M * K * arm_compute::element_size_from_data_type(arm_compute::data_type_from_format(format)))); | ||
else | ||
dstInfo.init(shapeCast({M, K}), format); | ||
|
||
dstTensor.allocator()->init(dstInfo); | ||
|
||
aTensor.allocator()->import_memory(reinterpret_cast<void *>(a)); | ||
bTensor.allocator()->import_memory(reinterpret_cast<void *>(b)); | ||
cTensor.allocator()->import_memory(reinterpret_cast<void *>(c)); | ||
|
||
if (out == nullptr) | ||
dstTensor.allocator()->allocate(); | ||
else | ||
dstTensor.allocator()->import_memory(out); | ||
|
||
if (b_transposed) | ||
aclGemmInfo.set_pretranspose_B(true); | ||
|
||
auto status = aclGemmKernel->validate(&aInfo, &bInfo, &cInfo, &dstInfo, 1.0, 0.0, aclGemmInfo); | ||
|
||
if (c == nullptr) | ||
aclGemmKernel->configure(&aTensor, &bTensor, nullptr, &dstTensor, alpha, beta, aclGemmInfo); | ||
else | ||
aclGemmKernel->configure(&aTensor, &bTensor, &cTensor, &dstTensor, alpha, beta, aclGemmInfo); | ||
aclGemmKernel->run(); | ||
|
||
return status; | ||
} | ||
} // namespace intel_cpu | ||
} // namespace ov |
52 changes: 52 additions & 0 deletions
52
src/plugins/intel_cpu/src/nodes/kernels/acl/gemm_kernel.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
#pragma once | ||
#include <cstddef> | ||
#include <openvino/core/type/element_type.hpp> | ||
#include "nodes/executors/acl/acl_utils.hpp" | ||
#include "utils/general_utils.h" | ||
|
||
#include "arm_compute/runtime/NEON/NEFunctions.h" | ||
#include "arm_compute/core/Types.h" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
class GemmKernel { | ||
public: | ||
GemmKernel(size_t M, | ||
size_t N, | ||
size_t K, | ||
bool b_transposed = false, | ||
ov::element::Type inType = ov::element::f32); | ||
|
||
arm_compute::Status executeGemm(void* a, | ||
void* b, | ||
arm_compute::TensorInfo& dstInfo, | ||
arm_compute::Tensor& dstTensor, | ||
arm_compute::Strides aStrides, | ||
arm_compute::Strides bStrides, | ||
void* c = nullptr, | ||
float alpha = 1.0f, | ||
float beta = 0.0f, | ||
arm_compute::Strides* outStrides = nullptr, | ||
void* out = nullptr); | ||
|
||
private: | ||
size_t M = 0; | ||
size_t N = 0, K = 0; | ||
bool b_transposed = false; | ||
arm_compute::Format format; | ||
arm_compute::TensorInfo aInfo; | ||
arm_compute::TensorInfo bInfo; | ||
arm_compute::TensorInfo cInfo; | ||
arm_compute::Tensor aTensor; | ||
arm_compute::Tensor bTensor; | ||
arm_compute::Tensor cTensor; | ||
arm_compute::Tensor dTensor; | ||
std::unique_ptr<arm_compute::NEGEMM> aclGemmKernel; | ||
arm_compute::GEMMInfo aclGemmInfo; | ||
}; | ||
|
||
} // namespace intel_cpu | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters