diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp index 7f12b67d54aae4..5e25f6c270d4e1 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp @@ -7,10 +7,11 @@ #include "snippets/op/subgraph.hpp" #include "snippets/utils.hpp" -#include -#include "utils/general_utils.h" -#include +#include "transformations/utils/utils.hpp" #include "transformations/utils.hpp" +#include "utils/general_utils.h" +#include "utils/cpu_utils.hpp" +#include "cpu/x64/cpu_isa_traits.hpp" #include "itt.hpp" @@ -58,6 +59,19 @@ int getNumNonConstInputs(const std::shared_ptr &node) { } return num_non_const_inputs; } +bool isFullyConnected(const std::shared_ptr& node) { + if (!ov::is_type(node)) + return false; + const auto out_activations = node->input_value(0); + const auto out_weights = node->input_value(1); + const auto rank_a = out_activations.get_partial_shape().rank(); + const auto rank_w = out_weights.get_partial_shape().rank(); + return out_weights.get_partial_shape().is_static() && + rank_a.is_static() && rank_w.is_static() && + rank_a.get_length() != 1 && rank_w.get_length() != 1 && + rank_a.get_length() <= 3 && rank_w.get_length() <= 3 && + ov::op::util::is_on_constant_path(out_weights); +} bool SupportsFusingWithConvolution_SumActivation(const std::shared_ptr &node) { // todo: Do all PReLUs are fused? Not sure about round and softRelu // EltwiseRoundHalfToEven, EltwiseRoundHalfAwayFromZero, EltwiseSoftRelu @@ -247,99 +261,97 @@ bool isSuitableChildForFusingSimple(const std::shared_ptr &node, con // Note: Fusing child is allowed to have several users, but that must be the end of the chain return SupportsFusingWithConvolution_Simple(node, channelAxis) && getNumNonConstInputs(node) == 1; } + bool isSuitableChildForFusingMatMul(const std::shared_ptr &node, const bool canMatMulBeExecutedInI8, NodeFusingType &updatedChainType, int& fusingAxis) { - int num_non_const_inputs = 0; - bool can_be_converted_to_FC = false; - ov::PartialShape bias_shape; - ov::PartialShape matmul_shape; - for (const auto &parent_out : node->input_values()) { - const auto parent = parent_out.get_node_shared_ptr(); - if (ov::op::util::is_on_constant_path(parent_out)) { - bias_shape = parent_out.get_shape(); - num_non_const_inputs++; - } else { - matmul_shape = parent_out.get_partial_shape(); - if (matmul_shape.size() == 0) - return false; - const auto& grandparents = parent->input_values(); - // first check that weights are constant and both activations and weights have static shape - if (grandparents.size() == 2 && - grandparents[1].get_partial_shape().is_static() && - (ov::op::util::is_on_constant_path(grandparents[1]))) { - auto rank_a = grandparents[0].get_partial_shape().rank().get_length(); - auto rank_w = grandparents[1].get_partial_shape().rank().get_length(); - if (rank_a != 1 && rank_w != 1 && rank_a <= 3 && rank_w <= 3) - can_be_converted_to_FC = true; + // Firsly check for Bias and DQScales fusion + const bool is_bias = ov::is_type(node); + const bool is_dq_scales = ov::is_type(node) && canMatMulBeExecutedInI8; + if (is_bias || is_dq_scales) { + for (const auto &in : node->inputs()) { + const auto& parent_out = in.get_source_output(); + const auto& parent = parent_out.get_node_shared_ptr(); + const auto& parent_pshape = parent_out.get_partial_shape(); + if (ov::is_type(parent) && parent_pshape.rank().is_static()) { + if (parent->get_output_target_inputs(0).size() > 1) + break; + const auto bias_port = 1 - in.get_index(); + const auto bias_out = node->input_value(bias_port); + if ((bias_out.get_target_inputs().size() > 1) || !ov::op::util::is_on_constant_path(bias_out)) + break; + const auto& bias_pshape = bias_out.get_partial_shape(); + if (bias_pshape.is_dynamic()) + break; + auto getNormalizedPShape = [](const ov::PartialShape &dims, size_t ndims) -> ov::PartialShape { + if (dims.size() >= ndims) + return dims; + ov::PartialShape pshape(std::vector(ndims, 1)); + std::copy(dims.rbegin(), dims.rend(), pshape.rbegin()); + return pshape; + }; + const auto bias_pshape_norm = getNormalizedPShape(bias_pshape, parent_pshape.size()); + if (fusingAxis >= static_cast(bias_pshape_norm.size()) || fusingAxis >= static_cast(parent_pshape.size()) || + bias_pshape_norm.size() != parent_pshape.size() || bias_pshape_norm.size() < 2) + break; + if (((bias_pshape_norm[fusingAxis] == parent_pshape[fusingAxis]) || (is_dq_scales && bias_pshape_norm[fusingAxis] == 1)) && + (bias_pshape_norm[fusingAxis] == static_cast(shape_size(bias_pshape_norm.get_shape())))) + return true; } } } - if (num_non_const_inputs != 1) - return false; - - // Matmul / FC bias fusion - if (ov::is_type(node) && - bias_shape.is_static() && matmul_shape.rbegin()->is_static() && - bias_shape.rbegin()->get_length() == matmul_shape.rbegin()->get_length() && - bias_shape.rbegin()->get_length() == static_cast(shape_size(bias_shape.get_shape()))) { - return true; - } - - // FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation - // Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to - // eliminate getNumNonConstInputs() check - fusingAxis = can_be_converted_to_FC ? (matmul_shape.size() == 3 ? 2 : 1) : matmul_shape.size() - 1; - if (SupportsFusingWithConvolution_Simple(node, fusingAxis)) { - updatedChainType = NodeFusingType::FusedWithMisc; - return true; - } // MatMul specific checks from ::canFuse() - if (!can_be_converted_to_FC) { - // can with rank() > 2 - // Algorithm::EltwisePowerStatic is ignored + if (one_of(updatedChainType, NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8)) { + const auto is_binary_eltwise = + ov::is_type(node) || ov::is_type(node) || ov::is_type(node) || + ov::is_type(node) || ov::is_type(node); const auto rank = node->get_output_partial_shape(0).rank(); - if (rank.is_static() && rank.get_length() > 2) { - if (ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node) || - ov::is_type(node)) { - const auto const1 = ov::is_type(node->get_input_node_shared_ptr(0)); - const auto const2 = ov::is_type(node->get_input_node_shared_ptr(1)); - int constPort = -1; - if (const2) { - constPort = 1; - } else if (const1) { - constPort = 0; - } + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && rank.is_static() && is_binary_eltwise) { + const auto const1 = ov::is_type(node->get_input_node_shared_ptr(0)); + const auto const2 = ov::is_type(node->get_input_node_shared_ptr(1)); + int constPort = -1; + if (const2) { + constPort = 1; + } else if (const1) { + constPort = 0; + } - if (constPort != -1) { - auto const_shape = node->get_input_shape(constPort); - if (ov::shape_size(const_shape) != 1) { - return false; - } - } - } else if (ov::is_type(node)) { - const bool is_per_tensor_broadcasting = snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(1)) && - snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(2)) && - snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(3)) && - snippets::utils::is_scalar_constant(node->get_input_node_shared_ptr(4)); - if (!is_per_tensor_broadcasting) { + if (constPort != -1) { + auto const_shape = node->get_input_shape(constPort); + if (ov::shape_size(const_shape) != 1 && rank.get_length() > 4) { return false; } } } - // specific case for FQ if (ov::is_type(node)) { - if (one_of(node->get_output_element_type(0), ov::element::i8, ov::element::u8) && canMatMulBeExecutedInI8) { + if (one_of(node->get_output_element_type(0), ov::element::i8, ov::element::u8) && !canMatMulBeExecutedInI8) return false; - } } } - return true; + // FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation + // Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to + // eliminate getNumNonConstInputs() check + if (SupportsFusingWithConvolution_Simple(node, fusingAxis)) { + size_t num_non_const_inputs = 0; + size_t num_mm_inputs = 0; + for (const auto &parent_out : node->input_values()) { + // To avoid endless check `is_on_constant_path` for MatMul branch + if (one_of(GetNodeFusingType(parent_out.get_node_shared_ptr()), NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8, + NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8)) + num_mm_inputs++; + else if (!ov::op::util::is_on_constant_path(parent_out)) + num_non_const_inputs++; + } + if (num_non_const_inputs + num_mm_inputs != 1) + return false; + + updatedChainType = NodeFusingType::FusedWithMisc; + return true; + } + + return false; } bool isSuitableParentForFusingSumActivation(const std::shared_ptr &node) { if (!ov::is_type(node)) @@ -508,11 +520,16 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { } SetNodeFusingType(node, NodeFusingType::FusedWithMisc); } else if (isSuitableMatMulParent(node)) { - if (canBeMatMulExecutedInInt8(node->get_input_element_type(0), node->get_input_element_type(1))) - SetNodeFusingType(node, NodeFusingType::FusedWithMatMulI8); - else - SetNodeFusingType(node, NodeFusingType::FusedWithMatMul); - channelAxis = DEFAULT_AXIS; + const bool is_fc = isFullyConnected(node); + const bool is_i8 = canBeMatMulExecutedInInt8(node->get_input_element_type(0), node->get_input_element_type(1)); + const auto out_rank = node->get_output_partial_shape(0).rank(); + if (is_fc) { + SetNodeFusingType(node, is_i8 ? NodeFusingType::FusedWithFCI8 : NodeFusingType::FusedWithFC); + channelAxis = out_rank.is_static() ? (out_rank.get_length() == 3 ? 2 : 1) : DEFAULT_AXIS; + } else { + SetNodeFusingType(node, is_i8 ? NodeFusingType::FusedWithMatMulI8 : NodeFusingType::FusedWithMatMul); + channelAxis = out_rank.is_static() ? out_rank.get_length() - 1 : DEFAULT_AXIS; + } } else if (isSuitableSubtractAsZeroPointsParent(node)) { SetSnippetsNodeType(node, snippets::pass::SnippetsNodeType::SkippedByPlugin); channelAxis = DEFAULT_AXIS; @@ -542,9 +559,9 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { // Todo: Chain could be converted from FusedWithBinaryConvolution to FusedWithConvolution at this point // Set FusedWithConvolution, so the fusing chain could be propagated PropagateIfHasOnlyChild(node, NodeFusingType::FusedWithConvolution); - } else if (fusingChainType == NodeFusingType::FusedWithMatMul || - fusingChainType == NodeFusingType::FusedWithMatMulI8) { - const bool isExecutedInINT8 = fusingChainType == NodeFusingType::FusedWithMatMulI8; + } else if (one_of(fusingChainType, NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8, + NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8)) { + const bool isExecutedInINT8 = one_of(fusingChainType, NodeFusingType::FusedWithMatMulI8, NodeFusingType::FusedWithFCI8); // Handle fusings for both MatMul and FullyConnected NodeFusingType updatedChainType = fusingChainType; if (isSuitableChildForFusingMatMul(node, isExecutedInINT8, updatedChainType, channelAxis)) diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.hpp index 5ce0b04c836255..e5e04b23a41717 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.hpp @@ -39,7 +39,7 @@ enum class NodeFusingType : int64_t { NotSet, FusedTerminator, FusedWithConvolution, FusedWithBinaryConvolution, FusedWithConvolutionSumActivation, - FusedWithMatMul, FusedWithMatMulI8, FusedWithReduce, FusedWithMisc}; + FusedWithMatMul, FusedWithFC, FusedWithMatMulI8, FusedWithFCI8, FusedWithReduce, FusedWithMisc}; } // namespace intel_cpu } // namespace ov