Skip to content

Commit

Permalink
[Snippets][CPU] Fixed isSuitableChildForFusingMatMul (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#23182)

### Details:
- *The pass `isSuitableChildForFusingMatMul` first checked that the
`node` has constant inputs, without checking what the `node` type is and
what parent is. It leaded to infinity checks for constant path for the
node which cannot be even fused to `MatMul`. Thus, no make sense to
check that parent `MatMul` has `constant` path.*
 - *The PR refactored check `isSuitableChildForFusingMatMul`*:
- *Firstly, we check for possible fusion `MatMul (FC)` with `Bias` and
`DQScales` based on node type;*
- *Secondly, we add specific checks from `MatMul::CanFuse` for binary
`Eltwise` and `FQ`. Moreover, this code has been updated for the first
time since Snippets support!*
- *Thirdly, we check that the node is supported for fusion with `MatMul`
(is `Eltwise` op at least) and only after that (!) we check for constant
input paths for inputs that are not `MatMul` path to avoid endless
searches.*
- *Added additional `NodeFusingType` for `FullyConnected` to make the
code clearer: now we can init `channelAxis` correctly and separate
additional checks in `MatMul::CanFuse`*

### Tickets:
 - *CVS-134292*


### TODO:
- [x] *Performance Validation (Passed - the report is attached to the
ticket)*
  • Loading branch information
a-sidorova authored Mar 14, 2024
1 parent 3a3f323 commit 8ba1ae3
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
#include "snippets/op/subgraph.hpp"
#include "snippets/utils.hpp"

#include <transformations/utils/utils.hpp>
#include "utils/general_utils.h"
#include <utils/cpu_utils.hpp>
#include "transformations/utils/utils.hpp"
#include "transformations/utils.hpp"
#include "utils/general_utils.h"
#include "utils/cpu_utils.hpp"
#include "cpu/x64/cpu_isa_traits.hpp"

#include "itt.hpp"

Expand Down Expand Up @@ -58,6 +59,19 @@ int getNumNonConstInputs(const std::shared_ptr<const Node> &node) {
}
return num_non_const_inputs;
}
bool isFullyConnected(const std::shared_ptr<const ov::Node>& node) {
if (!ov::is_type<ov::op::v0::MatMul>(node))
return false;
const auto out_activations = node->input_value(0);
const auto out_weights = node->input_value(1);
const auto rank_a = out_activations.get_partial_shape().rank();
const auto rank_w = out_weights.get_partial_shape().rank();
return out_weights.get_partial_shape().is_static() &&
rank_a.is_static() && rank_w.is_static() &&
rank_a.get_length() != 1 && rank_w.get_length() != 1 &&
rank_a.get_length() <= 3 && rank_w.get_length() <= 3 &&
ov::op::util::is_on_constant_path(out_weights);
}
bool SupportsFusingWithConvolution_SumActivation(const std::shared_ptr<const Node> &node) {
// todo: Do all PReLUs are fused? Not sure about round and softRelu
// EltwiseRoundHalfToEven, EltwiseRoundHalfAwayFromZero, EltwiseSoftRelu
Expand Down Expand Up @@ -247,99 +261,97 @@ bool isSuitableChildForFusingSimple(const std::shared_ptr<const Node> &node, con
// Note: Fusing child is allowed to have several users, but that must be the end of the chain
return SupportsFusingWithConvolution_Simple(node, channelAxis) && getNumNonConstInputs(node) == 1;
}

bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, const bool canMatMulBeExecutedInI8,
NodeFusingType &updatedChainType, int& fusingAxis) {
int num_non_const_inputs = 0;
bool can_be_converted_to_FC = false;
ov::PartialShape bias_shape;
ov::PartialShape matmul_shape;
for (const auto &parent_out : node->input_values()) {
const auto parent = parent_out.get_node_shared_ptr();
if (ov::op::util::is_on_constant_path(parent_out)) {
bias_shape = parent_out.get_shape();
num_non_const_inputs++;
} else {
matmul_shape = parent_out.get_partial_shape();
if (matmul_shape.size() == 0)
return false;
const auto& grandparents = parent->input_values();
// first check that weights are constant and both activations and weights have static shape
if (grandparents.size() == 2 &&
grandparents[1].get_partial_shape().is_static() &&
(ov::op::util::is_on_constant_path(grandparents[1]))) {
auto rank_a = grandparents[0].get_partial_shape().rank().get_length();
auto rank_w = grandparents[1].get_partial_shape().rank().get_length();
if (rank_a != 1 && rank_w != 1 && rank_a <= 3 && rank_w <= 3)
can_be_converted_to_FC = true;
// Firsly check for Bias and DQScales fusion
const bool is_bias = ov::is_type<ov::opset1::Add>(node);
const bool is_dq_scales = ov::is_type<ov::opset1::Multiply>(node) && canMatMulBeExecutedInI8;
if (is_bias || is_dq_scales) {
for (const auto &in : node->inputs()) {
const auto& parent_out = in.get_source_output();
const auto& parent = parent_out.get_node_shared_ptr();
const auto& parent_pshape = parent_out.get_partial_shape();
if (ov::is_type<ov::op::v0::MatMul>(parent) && parent_pshape.rank().is_static()) {
if (parent->get_output_target_inputs(0).size() > 1)
break;
const auto bias_port = 1 - in.get_index();
const auto bias_out = node->input_value(bias_port);
if ((bias_out.get_target_inputs().size() > 1) || !ov::op::util::is_on_constant_path(bias_out))
break;
const auto& bias_pshape = bias_out.get_partial_shape();
if (bias_pshape.is_dynamic())
break;
auto getNormalizedPShape = [](const ov::PartialShape &dims, size_t ndims) -> ov::PartialShape {
if (dims.size() >= ndims)
return dims;
ov::PartialShape pshape(std::vector<size_t>(ndims, 1));
std::copy(dims.rbegin(), dims.rend(), pshape.rbegin());
return pshape;
};
const auto bias_pshape_norm = getNormalizedPShape(bias_pshape, parent_pshape.size());
if (fusingAxis >= static_cast<int>(bias_pshape_norm.size()) || fusingAxis >= static_cast<int>(parent_pshape.size()) ||
bias_pshape_norm.size() != parent_pshape.size() || bias_pshape_norm.size() < 2)
break;
if (((bias_pshape_norm[fusingAxis] == parent_pshape[fusingAxis]) || (is_dq_scales && bias_pshape_norm[fusingAxis] == 1)) &&
(bias_pshape_norm[fusingAxis] == static_cast<int64_t>(shape_size(bias_pshape_norm.get_shape()))))
return true;
}
}
}
if (num_non_const_inputs != 1)
return false;

// Matmul / FC bias fusion
if (ov::is_type<ov::opset1::Add>(node) &&
bias_shape.is_static() && matmul_shape.rbegin()->is_static() &&
bias_shape.rbegin()->get_length() == matmul_shape.rbegin()->get_length() &&
bias_shape.rbegin()->get_length() == static_cast<int64_t>(shape_size(bias_shape.get_shape()))) {
return true;
}

// FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation
// Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to
// eliminate getNumNonConstInputs() check
fusingAxis = can_be_converted_to_FC ? (matmul_shape.size() == 3 ? 2 : 1) : matmul_shape.size() - 1;
if (SupportsFusingWithConvolution_Simple(node, fusingAxis)) {
updatedChainType = NodeFusingType::FusedWithMisc;
return true;
}

// MatMul specific checks from ::canFuse()
if (!can_be_converted_to_FC) {
// can with rank() > 2
// Algorithm::EltwisePowerStatic is ignored
if (one_of(updatedChainType, NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8)) {
const auto is_binary_eltwise =
ov::is_type<ov::op::v1::Add>(node) || ov::is_type<ov::op::v1::Multiply>(node) || ov::is_type<ov::op::v1::Subtract>(node) ||
ov::is_type<ov::op::v1::Divide>(node) || ov::is_type<ov::op::v0::PRelu>(node);
const auto rank = node->get_output_partial_shape(0).rank();
if (rank.is_static() && rank.get_length() > 2) {
if (ov::is_type<ov::op::v1::Add>(node) ||
ov::is_type<ov::op::v1::Multiply>(node) ||
ov::is_type<ov::op::v1::Subtract>(node) ||
ov::is_type<ov::op::v1::Divide>(node) ||
ov::is_type<ov::op::v0::PRelu>(node)) {
const auto const1 = ov::is_type<ov::op::v0::Constant>(node->get_input_node_shared_ptr(0));
const auto const2 = ov::is_type<ov::op::v0::Constant>(node->get_input_node_shared_ptr(1));
int constPort = -1;
if (const2) {
constPort = 1;
} else if (const1) {
constPort = 0;
}
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && rank.is_static() && is_binary_eltwise) {
const auto const1 = ov::is_type<ov::op::v0::Constant>(node->get_input_node_shared_ptr(0));
const auto const2 = ov::is_type<ov::op::v0::Constant>(node->get_input_node_shared_ptr(1));
int constPort = -1;
if (const2) {
constPort = 1;
} else if (const1) {
constPort = 0;
}

if (constPort != -1) {
auto const_shape = node->get_input_shape(constPort);
if (ov::shape_size(const_shape) != 1) {
return false;
}
}
} else if (ov::is_type<ov::op::v0::FakeQuantize>(node)) {
const bool is_per_tensor_broadcasting = snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(1)) &&
snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(2)) &&
snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(3)) &&
snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(4));
if (!is_per_tensor_broadcasting) {
if (constPort != -1) {
auto const_shape = node->get_input_shape(constPort);
if (ov::shape_size(const_shape) != 1 && rank.get_length() > 4) {
return false;
}
}
}

// specific case for FQ
if (ov::is_type<ov::op::v0::FakeQuantize>(node)) {
if (one_of(node->get_output_element_type(0), ov::element::i8, ov::element::u8) && canMatMulBeExecutedInI8) {
if (one_of(node->get_output_element_type(0), ov::element::i8, ov::element::u8) && !canMatMulBeExecutedInI8)
return false;
}
}
}

return true;
// FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation
// Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to
// eliminate getNumNonConstInputs() check
if (SupportsFusingWithConvolution_Simple(node, fusingAxis)) {
size_t num_non_const_inputs = 0;
size_t num_mm_inputs = 0;
for (const auto &parent_out : node->input_values()) {
// To avoid endless check `is_on_constant_path` for MatMul branch
if (one_of(GetNodeFusingType(parent_out.get_node_shared_ptr()), NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8,
NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8))
num_mm_inputs++;
else if (!ov::op::util::is_on_constant_path(parent_out))
num_non_const_inputs++;
}
if (num_non_const_inputs + num_mm_inputs != 1)
return false;

updatedChainType = NodeFusingType::FusedWithMisc;
return true;
}

return false;
}
bool isSuitableParentForFusingSumActivation(const std::shared_ptr<const Node> &node) {
if (!ov::is_type<ov::op::v1::Add>(node))
Expand Down Expand Up @@ -508,11 +520,16 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
}
SetNodeFusingType(node, NodeFusingType::FusedWithMisc);
} else if (isSuitableMatMulParent(node)) {
if (canBeMatMulExecutedInInt8(node->get_input_element_type(0), node->get_input_element_type(1)))
SetNodeFusingType(node, NodeFusingType::FusedWithMatMulI8);
else
SetNodeFusingType(node, NodeFusingType::FusedWithMatMul);
channelAxis = DEFAULT_AXIS;
const bool is_fc = isFullyConnected(node);
const bool is_i8 = canBeMatMulExecutedInInt8(node->get_input_element_type(0), node->get_input_element_type(1));
const auto out_rank = node->get_output_partial_shape(0).rank();
if (is_fc) {
SetNodeFusingType(node, is_i8 ? NodeFusingType::FusedWithFCI8 : NodeFusingType::FusedWithFC);
channelAxis = out_rank.is_static() ? (out_rank.get_length() == 3 ? 2 : 1) : DEFAULT_AXIS;
} else {
SetNodeFusingType(node, is_i8 ? NodeFusingType::FusedWithMatMulI8 : NodeFusingType::FusedWithMatMul);
channelAxis = out_rank.is_static() ? out_rank.get_length() - 1 : DEFAULT_AXIS;
}
} else if (isSuitableSubtractAsZeroPointsParent(node)) {
SetSnippetsNodeType(node, snippets::pass::SnippetsNodeType::SkippedByPlugin);
channelAxis = DEFAULT_AXIS;
Expand Down Expand Up @@ -542,9 +559,9 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
// Todo: Chain could be converted from FusedWithBinaryConvolution to FusedWithConvolution at this point
// Set FusedWithConvolution, so the fusing chain could be propagated
PropagateIfHasOnlyChild(node, NodeFusingType::FusedWithConvolution);
} else if (fusingChainType == NodeFusingType::FusedWithMatMul ||
fusingChainType == NodeFusingType::FusedWithMatMulI8) {
const bool isExecutedInINT8 = fusingChainType == NodeFusingType::FusedWithMatMulI8;
} else if (one_of(fusingChainType, NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8,
NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8)) {
const bool isExecutedInINT8 = one_of(fusingChainType, NodeFusingType::FusedWithMatMulI8, NodeFusingType::FusedWithFCI8);
// Handle fusings for both MatMul and FullyConnected
NodeFusingType updatedChainType = fusingChainType;
if (isSuitableChildForFusingMatMul(node, isExecutedInINT8, updatedChainType, channelAxis))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum class NodeFusingType : int64_t {
NotSet,
FusedTerminator,
FusedWithConvolution, FusedWithBinaryConvolution, FusedWithConvolutionSumActivation,
FusedWithMatMul, FusedWithMatMulI8, FusedWithReduce, FusedWithMisc};
FusedWithMatMul, FusedWithFC, FusedWithMatMulI8, FusedWithFCI8, FusedWithReduce, FusedWithMisc};

} // namespace intel_cpu
} // namespace ov

0 comments on commit 8ba1ae3

Please sign in to comment.