Skip to content

Commit

Permalink
add transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed May 10, 2024
1 parent ca2738e commit 4138d73
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1983,8 +1983,8 @@ void Reduce::initSupportedPrimitiveDescriptors() {
if (axis < 0)
axis += static_cast<int>(getInputShapeAtPort(REDUCE_DATA).getRank());
}
pushDesc(LayoutType::ncsp, LayoutType::ncsp, input_prec, output_prec, impl_desc_type::undef, true);
pushDesc(LayoutType::nspc, LayoutType::nspc, input_prec, output_prec, impl_desc_type::undef, true);
pushDesc(LayoutType::ncsp, LayoutType::ncsp, input_prec, output_prec, impl_desc_type::undef, true);
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
if (canUseAclExecutor)
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2020-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0


#include "convert_reduce_no_keep_dims.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset8.hpp"

template <class T>
ov::matcher_pass_callback ov::intel_cpu::ConvertReduceNoKeepDimsBase::convert_reduce() {
return [&](ov::pass::pattern::Matcher& m) {
auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
if (!reduce || reduce->is_dynamic() || reduce->get_keep_dims()) {
return false;
}

reduce->set_keep_dims(true);
const auto reduce_new = reduce->clone_with_new_inputs({reduce->input_value(0), reduce->input_value(1)});
std::shared_ptr<ov::Node> squeeze = std::make_shared<ov::op::v0::Squeeze>(reduce_new, reduce->input_value(1));
squeeze->set_friendly_name(reduce_new->get_friendly_name());
ov::copy_runtime_info(reduce, {reduce_new, squeeze});
ov::replace_node(reduce, squeeze);

return true;
};
}

ov::intel_cpu::ConvertArithmeticReduction::ConvertArithmeticReduction() {
auto m = std::make_shared<ov::pass::pattern::Matcher>(
ov::pass::pattern::wrap_type<ov::op::util::ArithmeticReductionKeepDims>({ov::pass::pattern::any_input(),
ov::pass::pattern::wrap_type<ov::opset8::Constant>()}), "ConvertArithmeticReduction");
register_matcher(m, convert_reduce<ov::op::util::ArithmeticReductionKeepDims>());
}

ov::intel_cpu::ConvertLogicalReduction::ConvertLogicalReduction() {
auto m = std::make_shared<ov::pass::pattern::Matcher>(
ov::pass::pattern::wrap_type<ov::op::util::LogicalReductionKeepDims>({ov::pass::pattern::any_input(),
ov::pass::pattern::wrap_type<ov::opset8::Constant>()}), "ConvertLogicalReduction");
register_matcher(m, convert_reduce<ov::op::util::LogicalReductionKeepDims>());
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (C) 2020-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/graph_rewrite.hpp"

/*
* Description:
* ConvertReduceNoKeepDimsBase detects Reduce operations with keepDims = false.
* Such Reduce operation is replaced with Reduce operation with keepDims = true and Squeeze
* which removes undesired dimensions.
*
* Before:
*
* +--------------+ +-----------------+
* | Data | | Axes tensor |
* +-----------+--+ +-+---------------+
* | |
* +---------------------------+
* | Reduce (keepDims = false) |
* +---------------------------+
*
* After:
*
* +--------------+ +-----------------+
* | Data | | Axes tensor |
* +-----------+--+ +-+------------+--+
* | | |
* +---------------------------+ |
* | Reduce (keepDims = true) | |
* +-----------------------+---+ |
* | |
* +--------v------v-+
* | Squeeze |
* +-----------------+
*
*/

namespace ov {
namespace intel_cpu {

class ConvertReduceNoKeepDimsBase: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertReduceNoKeepDims", "0");
template <class T>
ov::matcher_pass_callback convert_reduce();
};

class ConvertArithmeticReduction: public ConvertReduceNoKeepDimsBase {
public:
OPENVINO_RTTI("ConvertArithmeticReduction", "0");
ConvertArithmeticReduction();
};

class ConvertLogicalReduction: public ConvertReduceNoKeepDimsBase {
public:
OPENVINO_RTTI("ConvertLogicalReduction", "0");
ConvertLogicalReduction();
};

class ConvertReduceNoKeepDims: public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("ConvertReduceNoKeepDims", "0");
ConvertReduceNoKeepDims() {
add_matcher<ConvertArithmeticReduction>();
add_matcher<ConvertLogicalReduction>();
}
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
#include "transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp"
#include "transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp"
#include "transformations/cpu_opset/arm/pass/mish_decomposition.hpp"
#include "transformations/cpu_opset/common/pass/convert_reduce_no_keep_dims.hpp"
#include "transformations/cpu_opset/common/pass/decompose_integer_divide.hpp"
#include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp"
#include "transformations/cpu_opset/common/pass/insert_convert_after_extension.hpp"
Expand Down Expand Up @@ -414,6 +415,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis

CPU_REGISTER_PASS_COMMON(manager, ov::pass::EliminateConvert);
CPU_REGISTER_PASS_COMMON(manager, SwapConvertTranspose);
CPU_REGISTER_PASS_COMMON(manager, ConvertReduceNoKeepDims);
CPU_REGISTER_PASS_X64(manager, ConvertToInteraction);
CPU_REGISTER_PASS_X64(manager, ConvertInteractionInt8);
CPU_REGISTER_PASS_ARM(manager, ConvertReduceMultiAxis);
Expand Down

0 comments on commit 4138d73

Please sign in to comment.