Skip to content

Commit

Permalink
[CPU] prohibit fusing if droped node contain > 1 child edges
Browse files Browse the repository at this point in the history
  • Loading branch information
mandrono committed Jul 20, 2021
1 parent 960ba48 commit d877c3f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
17 changes: 13 additions & 4 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndZeroPoints(MKLDNNGraph &graph) {
return false;

auto arg0 = parent0->getParentEdgesAtPort(1)[0]->getParent();
if (arg0->getType() == Input && arg0->isConstant()) {
if (arg0->getChildEdges().size() == 1 && arg0->getType() == Input && arg0->isConstant()) {
if (arg0->getOriginalOutputPrecisionAtPort(0) != Precision::U8)
return false;

Expand Down Expand Up @@ -864,7 +864,11 @@ void MKLDNNGraphOptimizer::FusePoolingAndFakeQuantize(MKLDNNGraph &graph) {
};

auto isSutableChildNode = [](MKLDNNNodePtr node) {
return node->getType() == FakeQuantize && node->getAlgorithm() != Algorithm::FQBinarization;
bool ret = node->getType() == FakeQuantize && node->getAlgorithm() != Algorithm::FQBinarization;
for (size_t i = 1; i < node->getParentEdges().size(); i++) {
ret &= node->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
return ret;
};

for (int i = 0; i < graphNodes.size(); i++) {
Expand Down Expand Up @@ -1440,12 +1444,17 @@ void MKLDNNGraphOptimizer::FuseBroadcastAndEltwise(MKLDNNGraph &graph) {
std::vector<MKLDNNNodePtr>& graphNodes = graph.GetNodes();

for (auto &graphNode : graphNodes) {
if (graphNode->getType() != Generic
|| graphNode->getTypeStr() != "Broadcast"
if (graphNode->getType() != Broadcast
|| graphNode->getChildEdges().size() != 1lu
|| graphNode->getChildEdgeAt(0)->getChild()->getType() != Eltwise)
continue;

bool ret = true;
for (size_t i = 1; i < graphNode->getParentEdges().size(); i++) {
ret &= graphNode->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
if (!ret) continue;

MKLDNNNodePtr& broadcastNode = graphNode;
MKLDNNNodePtr eltwiseNode = broadcastNode->getChildEdgeAt(0)->getChild();
eltwiseNode->inDims[broadcastNode->getChildEdgeAt(0)->getOutputNum()]
Expand Down
8 changes: 6 additions & 2 deletions inference-engine/src/mkldnn_plugin/mkldnn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ bool MKLDNNNode::canBePerformedAsScaleShift(const MKLDNNNode *parentNode) const
if (i == fusingPort)
continue;
auto weightShape = getParentEdgeAt(i)->getDims().ToSizeVector();
if (!isPerTensorOrPerChannelBroadcastable(dataShape, weightShape))
if (getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() != 1 || !isPerTensorOrPerChannelBroadcastable(dataShape, weightShape))
return false;
}
return true;
Expand All @@ -1351,7 +1351,11 @@ bool MKLDNNNode::canBePerformedAsScaleShift(const MKLDNNNode *parentNode) const

bool MKLDNNNode::canFuseSimpleOperation(const MKLDNNNodePtr& node) const {
if (node->getType() == FakeQuantize) {
return node->getAlgorithm() != FQBinarization;
bool ret = node->getAlgorithm() != FQBinarization;
for (size_t i = 1; i < node->getParentEdges().size(); i++) {
ret &= node->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
return ret;
} else if (node->getType() == Eltwise) {
return one_of(node->getAlgorithm(), EltwiseRelu, EltwiseGelu, EltwiseElu, EltwiseSigmoid, EltwiseClamp, EltwiseTanh,
EltwiseSwish, EltwiseHswish, EltwiseMish, EltwiseHsigmoid, EltwiseRoundHalfToEven,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,11 @@ bool MKLDNNBinaryConvolutionNode::canFuse(const MKLDNNNodePtr& node) const {
return false;

if (node->getType() == FakeQuantize) {
return node->getAlgorithm() == FQBinarization;
bool ret = node->getAlgorithm() == FQBinarization;
for (size_t i = 1; i < node->getParentEdges().size(); i++) {
ret &= node->getParentEdgesAtPort(i)[0]->getParent()->getChildEdges().size() == 1;
}
return ret;
} else {
return canFuseSimpleOperation(node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
}

if (getParentEdges().size() != expectedInputEdgesNum)
IE_THROW() << "Incorrect number of input edges for layer " << getName();
IE_THROW() << "Incorrect number of input edges for layer " << getName() << ", expected: " << expectedInputEdgesNum
<< " actual: " << getParentEdges().size();
if (getChildEdges().empty())
IE_THROW() << "Incorrect number of output edges for layer " << getName();

Expand Down

0 comments on commit d877c3f

Please sign in to comment.