From 0c598f4d91cf12915af25ca0fed7970595a095f9 Mon Sep 17 00:00:00 2001 From: Yuan Hu Date: Thu, 25 Jul 2024 12:02:18 +0800 Subject: [PATCH] [CPU] only allow per-oc or per-tensor FQ fusing into FC (#25530) ### Details: - Add a check to reject non-supported FakeQuantize from fusing into FC node, so they can run in standalone mode w/o causing exceptions when composing oneDNN postOps. - port from https://github.com/openvinotoolkit/openvino/pull/23009 - add test case ### Tickets: - *CVS-131890* --------- Signed-off-by: HU Yuan2 Co-authored-by: Li, Tingqian --- .../intel_cpu/src/nodes/fullyconnected.cpp | 21 ++++++++++ .../instances/x64/matmul.cpp | 39 +++++++++++++++++++ .../functional/utils/fusing_test_utils.cpp | 6 ++- .../functional/utils/fusing_test_utils.hpp | 21 ++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp index 76e41db1cd06c0..da3dcafa4750ef 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp @@ -25,6 +25,8 @@ #include "utils/debug_capabilities.h" #include "utils/general_utils.h" +#include "fake_quantize.h" + using namespace dnnl; using namespace ov::element; @@ -94,6 +96,25 @@ bool FullyConnected::canFuse(const NodePtr& node) const { #if defined(OV_CPU_WITH_SHL) return false; #endif + if (node->getType() == Type::FakeQuantize) { + auto* fq = dynamic_cast(node.get()); + if (fq->getBroadcastingPolicy() != FakeQuantize::BroadcastingPolicy::PerTensor) { + const auto& dstShape = getOutputShapeAtPort(0); + auto dataRanks = dstShape.getRank(); + // only per-OC or per-Tensor fakequantize can be postOps + if (fq->getAxis() != dataRanks - 1) { + DEBUG_LOG("reject FakeQuantize ", + fq->getName(), + "(axis=", + fq->getAxis(), + ") from fusing into ", + getName(), + " with dst shape ", + dstShape); + return false; + } + } + } return canFuseSimpleOperation(node); } diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/matmul.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/matmul.cpp index 3daa819cd4854d..83faa2c06ec6f6 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/instances/x64/matmul.cpp @@ -1108,6 +1108,45 @@ INSTANTIATE_TEST_SUITE_P( testParamsDynamicFusingFullUndefShapes, MatMulLayerCPUTest::getTestCaseName); +class FCNotFuseFQCPUTest : public MatMulLayerCPUTest { + void SetUp() override { + MatMulLayerCPUTest::SetUp(); + expectPostOpsToBeFused = false; + } +}; + +TEST_P(FCNotFuseFQCPUTest, CompareWithRefs) { + run(); + CheckPluginRelatedResults(compiledModel, cpuNodeType); +} + +const std::vector& notFuseSmoke() { + static const std::vector params = { + {static_shapes_to_test_representation({{59, 1}, {1, 120}}), {false, true}}, + {static_shapes_to_test_representation({{59, 1}, {1, 120}}), {true, true}}, + + {static_shapes_to_test_representation({{59, 120}, {120, 1}}), {false, false}}, + {static_shapes_to_test_representation({{59, 120}, {120, 1}}), {true, true}}, + + {static_shapes_to_test_representation({{71, 128}, {128, 20}}), {true, false}}, + {static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, true}}, + }; + return params; +} + +const auto notFuseTestParamsSmoke = ::testing::Combine(::testing::Combine(::testing::ValuesIn(notFuseSmoke()), + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::Values(utils::InputLayerType::CONSTANT), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(emptyAdditionalConfig())), + ::testing::Values(MatMulNodeType::FullyConnected), + ::testing::ValuesIn({fusingFakeQuantizePerBatch, fusingFakeQuantizeFullTensor}), + ::testing::ValuesIn({CPUSpecificParams{{}, {}, {""}, "any_type"}})); + +INSTANTIATE_TEST_SUITE_P(smoke_FC, FCNotFuseFQCPUTest, notFuseTestParamsSmoke, FCNotFuseFQCPUTest::getTestCaseName); + } // namespace } // namespace MatMul } // namespace test diff --git a/src/plugins/intel_cpu/tests/functional/utils/fusing_test_utils.cpp b/src/plugins/intel_cpu/tests/functional/utils/fusing_test_utils.cpp index 39e60bdfe8a235..6f5e559201b30e 100644 --- a/src/plugins/intel_cpu/tests/functional/utils/fusing_test_utils.cpp +++ b/src/plugins/intel_cpu/tests/functional/utils/fusing_test_utils.cpp @@ -58,7 +58,11 @@ void CpuTestWithFusing::CheckFusingResults(const std::shared_ptr postOpMgrPtr; std::vector fusedOps; bool checkFusingPosition = true; + bool expectPostOpsToBeFused = true; }; static int getChannelAxis(const ov::AxisSet &axes, bool keep_dims) { @@ -304,6 +305,26 @@ const auto fusingFakeQuantizePerChannel = fusingSpecificParams{std::make_shared< return ov::test::utils::make_fake_quantize(cfg.input, localPrc, 256, newShape); }, "FakeQuantize(PerChannel)"}}), {"FakeQuantize"}}; +const auto fusingFakeQuantizePerBatch = fusingSpecificParams{std::make_shared(std::vector{ + {[](postNodeConfig& cfg){ + auto localPrc = cfg.input->get_element_type(); + const auto shape = cfg.input->get_output_partial_shape(0); + ov::Shape perBatchSize(shape.size(), 1); + perBatchSize[0] = shape[0].get_length(); + return ov::test::utils::make_fake_quantize(cfg.input, localPrc, 256, perBatchSize); + }, "FakeQuantize(PerBatch)"}}), {"FakeQuantize"}}; + +const auto fusingFakeQuantizeFullTensor = fusingSpecificParams{std::make_shared(std::vector{ + {[](postNodeConfig& cfg){ + auto localPrc = cfg.input->get_element_type(); + const auto shape = cfg.input->get_output_partial_shape(0); + ov::Shape fullTensorShape(shape.size(), 1); + for (size_t axis = 0; axis < shape.size(); axis++) { + fullTensorShape[axis] = shape[axis].get_length(); + } + return ov::test::utils::make_fake_quantize(cfg.input, localPrc, 256, fullTensorShape); + }, "FakeQuantize(FullTensor)"}}), {"FakeQuantize"}}; + const auto fusingFakeQuantizePerChannelRelu = fusingSpecificParams{std::make_shared(std::vector{ {[](postNodeConfig& cfg){ auto localPrc = cfg.input->get_element_type();