Skip to content

Commit

Permalink
[CPU] only allow per-oc or per-tensor FQ fusing into FC (openvinotool…
Browse files Browse the repository at this point in the history
…kit#25530)

### Details:
- Add a check to reject non-supported FakeQuantize from fusing into FC
node, so they can run in standalone mode w/o causing exceptions when
composing oneDNN postOps.
 - port from openvinotoolkit#23009
 - add test case

### Tickets:
 - *CVS-131890*

---------

Signed-off-by: HU Yuan2 <[email protected]>
Co-authored-by: Li, Tingqian <[email protected]>
  • Loading branch information
tiger100256-hu and usstq authored Jul 25, 2024
1 parent b9d98cb commit 0c598f4
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 1 deletion.
21 changes: 21 additions & 0 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "utils/debug_capabilities.h"
#include "utils/general_utils.h"

#include "fake_quantize.h"

using namespace dnnl;
using namespace ov::element;

Expand Down Expand Up @@ -94,6 +96,25 @@ bool FullyConnected::canFuse(const NodePtr& node) const {
#if defined(OV_CPU_WITH_SHL)
return false;
#endif
if (node->getType() == Type::FakeQuantize) {
auto* fq = dynamic_cast<FakeQuantize*>(node.get());
if (fq->getBroadcastingPolicy() != FakeQuantize::BroadcastingPolicy::PerTensor) {
const auto& dstShape = getOutputShapeAtPort(0);
auto dataRanks = dstShape.getRank();
// only per-OC or per-Tensor fakequantize can be postOps
if (fq->getAxis() != dataRanks - 1) {
DEBUG_LOG("reject FakeQuantize ",
fq->getName(),
"(axis=",
fq->getAxis(),
") from fusing into ",
getName(),
" with dst shape ",
dstShape);
return false;
}
}
}
return canFuseSimpleOperation(node);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,45 @@ INSTANTIATE_TEST_SUITE_P(
testParamsDynamicFusingFullUndefShapes,
MatMulLayerCPUTest::getTestCaseName);

class FCNotFuseFQCPUTest : public MatMulLayerCPUTest {
void SetUp() override {
MatMulLayerCPUTest::SetUp();
expectPostOpsToBeFused = false;
}
};

TEST_P(FCNotFuseFQCPUTest, CompareWithRefs) {
run();
CheckPluginRelatedResults(compiledModel, cpuNodeType);
}

const std::vector<ShapeRelatedParams>& notFuseSmoke() {
static const std::vector<ShapeRelatedParams> params = {
{static_shapes_to_test_representation({{59, 1}, {1, 120}}), {false, true}},
{static_shapes_to_test_representation({{59, 1}, {1, 120}}), {true, true}},

{static_shapes_to_test_representation({{59, 120}, {120, 1}}), {false, false}},
{static_shapes_to_test_representation({{59, 120}, {120, 1}}), {true, true}},

{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {true, false}},
{static_shapes_to_test_representation({{71, 128}, {128, 20}}), {false, true}},
};
return params;
}

const auto notFuseTestParamsSmoke = ::testing::Combine(::testing::Combine(::testing::ValuesIn(notFuseSmoke()),
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::Values(utils::InputLayerType::CONSTANT),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(emptyAdditionalConfig())),
::testing::Values(MatMulNodeType::FullyConnected),
::testing::ValuesIn({fusingFakeQuantizePerBatch, fusingFakeQuantizeFullTensor}),
::testing::ValuesIn({CPUSpecificParams{{}, {}, {""}, "any_type"}}));

INSTANTIATE_TEST_SUITE_P(smoke_FC, FCNotFuseFQCPUTest, notFuseTestParamsSmoke, FCNotFuseFQCPUTest::getTestCaseName);

} // namespace
} // namespace MatMul
} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ void CpuTestWithFusing::CheckFusingResults(const std::shared_ptr<const ov::Model
size_t pos = 0;
for (const auto& fusedOp : fusedOps) {
pos = originalLayersNames.find(fusedOp, checkFusingPosition ? pos : 0);
ASSERT_TRUE(pos != std::string::npos) << "Fused op " << fusedOp << " has not been found!";
if (expectPostOpsToBeFused) {
ASSERT_TRUE(pos != std::string::npos) << "Fused op " << fusedOp << " has not been found!";
} else {
ASSERT_TRUE(pos == std::string::npos) << "op" << fusedOp << " should not be fused!";
}
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/plugins/intel_cpu/tests/functional/utils/fusing_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class CpuTestWithFusing : public CPUTestsBase {
std::shared_ptr<postOpMgr> postOpMgrPtr;
std::vector<std::string> fusedOps;
bool checkFusingPosition = true;
bool expectPostOpsToBeFused = true;
};

static int getChannelAxis(const ov::AxisSet &axes, bool keep_dims) {
Expand Down Expand Up @@ -304,6 +305,26 @@ const auto fusingFakeQuantizePerChannel = fusingSpecificParams{std::make_shared<
return ov::test::utils::make_fake_quantize(cfg.input, localPrc, 256, newShape);
}, "FakeQuantize(PerChannel)"}}), {"FakeQuantize"}};

const auto fusingFakeQuantizePerBatch = fusingSpecificParams{std::make_shared<postNodesMgr>(std::vector<postNodeBuilder>{
{[](postNodeConfig& cfg){
auto localPrc = cfg.input->get_element_type();
const auto shape = cfg.input->get_output_partial_shape(0);
ov::Shape perBatchSize(shape.size(), 1);
perBatchSize[0] = shape[0].get_length();
return ov::test::utils::make_fake_quantize(cfg.input, localPrc, 256, perBatchSize);
}, "FakeQuantize(PerBatch)"}}), {"FakeQuantize"}};

const auto fusingFakeQuantizeFullTensor = fusingSpecificParams{std::make_shared<postNodesMgr>(std::vector<postNodeBuilder>{
{[](postNodeConfig& cfg){
auto localPrc = cfg.input->get_element_type();
const auto shape = cfg.input->get_output_partial_shape(0);
ov::Shape fullTensorShape(shape.size(), 1);
for (size_t axis = 0; axis < shape.size(); axis++) {
fullTensorShape[axis] = shape[axis].get_length();
}
return ov::test::utils::make_fake_quantize(cfg.input, localPrc, 256, fullTensorShape);
}, "FakeQuantize(FullTensor)"}}), {"FakeQuantize"}};

const auto fusingFakeQuantizePerChannelRelu = fusingSpecificParams{std::make_shared<postNodesMgr>(std::vector<postNodeBuilder>{
{[](postNodeConfig& cfg){
auto localPrc = cfg.input->get_element_type();
Expand Down

0 comments on commit 0c598f4

Please sign in to comment.