diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.cpp index 82b8f4fb7dd1f2..5f7dcf6c132181 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "acl_utils.hpp" #include "acl_matmul.hpp" namespace ov { @@ -9,25 +10,6 @@ namespace intel_cpu { using namespace arm_compute; -TensorShape shapeCast(const VectorDims& dims) { - arm_compute::TensorShape tensorShape; - for (std::size_t i = 0; i < dims.size(); ++i) { - tensorShape.set(dims.size() - i - 1, dims[i], false); - } - if (tensorShape.num_dimensions() == 0) { - tensorShape.set(0, 1, false); - tensorShape.set_num_dimensions(1); - } - return tensorShape; -} - -inline Dim vectorProduct(const VectorDims& vec, size_t size) { - Dim prod = 1; - for (size_t i = 0; i < size; ++i) - prod *= vec[i]; - return prod; -} - AclMatMulExecutor::AclMatMulExecutor(const ExecutorContext::CPtr context) : MatMulExecutor(context) {} bool AclMatMulExecutor::init(const MatMulAttrs& matmulAttrs, diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.hpp index f0c4e3db68c195..80bd3eaff68e0a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_matmul.hpp @@ -10,8 +10,6 @@ namespace ov { namespace intel_cpu { -arm_compute::TensorShape shapeCast(const VectorDims& dims); - class AclMatMulExecutor : public MatMulExecutor { public: AclMatMulExecutor(const ExecutorContext::CPtr context); diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.cpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.cpp new file mode 100644 index 00000000000000..d12ddf1876a571 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "acl_utils.hpp" +#include "acl_reduce.hpp" + +namespace ov { +namespace intel_cpu { + +using namespace arm_compute; + +arm_compute::ReductionOperation getAclReductionOperationByAlgorithm(Algorithm algorithm) { + switch (algorithm) { + case Algorithm::ReduceMax: return arm_compute::ReductionOperation::MAX; + case Algorithm::ReduceMin: return arm_compute::ReductionOperation::MIN; + case Algorithm::ReduceSum: return arm_compute::ReductionOperation::SUM; + case Algorithm::ReduceProd: return arm_compute::ReductionOperation::PROD; + default: IE_THROW() << "Unsupported reduction operation: " << static_cast(algorithm); + } +} + +AclReduceExecutor::AclReduceExecutor(const ExecutorContext::CPtr context) : ReduceExecutor(context) {} + +bool AclReduceExecutor::init(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr &attr) { + if (reduceAttrs.operation != Algorithm::ReduceMax && + reduceAttrs.operation != Algorithm::ReduceMin && + reduceAttrs.operation != Algorithm::ReduceSum && + reduceAttrs.operation != Algorithm::ReduceProd && + reduceAttrs.operation != Algorithm::ReduceMean) { + return false; + } + + this->reduceAttrs = reduceAttrs; + + auto srcDims = srcDescs[0]->getShape().getStaticDims(); + auto dstDims = dstDescs[0]->getShape().getStaticDims(); + + TensorInfo srcTensorInfo = TensorInfo(shapeCast(srcDims), 1, + precisionToAclDataType(srcDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(srcDescs[0])); + TensorInfo dstTensorInfo = TensorInfo(shapeCast(dstDims), 1, + precisionToAclDataType(dstDescs[0]->getPrecision()), getAclDataLayoutByMemoryDesc(dstDescs[0])); + + srcTensor.allocator()->init(srcTensorInfo); + dstTensor.allocator()->init(dstTensorInfo); + + switch (reduceAttrs.operation) { + case Algorithm::ReduceMean: + for (size_t i = 0; i < reduceAttrs.axes.size(); ++i) { + auto pos = axisCast(i, reduceAttrs.axes.size()); + axesMean.set(pos, reduceAttrs.axes[i]); + } + if (!arm_compute::NEReduceMean::validate(&srcTensorInfo, axesMean, reduceAttrs.keepDims, &dstTensorInfo)) { + return false; + } + exec_func = [this]{ + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensor, axesMean, this->reduceAttrs.keepDims, &dstTensor); + acl_op->run(); + }; + break; + case Algorithm::ReduceMax: + case Algorithm::ReduceMin: + case Algorithm::ReduceSum: + case Algorithm::ReduceProd: + if (reduceAttrs.axes.size() != 1) { + return false; + } + if (!arm_compute::NEReductionOperation::validate(&srcTensorInfo, &dstTensorInfo, axisCast(reduceAttrs.axes[0], srcDims.size()), + getAclReductionOperationByAlgorithm(reduceAttrs.operation), reduceAttrs.keepDims)) { + return false; + } + exec_func = [this, srcDims]{ + auto acl_op = std::make_unique(); + acl_op->configure(&srcTensor, &dstTensor, axisCast(this->reduceAttrs.axes[0], srcDims.size()), + getAclReductionOperationByAlgorithm(this->reduceAttrs.operation), this->reduceAttrs.keepDims); + acl_op->run(); + }; + break; + default: + IE_THROW() << "Unsupported operation type for ACL Reduce executor: " << static_cast(reduceAttrs.operation); + } + + return true; +} + +void AclReduceExecutor::exec(const std::vector& src, const std::vector& dst, std::unordered_map postOpsArgs) { + srcTensor.allocator()->import_memory(src[0]->GetPtr()); + dstTensor.allocator()->import_memory(dst[0]->GetPtr()); + + exec_func(); + + srcTensor.allocator()->free(); + dstTensor.allocator()->free(); +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.hpp new file mode 100644 index 00000000000000..2dad0b9f2242db --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_reduce.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +// TODO: remove relative path +#include "../reduce.hpp" +#include "arm_compute/runtime/NEON/NEFunctions.h" + +namespace ov { +namespace intel_cpu { + +class AclReduceExecutor : public ReduceExecutor { +public: + AclReduceExecutor(const ExecutorContext::CPtr context); + + bool init(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr &attr) override; + void exec(const std::vector& src, + const std::vector& dst, + std::unordered_map postOpsArgs) override; + + impl_desc_type getImplType() const override { + return implType; + } + +private: + std::function exec_func; + ReduceAttrs reduceAttrs; + impl_desc_type implType = impl_desc_type::acl; + + arm_compute::Coordinates axesMean; + arm_compute::Tensor srcTensor; + arm_compute::Tensor dstTensor; +}; + +class AclReduceExecutorBuilder : public ReduceExecutorBuilder { +public: + bool isSupported(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const override { + if (srcDescs[0]->getPrecision() != dstDescs[0]->getPrecision() || + (srcDescs[0]->getPrecision() != InferenceEngine::Precision::FP32 && + dstDescs[0]->getPrecision() != InferenceEngine::Precision::FP16 && + dstDescs[0]->getPrecision() != InferenceEngine::Precision::I32)) + return false; + + return true; + } + + ReduceExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const override { + return std::make_shared(context); + } +}; + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp index f255a10d5c7d10..8cb60b02a67a16 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/acl/acl_utils.hpp @@ -3,9 +3,41 @@ // #pragma once +#include "ie_precision.hpp" +#include "memory_desc/cpu_memory_desc.h" +#include "arm_compute/core/Types.h" + namespace ov { namespace intel_cpu { +/** +* @brief Return ComputeLibrary TensorShape with reverted layout schema used in ACL +* @param dims vector of dimensions to convert +* @return ComputeLibrary TensorShape object +*/ +inline arm_compute::TensorShape shapeCast(const VectorDims& dims) { + arm_compute::TensorShape tensorShape; + for (std::size_t i = 0; i < dims.size(); ++i) { + tensorShape.set(dims.size() - i - 1, dims[i], false); + } + if (tensorShape.num_dimensions() == 0) { + tensorShape.set(0, 1, false); + tensorShape.set_num_dimensions(1); + } + return tensorShape; +} + +inline std::size_t axisCast(const std::size_t axis, const std::size_t shapeSize) { + return shapeSize - axis - 1; +} + +inline Dim vectorProduct(const VectorDims& vec, size_t size) { + Dim prod = 1; + for (size_t i = 0; i < size; ++i) + prod *= vec[i]; + return prod; +} + /** * @brief Return ComputeLibrary DataType that corresponds to the given precision * @param precision precision to be converted @@ -36,8 +68,8 @@ inline arm_compute::DataType precisionToAclDataType(InferenceEngine::Precision p inline arm_compute::DataLayout getAclDataLayoutByMemoryDesc(MemoryDescCPtr desc) { if (desc->hasLayoutType(LayoutType::ncsp)) { if (desc->getShape().getRank() == 4) return arm_compute::DataLayout::NCHW; - if (desc->getShape().getRank() == 5) return arm_compute::DataLayout::NCDHW; - } else if(desc->hasLayoutType(LayoutType::nspc)) { + if (desc->getShape().getRank() == 5) return arm_compute::DataLayout::NCDHW; + } else if (desc->hasLayoutType(LayoutType::nspc)) { if (desc->getShape().getRank() == 4) return arm_compute::DataLayout::NHWC; if (desc->getShape().getRank() == 5) return arm_compute::DataLayout::NDHWC; } diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_reduce.cpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_reduce.cpp new file mode 100644 index 00000000000000..15db8e97434007 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_reduce.cpp @@ -0,0 +1,205 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ref_reduce.hpp" +#include "ie_parallel.hpp" +#include + +namespace ov { +namespace intel_cpu { + +RefReduceExecutor::RefReduceExecutor(const ExecutorContext::CPtr context) : ReduceExecutor(context) {} + +bool RefReduceExecutor::init(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr &attr) { + this->reduceAttrs = reduceAttrs; + + if (srcDescs[0]->getPrecision() != InferenceEngine::Precision::FP32 || + dstDescs[0]->getPrecision() != InferenceEngine::Precision::FP32) + return false; + + if (!srcDescs[0]->hasLayoutType(LayoutType::ncsp) || + !dstDescs[0]->hasLayoutType(LayoutType::ncsp)) + return false; + + src_dims = srcDescs[0]->getShape().getDims(); + calc_process_dst_dims(dstDescs[0]->getShape().getStaticDims()); + return true; +} + +void RefReduceExecutor::exec(const std::vector& src, const std::vector& dst, std::unordered_map postOpsArgs) { + switch (reduceAttrs.operation) { + case Algorithm::ReduceAnd: + reduce_ref_process(src, dst, 1, [](float x, float y)->float { return x && y; }); + break; + case Algorithm::ReduceL1: + reduce_ref_process(src, dst, 0, [](float old, float y)->float { return old + (y >= 0 ? y : -y); }); + break; + case Algorithm::ReduceL2: + reduce_ref_process(src, dst, 0, [](float old, float y)->float { return old + y * y; }); + break; + case Algorithm::ReduceLogSum: + reduce_ref_process(src, dst, 0, [](float x, float y)->float { return x + y; }); + break; + case Algorithm::ReduceLogSumExp: + reduce_ref_process(src, dst, 0, [](float old, float y)->float { return old + expf(y); }); + break; + case Algorithm::ReduceMax: + reduce_ref_process(src, dst, std::numeric_limits::lowest(), + [](float x, float y)->float { return x > y ? x : y; }); + break; + case Algorithm::ReduceMean: + reduce_ref_process(src, dst, 0, [](float x, float y)->float { return x + y; }); + break; + case Algorithm::ReduceMin: + reduce_ref_process(src, dst, std::numeric_limits::max(), + [](float x, float y)->float { return x < y ? x : y; }); + break; + case Algorithm::ReduceOr: + reduce_ref_process(src, dst, 0, [](float x, float y)->float { return x || y; }); + break; + case Algorithm::ReduceProd: + reduce_ref_process(src, dst, 1, [](float x, float y)->float { return x * y; }); + break; + case Algorithm::ReduceSum: + reduce_ref_process(src, dst, 0, [](float x, float y)->float { return x + y; }); + break; + case Algorithm::ReduceSumSquare: + reduce_ref_process(src, dst, 0, [](float old, float y)->float { return old + y * y; }); + break; + default: + IE_THROW() << "Reduce node gets unsupported reduce mode."; + } +} + +inline void RefReduceExecutor::calc_process_dst_dims(const InferenceEngine::SizeVector &dst_dims) { + std::set axes; + InferenceEngine::SizeVector out_dims; + process_dst_dims.clear(); + axes_for_reduction.clear(); + for (auto &axis : reduceAttrs.axes) { + if (axis < 0) + axis += src_dims.size(); + if (static_cast(axis) > src_dims.size()) + IE_THROW() << "Reduce node " << axis << " " << src_dims.size() << " exceeds data tensor dimension on index to reduce"; + axes.insert(static_cast(axis)); + } + for (size_t i = 0; i < src_dims.size(); i++) { + bool found = false; + for (auto axis : axes) { + if (i == axis) { + found = true; + break; + } + } + if (found) { + if (reduceAttrs.keepDims) out_dims.push_back(1); + process_dst_dims.push_back(1); + axes_for_reduction.push_back(i); + } else { + out_dims.push_back(src_dims[i]); + process_dst_dims.push_back(src_dims[i]); + } + } + for (size_t i = 0; i < std::min(out_dims.size(), dst_dims.size()); i++) { + if (out_dims[i] != dst_dims[i]) + IE_THROW() << "Reduce node gets incorrect number of output dimensions!"; + } +} + +void RefReduceExecutor::reduce_ref_process(const std::vector& src, const std::vector& dst, + float init_value, std::function func) { + float *in_ptr = reinterpret_cast(src[0]->GetPtr()); + float *out_ptr = reinterpret_cast(dst[0]->GetPtr()); + + size_t work_amount_dst = 1, reduced_dims_work_amount = 1; + for (size_t i = 0; i < process_dst_dims.size(); i++) + work_amount_dst *= process_dst_dims[i]; + for (size_t i = 0; i < src_dims.size(); i++) + reduced_dims_work_amount *= src_dims[i]; + reduced_dims_work_amount /= work_amount_dst; + + InferenceEngine::SizeVector src_strides = src[0]->GetDescWithType()->getStrides(); + parallel_nt(0, [&](const int ithr, const int nthr) { + int j; + size_t i, start = 0, end = 0; + InferenceEngine::SizeVector dst_counters(process_dst_dims.size(), 0); + splitter(work_amount_dst, nthr, ithr, start, end); + for (j = process_dst_dims.size() - 1, i = start; j >= 0; j--) { + dst_counters[j] = i % process_dst_dims[j]; + i /= process_dst_dims[j]; + } + for (size_t src_idx = 0, dst_idx = start; dst_idx < end; ++dst_idx) { + float reduce_prod = init_value; + bool update_idx = true; + InferenceEngine::SizeVector src_counters = dst_counters; + for (i = 0; i < reduced_dims_work_amount; ++i) { + if (update_idx) { + src_idx = 0; + for (j = 0; j < static_cast(src_dims.size()); ++j) + src_idx += (src_counters[j] % src_dims[j]) * src_strides[j]; + update_idx = false; + } + reduce_prod = func(reduce_prod, in_ptr[src_idx]); + for (j = axes_for_reduction.size() - 1; j >= 0; j--) { + src_counters[axes_for_reduction[j]]++; + if (src_counters[axes_for_reduction[j]] < src_dims[axes_for_reduction[j]]) { + src_idx += src_strides[axes_for_reduction[j]]; + break; + } else { + src_counters[axes_for_reduction[j]] = 0; + update_idx = true; + } + } + } + out_ptr[dst_idx] = reduce_prod; + for (j = process_dst_dims.size() - 1; j >= 0; j--) { + dst_counters[j]++; + if (dst_counters[j] < process_dst_dims[j]) + break; + else + dst_counters[j] = 0; + } + } + }); + + reduce_ref_map(out_ptr, work_amount_dst, reduced_dims_work_amount); +} + +inline void RefReduceExecutor::reduce_ref_map(float *out_ptr, size_t work_amount_dst, size_t reduced_dims_work_amount) { + switch (reduceAttrs.operation) { + case Algorithm::ReduceAnd: + case Algorithm::ReduceL1: + case Algorithm::ReduceMax: + case Algorithm::ReduceMin: + case Algorithm::ReduceOr: + case Algorithm::ReduceProd: + case Algorithm::ReduceSum: + case Algorithm::ReduceSumSquare: + break; + case Algorithm::ReduceL2: + parallel_for(work_amount_dst, [&](size_t i) { + out_ptr[i] = std::sqrt(out_ptr[i]); + }); + break; + case Algorithm::ReduceLogSum: + case Algorithm::ReduceLogSumExp: + parallel_for(work_amount_dst, [&](size_t i) { + out_ptr[i] = logf(out_ptr[i]); + }); + break; + case Algorithm::ReduceMean: + parallel_for(work_amount_dst, [&](size_t i) { + out_ptr[i] /= reduced_dims_work_amount; + }); + break; + default: + IE_THROW() << "Reduce node gets unsupported reduce mode."; + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_reduce.hpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_reduce.hpp new file mode 100644 index 00000000000000..7fa81a0ecf1b73 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_reduce.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "nodes/executors/reduce.hpp" + +namespace ov { +namespace intel_cpu { + +class RefReduceExecutor : public ReduceExecutor { +public: + RefReduceExecutor(const ExecutorContext::CPtr context); + + bool init(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr &attr) override; + void exec(const std::vector& src, + const std::vector& dst, + std::unordered_map postOpsArgs) override; + + impl_desc_type getImplType() const override { + return implType; + } + +private: + void reduce_ref_process(const std::vector& src, const std::vector& dst, float init_value, std::function func); + inline void reduce_ref_map(float *out_ptr, size_t work_amount_dst, size_t reduced_dims_work_amount); + inline void calc_process_dst_dims(const InferenceEngine::SizeVector &dst_dim); + + impl_desc_type implType = impl_desc_type::ref; + InferenceEngine::SizeVector src_dims; + InferenceEngine::SizeVector process_dst_dims; + InferenceEngine::SizeVector axes_for_reduction; +}; + +class RefReduceExecutorBuilder : public ReduceExecutorBuilder { +public: + bool isSupported(const ReduceAttrs& reduceAttrs, const std::vector& srcDescs, const std::vector& dstDescs) const override { + return true; + } + + ReduceExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const override { + return std::make_shared(context); + } +}; + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce.cpp b/src/plugins/intel_cpu/src/nodes/executors/reduce.cpp new file mode 100644 index 00000000000000..a7906b6d20b1c8 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce.cpp @@ -0,0 +1,15 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "reduce.hpp" + +namespace ov { +namespace intel_cpu { + +using namespace InferenceEngine; + +ReduceExecutor::ReduceExecutor(const ExecutorContext::CPtr context) : context(context) {} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce.hpp b/src/plugins/intel_cpu/src/nodes/executors/reduce.hpp new file mode 100644 index 00000000000000..a929e0c5c3d6d0 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "cpu_memory.h" +#include "onednn/iml_type_mapper.h" +#include "dnnl_scratch_pad.h" +#include "executor.hpp" + +namespace ov { +namespace intel_cpu { + +struct ReduceAttrs { + std::vector axes; + Algorithm operation; + bool keepDims; +}; + +class ReduceExecutor { +public: + ReduceExecutor(const ExecutorContext::CPtr context); + virtual bool init(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr &attr) = 0; + + virtual void exec(const std::vector& src, const std::vector& dst, std::unordered_map postOpsArgs) = 0; + virtual ~ReduceExecutor() = default; + + virtual impl_desc_type getImplType() const = 0; + +protected: + ReduceAttrs reduceAttrs; + const ExecutorContext::CPtr context; +}; + +using ReduceExecutorPtr = std::shared_ptr; +using ReduceExecutorCPtr = std::shared_ptr; + +class ReduceExecutorBuilder { +public: + ~ReduceExecutorBuilder() = default; + virtual bool isSupported(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const = 0; + virtual ReduceExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const = 0; +}; + +using ReduceExecutorBuilderPtr = std::shared_ptr; +using ReduceExecutorBuilderCPtr = std::shared_ptr; + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.cpp new file mode 100644 index 00000000000000..fed7a21aaf4860 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.cpp @@ -0,0 +1,21 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "reduce_list.hpp" + +namespace ov { +namespace intel_cpu { + +const std::vector& getReduceExecutorsList() { + static std::vector descs = { + //OV_CPU_INSTANCE_X64(ExecutorType::x64, std::make_shared()) + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) + OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared()) + }; + + return descs; +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.hpp new file mode 100644 index 00000000000000..af3b55d712161e --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.hpp @@ -0,0 +1,86 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "executor.hpp" + +#include "reduce.hpp" +#if defined(OV_CPU_WITH_ACL) +#include "acl/acl_reduce.hpp" +#endif +#include "common/ref_reduce.hpp" + +#include "onednn/iml_type_mapper.h" +#include "common/primitive_cache.hpp" + +namespace ov { +namespace intel_cpu { + +struct ReduceExecutorDesc { + ExecutorType executorType; + ReduceExecutorBuilderCPtr builder; +}; + +const std::vector& getReduceExecutorsList(); + +class ReduceExecutorFactory : public ExecutorFactory { +public: + ReduceExecutorFactory(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const ExecutorContext::CPtr context) : ExecutorFactory(context) { + for (auto& desc : getReduceExecutorsList()) { + if (desc.builder->isSupported(reduceAttrs, srcDescs, dstDescs)) { + supportedDescs.push_back(desc); + } + } + } + + ~ReduceExecutorFactory() = default; + virtual ReduceExecutorPtr makeExecutor(const ReduceAttrs& reduceAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr &attr) { + auto build = [&](const ReduceExecutorDesc* desc) { + switch (desc->executorType) { + default: { + auto executor = desc->builder->makeExecutor(context); + if (executor->init(reduceAttrs, srcDescs, dstDescs, attr)) { + return executor; + } + } break; + } + + ReduceExecutorPtr ptr = nullptr; + return ptr; + }; + + + if (chosenDesc) { + if (auto executor = build(chosenDesc)) { + return executor; + } + } + + for (const auto& sd : supportedDescs) { + if (auto executor = build(&sd)) { + chosenDesc = &sd; + return executor; + } + } + + IE_THROW() << "Supported executor is not found"; + } + +private: + std::vector supportedDescs; + const ReduceExecutorDesc* chosenDesc = nullptr; +}; + +using ReduceExecutorFactoryPtr = std::shared_ptr; +using ReduceExecutorFactoryCPtr = std::shared_ptr; + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/reduce.cpp b/src/plugins/intel_cpu/src/nodes/reduce.cpp index e6130922691bc5..2bd46c4092c3e1 100644 --- a/src/plugins/intel_cpu/src/nodes/reduce.cpp +++ b/src/plugins/intel_cpu/src/nodes/reduce.cpp @@ -107,6 +107,7 @@ static inline bool isFloatCompatible(memory::data_type type) { return memory::data_type::f32 == type || memory::data_type::bf16 == type; } +#if defined(OPENVINO_ARCH_X86_64) template struct jit_uni_reduce_kernel_f32 : public jit_uni_reduce_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reduce_kernel_f32) @@ -1672,6 +1673,7 @@ struct jit_uni_reduce_post_kernel_f32 : public jit_uni_reduce_post_kernel, publi } } }; +#endif const std::map&, Reduce&)>> Reduce::initializers = { {ngraph::opset4::ReduceL1::get_type_info_static(), [](const std::shared_ptr& op, Reduce& node) { @@ -1759,6 +1761,9 @@ Reduce::Reduce(const std::shared_ptr& op, const GraphContext::CPtr } vec_reduceDH_prc.clear(); setJITBeyond5D(); + reduceAttrs.operation = algorithm; + reduceAttrs.axes = raw_axes; + reduceAttrs.keepDims = keep_dims; } else { IE_THROW(NotImplemented) << errorMessage; } @@ -1840,7 +1845,22 @@ void Reduce::initSupportedPrimitiveDescriptors() { config.inConfs[REDUCE_INDEXES].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(InferenceEngine::Precision::I32, getInputShapeAtPort(REDUCE_INDEXES))); config.outConfs[0].setMemDesc(creatorsMap.at(outFormat)->createSharedDesc(outPrecision, getOutputShapeAtPort(0))); + #if defined(OPENVINO_ARCH_X86_64) supportedPrimitiveDescriptors.push_back({config, impl_type}); + #else + std::vector srcMemoryDescs; + for (int i = 0; i < config.inConfs.size(); i++) { + srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()); + } + std::vector dstMemoryDescs; + for (int i = 0; i < config.outConfs.size(); i++) { + dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()); + } + + auto factory = std::make_shared(reduceAttrs, srcMemoryDescs, dstMemoryDescs, + std::make_shared(context, getPrimitivesPriority())); + supportedPrimitiveDescriptors.push_back({config, impl_type, factory}); + #endif }; if (jit_mode) { @@ -1873,7 +1893,7 @@ void Reduce::initSupportedPrimitiveDescriptors() { } } } else { - pushDesc(LayoutType::ncsp, LayoutType::ncsp, InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP32, impl_desc_type::ref); + pushDesc(LayoutType::ncsp, LayoutType::ncsp, InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP32, impl_desc_type::undef); } } @@ -1882,6 +1902,9 @@ bool Reduce::isExecutable() const { } void Reduce::prepareParams() { + auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); + const SizeVector &dst_dims = dstMemPtr->getDesc().getShape().getDims(); +#if defined(OPENVINO_ARCH_X86_64) src_dims = getParentEdgesAtPort(REDUCE_DATA)[0]->getMemory().getDesc().getShape().getDims(); std::vector reduce_axes; if (jit_mode && jit_beyond_5D) { @@ -1890,8 +1913,6 @@ void Reduce::prepareParams() { reduce_axes = raw_axes; } - auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); - const SizeVector &dst_dims = dstMemPtr->getDesc().getShape().getDims(); dst_size = dstMemPtr->GetSize(); calc_process_dst_dims(reduce_axes, dst_dims); if (jit_mode) { @@ -1931,6 +1952,21 @@ void Reduce::prepareParams() { compile_post_kernel = false; } } +#endif + std::vector srcMemoryDescs; + for (int i = 0; i < getOriginalInputsNumber(); i++) { + srcMemoryDescs.push_back(getParentEdgeAt(i)->getMemoryPtr()->getDescPtr()); + } + std::vector dstMemoryDescs; + for (int i = 0; i < getOriginalOutputsNumber(); i++) { + dstMemoryDescs.push_back(getChildEdgeAt(i)->getMemoryPtr()->getDescPtr()); + } + dnnl::primitive_attr attr; + setPostOps(attr, dst_dims, true); + auto selectedPD = getSelectedPrimitiveDescriptor(); + + execPtr = selectedPD->getExecutorFactoryAs()->makeExecutor(reduceAttrs, srcMemoryDescs, dstMemoryDescs, attr); + selectedPD->setImplementationType(execPtr->getImplType()); } void Reduce::createPrimitive() { @@ -1982,7 +2018,7 @@ void Reduce::createPrimitive() { prepareParams(); updateLastInputDims(); } - +#if defined(OPENVINO_ARCH_X86_64) if (mayiuse(cpu::x64::avx512_core)) { reduce_kernel.reset(new jit_uni_reduce_kernel_f32(jcp)); } else if (mayiuse(cpu::x64::avx2)) { @@ -1990,6 +2026,7 @@ void Reduce::createPrimitive() { } else if (mayiuse(cpu::x64::sse41)) { reduce_kernel.reset(new jit_uni_reduce_kernel_f32(jcp)); } +#endif if (reduce_kernel) reduce_kernel->create_ker(); jit_mode = jit_mode && reduce_kernel; @@ -2000,6 +2037,7 @@ void Reduce::executeDynamicImpl(dnnl::stream strm) { } void Reduce::execute(dnnl::stream strm) { +#if defined(OPENVINO_ARCH_X86_64) auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); auto &srcMemPtr = getParentEdgeAt(REDUCE_DATA)->getMemoryPtr(); @@ -2020,6 +2058,22 @@ void Reduce::execute(dnnl::stream strm) { IE_THROW() << errorPrefix << " supports only plain layout on machine w/o sse42."; } } +#else + if (!execPtr) { + IE_THROW() << "Can't execute Reduce node. Executor is not created"; + } + + std::vector srcMemory; + for (int i = 0; i < getOriginalInputsNumber(); i++) { + srcMemory.push_back(getParentEdgeAt(i)->getMemoryPtr()); + } + std::vector dstMemory; + for (int i = 0; i < getOriginalOutputsNumber(); i++) { + dstMemory.push_back(getChildEdgeAt(i)->getMemoryPtr()); + } + + execPtr->exec(srcMemory, dstMemory, postOpsArgs); +#endif } void Reduce::reduce_type(const uint8_t *in_ptr, uint8_t *out_ptr, size_t dst_size) { diff --git a/src/plugins/intel_cpu/src/nodes/reduce.h b/src/plugins/intel_cpu/src/nodes/reduce.h index b4021d79b9f37d..fbfb85c63c2a9b 100644 --- a/src/plugins/intel_cpu/src/nodes/reduce.h +++ b/src/plugins/intel_cpu/src/nodes/reduce.h @@ -10,6 +10,8 @@ #include #include +#include "executors/reduce_list.hpp" + namespace ov { namespace intel_cpu { namespace node { @@ -166,6 +168,10 @@ class Reduce : public Node { static const std::map& op, Reduce& node)>> initializers; std::string errorPrefix; + + ReduceAttrs reduceAttrs; + + std::shared_ptr execPtr = nullptr; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/nodes_factory.cpp b/src/plugins/intel_cpu/src/nodes_factory.cpp index e3ff96016e4cc0..8274891ecd6c10 100644 --- a/src/plugins/intel_cpu/src/nodes_factory.cpp +++ b/src/plugins/intel_cpu/src/nodes_factory.cpp @@ -172,6 +172,7 @@ Node::NodesFactory::NodesFactory() INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered); INTEL_CPU_NODE(Eye, Type::Eye); INTEL_CPU_NODE(Unique, Type::Unique); + INTEL_CPU_NODE(Reduce, Type::Reduce); #if defined(OPENVINO_ARCH_X86_64) INTEL_CPU_NODE(Gather, Type::Gather); INTEL_CPU_NODE(GridSample, Type::GridSample); @@ -182,7 +183,6 @@ Node::NodesFactory::NodesFactory() INTEL_CPU_NODE(ColorConvert, Type::ColorConvert); INTEL_CPU_NODE(NormalizeL2, Type::NormalizeL2); INTEL_CPU_NODE(BinaryConvolution, Type::BinaryConvolution); - INTEL_CPU_NODE(Reduce, Type::Reduce); INTEL_CPU_NODE(NonMaxSuppression, Type::NonMaxSuppression); INTEL_CPU_NODE(Interpolate, Type::Interpolate); INTEL_CPU_NODE(ROIPooling, Type::ROIPooling); diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.cpp new file mode 100644 index 00000000000000..e5bdf01d472e6b --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2020-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + + +#include "convert_reduce_multi_axis.hpp" + +#include + +#include +#include + +template +ngraph::matcher_pass_callback ov::intel_cpu::ConvertReduceMultiAxisBase::convert_reduce() { + return [&](ngraph::pattern::Matcher& m) { + auto reduce = m.get_match_root(); + if (!std::dynamic_pointer_cast(reduce)) { + return false; + } + if (ngraph::shape_size(reduce->input_value(1).get_shape()) <= 1) { + return false; + } + auto reduction_axes = std::dynamic_pointer_cast(reduce->input_value(1).get_node_shared_ptr()); + if (!reduction_axes) { + return false; + } + auto axes = reduction_axes->cast_vector(); + ngraph::NodeVector new_ops; + std::shared_ptr node = reduce->input_value(0).get_node_shared_ptr(); + for (auto axis : axes) { + auto reduction_axis = ov::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis}); + node = std::make_shared(node, reduction_axis, true); + new_ops.push_back(node); + } + + auto out_shape = reduce->get_output_shape(0); + auto dst_shape = std::make_shared(ngraph::element::i64, ngraph::Shape{out_shape.size()}, + std::vector(out_shape.begin(), out_shape.end())); + auto reshape = std::make_shared(node, dst_shape, true); + + reshape->set_friendly_name(reduce->get_friendly_name()); + ngraph::copy_runtime_info(reduce, new_ops); + ngraph::replace_node(reduce, reshape); + return true; + }; +} + +ov::intel_cpu::ConvertReduceProd::ConvertReduceProd() { + auto m = std::make_shared( + ngraph::pattern::wrap_type({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), + ngraph::pattern::wrap_type()}, + ngraph::pattern::has_static_shape()), "ConvertReduceProd"); + register_matcher(m, convert_reduce()); +} + +ov::intel_cpu::ConvertReduceMin::ConvertReduceMin() { + auto m = std::make_shared( + ngraph::pattern::wrap_type({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), + ngraph::pattern::wrap_type()}, + ngraph::pattern::has_static_shape()), "ConvertReduceMin"); + register_matcher(m, convert_reduce()); +} + +ov::intel_cpu::ConvertReduceMax::ConvertReduceMax() { + auto m = std::make_shared( + ngraph::pattern::wrap_type({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), + ngraph::pattern::wrap_type()}, + ngraph::pattern::has_static_shape()), "ConvertReduceMax"); + register_matcher(m, convert_reduce()); +} + +ov::intel_cpu::ConvertReduceSum::ConvertReduceSum() { + auto m = std::make_shared( + ngraph::pattern::wrap_type({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), + ngraph::pattern::wrap_type()}, + ngraph::pattern::has_static_shape()), "ConvertReduceSum"); + register_matcher(m, convert_reduce()); +} diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp new file mode 100644 index 00000000000000..1b2087e85945cf --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2020-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace ov { +namespace intel_cpu { + +class ConvertReduceMultiAxisBase: public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("ConvertReduceMultiAxisBase", "0"); + template + ngraph::matcher_pass_callback convert_reduce(); +}; + +class ConvertReduceProd: public ConvertReduceMultiAxisBase { +public: + OPENVINO_RTTI("ConvertReduceProd", "0"); + ConvertReduceProd(); +}; + +class ConvertReduceMin: public ConvertReduceMultiAxisBase { +public: + OPENVINO_RTTI("ConvertReduceMin", "0"); + ConvertReduceMin(); +}; + +class ConvertReduceMax: public ConvertReduceMultiAxisBase { +public: + OPENVINO_RTTI("ConvertReduceMax", "0"); + ConvertReduceMax(); +}; + +class ConvertReduceSum: public ConvertReduceMultiAxisBase { +public: + OPENVINO_RTTI("ConvertReduceSum", "0"); + ConvertReduceSum(); +}; + +class ConvertReduceMultiAxis: public ngraph::pass::GraphRewrite { +public: + OPENVINO_RTTI("ConvertReduceMultiAxis", "0"); + ConvertReduceMultiAxis() { + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + } +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index c2efe7c091bdd0..e2ea4d03b3d32c 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -91,6 +91,7 @@ #include "transformations/snippets/x64/pass/snippets_mark_skipped.hpp" #include "transformations/cpu_opset/x64/pass/mha_fusion.hpp" #include "transformations/cpu_opset/x64/pass/convert_to_interaction.hpp" +#include "transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp" #include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp" #include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp" #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" @@ -239,6 +240,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_REGISTER_PASS_COMMON(manager, SwapConvertTranspose); CPU_REGISTER_PASS_X64(manager, ConvertToInteraction); CPU_REGISTER_PASS_X64(manager, ConvertInteractionInt8); + CPU_REGISTER_PASS_ARM(manager, ConvertReduceMultiAxis); // SpaceToDepth/ DepthToSpace node implementation supports only equal input/output tensors with rank <= 5 CPU_SET_CALLBACK_COMMON(manager, @@ -384,8 +386,6 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertShuffleChannels3); CPU_DISABLE_PASS_COMMON(manager, ov::pass::Gelu7Downgrade); CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSwishDecomposition); - CPU_DISABLE_PASS_COMMON(manager, ov::pass::ReduceL1Decomposition); - CPU_DISABLE_PASS_COMMON(manager, ov::pass::ReduceL2Decomposition); CPU_DISABLE_PASS_COMMON(manager, ov::pass::SoftPlusDecomposition); CPU_DISABLE_PASS_COMMON(manager, ov::pass::HSigmoidDecomposition); CPU_DISABLE_PASS_COMMON(manager, ov::pass::ConvertMod); @@ -405,6 +405,9 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_DISABLE_PASS_COMMON(manager, ov::pass::SoftSignDecomposition); CPU_DISABLE_PASS_COMMON(manager, ov::pass::UniqueDecomposition); + CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL1Decomposition); + CPU_DISABLE_PASS_X64(manager, ov::pass::ReduceL2Decomposition); + CPU_ENABLE_PASS_COMMON(manager, ov::pass::NormalizeL2Decomposition); CPU_ENABLE_PASS_COMMON(manager, ov::pass::ConvertInterpolate1ToInterpolate4); CPU_ENABLE_PASS_COMMON(manager, ov::pass::ConvertGather1ToGather7); diff --git a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt index 614ae445065ebd..ce27a3bd326a36 100644 --- a/src/plugins/intel_cpu/tests/functional/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/functional/CMakeLists.txt @@ -8,7 +8,7 @@ set(TARGET_NAME ov_cpu_func_tests) # is a specific version for debugging purpose, just set DEBUG_SRC_PATH # to the test case to be debugged and debug using cpuDebugFuncTests set(DEBUG_TARGET_NAME cpuDebugFuncTests) -set(DEBUG_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/single_layer_tests/mvn.cpp) +set(DEBUG_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/single_layer_tests/reduce_ops.cpp) add_library(cpuSpecificRtInfo STATIC $/src/utils/rt_info/memory_formats_attribute.hpp $/src/utils/rt_info/memory_formats_attribute.cpp) diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/reduce_ops.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/reduce_ops.cpp index 7fa7871b0243ad..7fdc22dad04993 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/reduce_ops.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/reduce_ops.cpp @@ -295,15 +295,19 @@ std::vector> inputShapes_Int32 = { }; std::vector cpuParams_4D = { +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) CPUSpecificParams({nChw16c}, {nChw16c}, {}, {}), - CPUSpecificParams({nchw}, {nchw}, {}, {}), CPUSpecificParams({nhwc}, {nhwc}, {}, {}) +#endif + CPUSpecificParams({nchw}, {nchw}, {}, {}), }; std::vector cpuParams_5D = { +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) CPUSpecificParams({nCdhw16c}, {nCdhw16c}, {}, {}), - CPUSpecificParams({ncdhw}, {ncdhw}, {}, {}), CPUSpecificParams({ndhwc}, {ndhwc}, {}, {}) +#endif + CPUSpecificParams({ncdhw}, {ncdhw}, {}, {}), }; std::vector cpuParams_HybridLayout_4D = { @@ -379,6 +383,7 @@ const auto params_MultiAxis_5D = testing::Combine( testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)), testing::Values(emptyFusingSpec)); +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) const auto params_MultiAxis_4D_Hybrid = testing::Combine( testing::Combine( testing::ValuesIn(axesND), @@ -404,6 +409,7 @@ const auto params_MultiAxis_5D_Hybrid = testing::Combine( testing::ValuesIn(inputShapes_5D)), testing::ValuesIn(filterCPUSpecificParams(cpuParams_HybridLayout_5D)), testing::Values(emptyFusingSpec)); +#endif const auto params_MultiAxis_6D = testing::Combine( testing::Combine( @@ -452,6 +458,7 @@ INSTANTIATE_TEST_SUITE_P( ReduceCPULayerTest::getTestCaseName ); +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) INSTANTIATE_TEST_SUITE_P( smoke_Reduce_MultiAxis_4D_Hybrid_CPU, ReduceCPULayerTest, @@ -465,6 +472,7 @@ INSTANTIATE_TEST_SUITE_P( params_MultiAxis_5D_Hybrid, ReduceCPULayerTest::getTestCaseName ); +#endif INSTANTIATE_TEST_SUITE_P( smoke_Reduce_MultiAxis_6D_CPU, @@ -520,6 +528,7 @@ const auto params_MultiAxis_5D_Logical = testing::Combine( testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)), testing::Values(emptyFusingSpec)); +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) const auto params_MultiAxis_4D_Hybrid_Logical = testing::Combine( testing::Combine( testing::ValuesIn(axesND), @@ -545,6 +554,7 @@ const auto params_MultiAxis_5D_Hybrid_Logical = testing::Combine( testing::ValuesIn(inputShapes_5D)), testing::ValuesIn(filterCPUSpecificParams(cpuParams_HybridLayout_5D)), testing::Values(emptyFusingSpec)); +#endif const auto params_MultiAxis_6D_Logical = testing::Combine( testing::Combine( @@ -580,6 +590,7 @@ INSTANTIATE_TEST_SUITE_P( ReduceCPULayerTest::getTestCaseName ); +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) INSTANTIATE_TEST_SUITE_P( smoke_Reduce_MultiAxis_4D_Hybrid_Logical_CPU, ReduceCPULayerTest, @@ -593,6 +604,7 @@ INSTANTIATE_TEST_SUITE_P( params_MultiAxis_5D_Hybrid_Logical, ReduceCPULayerTest::getTestCaseName ); +#endif INSTANTIATE_TEST_SUITE_P( smoke_Reduce_MultiAxis_6D_Logical_CPU, @@ -602,6 +614,7 @@ INSTANTIATE_TEST_SUITE_P( ); /* ================================ 2.1 Fusion - KeepDims ================================ */ +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) const auto params_OneAxis_fusing = testing::Combine( testing::Combine( testing::ValuesIn(axes), @@ -722,6 +735,8 @@ INSTANTIATE_TEST_SUITE_P( params_MultiAxis_5D_Hybrid_fusing_KeepNoDims, ReduceCPULayerTest::getTestCaseName ); +#endif + } // namespace } // namespace CPULayerTestsDefinitions