Skip to content

Commit

Permalink
[CPU] prohibit fusing if dropped node contain > 1 child edges (#6705)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim Andronov authored Jul 30, 2021
1 parent e3e2ee4 commit 2b871eb
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 8 deletions.
17 changes: 13 additions & 4 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
return false;

auto arg0 = parent0->getParentEdgesAtPort(1)[0]->getParent();
if (arg0->getType() == Input && arg0->isConstant()) {
if (arg0->getChildEdges().size() == 1 && arg0->getType() == Input && arg0->isConstant()) {
if (arg0->getOriginalOutputPrecisionAtPort(0) != Precision::U8)
return false;

Expand Down Expand Up @@ -864,7 +864,11 @@ void MKLDNNGraphOptimizer::FusePoolingAndFakeQuantize(MKLDNNGraph &graph) {
};

auto isSutableChildNode = [](MKLDNNNodePtr node) {
return node->getType() == FakeQuantize && node->getAlgorithm() != Algorithm::FQBinarization;
bool ret = node->getType() == FakeQuantize && node->getAlgorithm() != Algorithm::FQBinarization;
for (size_t i = 1; i < node->getParentEdges().size(); i++) {
ret &= node->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
return ret;
};

for (int i = 0; i < graphNodes.size(); i++) {
Expand Down Expand Up @@ -1440,12 +1444,17 @@ void MKLDNNGraphOptimizer::FuseBroadcastAndEltwise(MKLDNNGraph &graph) {
std::vector<MKLDNNNodePtr>& graphNodes = graph.GetNodes();

for (auto &graphNode : graphNodes) {
if (graphNode->getType() != Generic
|| graphNode->getTypeStr() != "Broadcast"
if (graphNode->getType() != Broadcast
|| graphNode->getChildEdges().size() != 1lu
|| graphNode->getChildEdgeAt(0)->getChild()->getType() != Eltwise)
continue;

bool ret = true;
for (size_t i = 1; i < graphNode->getParentEdges().size(); i++) {
ret &= graphNode->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
if (!ret) continue;

MKLDNNNodePtr& broadcastNode = graphNode;
MKLDNNNodePtr eltwiseNode = broadcastNode->getChildEdgeAt(0)->getChild();
eltwiseNode->inDims[broadcastNode->getChildEdgeAt(0)->getOutputNum()]
Expand Down
8 changes: 6 additions & 2 deletions inference-engine/src/mkldnn_plugin/mkldnn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ bool MKLDNNNode::canBePerformedAsScaleShift(const MKLDNNNode *parentNode) const
if (i == fusingPort)
continue;
auto weightShape = getParentEdgeAt(i)->getDims().ToSizeVector();
if (!isPerTensorOrPerChannelBroadcastable(dataShape, weightShape))
if (getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() != 1 || !isPerTensorOrPerChannelBroadcastable(dataShape, weightShape))
return false;
}
return true;
Expand All @@ -1353,7 +1353,11 @@ bool MKLDNNNode::canBePerformedAsScaleShift(const MKLDNNNode *parentNode) const

bool MKLDNNNode::canFuseSimpleOperation(const MKLDNNNodePtr& node) const {
if (node->getType() == FakeQuantize) {
return node->getAlgorithm() != FQBinarization;
bool ret = node->getAlgorithm() != FQBinarization;
for (size_t i = 1; i < node->getParentEdges().size(); i++) {
ret &= node->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
return ret;
} else if (node->getType() == Eltwise) {
return one_of(node->getAlgorithm(), EltwiseRelu, EltwiseGelu, EltwiseElu, EltwiseSigmoid, EltwiseClamp, EltwiseTanh,
EltwiseSwish, EltwiseHswish, EltwiseMish, EltwiseHsigmoid, EltwiseRoundHalfToEven,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,11 @@ bool MKLDNNBinaryConvolutionNode::canFuse(const MKLDNNNodePtr& node) const {
return false;

if (node->getType() == FakeQuantize) {
return node->getAlgorithm() == FQBinarization;
bool ret = node->getAlgorithm() == FQBinarization;
for (size_t i = 1; i < node->getParentEdges().size(); i++) {
ret &= node->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
return ret;
} else {
return canFuseSimpleOperation(node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
}

if (getParentEdges().size() != expectedInputEdgesNum)
IE_THROW() << "Incorrect number of input edges for layer " << getName();
IE_THROW() << "Incorrect number of input edges for layer " << getName() << ", expected: " << expectedInputEdgesNum
<< " actual: " << getParentEdges().size();
if (getChildEdges().empty())
IE_THROW() << "Incorrect number of output edges for layer " << getName();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ngraph_functions/builders.hpp"
#include "test_utils/cpu_test_utils.hpp"

using namespace ngraph;
using ngraph::helpers::EltwiseTypes;

namespace SubgraphTestsDefinitions {

class NotFusedConvSimpleOp : public LayerTestsUtils::LayerTestsCommon {
protected:
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;

auto inputParams = builder::makeParams(element::f32, {{1, 3, 12, 9}, {1, 16, 12, 9}});
auto paramOuts = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(inputParams));

std::shared_ptr<Node> conv;
{
const std::vector<size_t> kernelSize = {3, 3};
const std::vector<size_t> strides = {1, 1};
const std::vector<ptrdiff_t> padBegin = {0, 0};
const std::vector<ptrdiff_t> padEnd = {0, 0};
const std::vector<size_t> dilation = {1, 1};
const size_t numOutChannels = 16;
const op::PadType paddingType = op::PadType::EXPLICIT;
conv = builder::makeConvolution(paramOuts[0], element::f32, kernelSize, strides, padBegin, padEnd, dilation, paddingType, numOutChannels);
}
const auto sharedNode = builder::makeConstant(element::f32, {1, 16, 1, 1}, std::vector<float>{}, true);
const auto postOpCandidate = builder::makeEltwise(conv, sharedNode, EltwiseTypes::ADD);

const auto secondConsumpt = builder::makeEltwise(paramOuts[1], sharedNode, EltwiseTypes::ADD);

NodeVector results{postOpCandidate, secondConsumpt};
function = std::make_shared<ngraph::Function>(results, inputParams, "NotFusedConvSimpleOp");
}
};

TEST_F(NotFusedConvSimpleOp, smoke_CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()

Run();
}

} // namespace SubgraphTestsDefinitions

0 comments on commit 2b871eb

Please sign in to comment.