diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp index c3fd996cfe0f3a..f63681f387c834 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp @@ -158,12 +158,11 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndBias(MKLDNNGraph &graph) { }; auto isSutableChildNode = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) { - if ((parentNode->isConstant() && !childNode->isConstant()) || childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() || - childNode->getParentEdges().size() != 2) + if (childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() || childNode->getParentEdges().size() != 2) return false; auto biasNode = childNode->getParentEdgesAtPort(1)[0]->getParent(); - if (biasNode->getChildEdges().size() != 1) + if (biasNode->getType() != Input || !biasNode->isConstant() || biasNode->getChildEdges().size() != 1) return false; auto convOutDims = parentNode->getChildEdgesAtPort(0)[0]->getDims().ToSizeVector(); @@ -302,6 +301,8 @@ void MKLDNNGraphOptimizer::FuseMultiplyAndAdd(MKLDNNGraph &graph) { auto& graphNodes = graph.GetNodes(); auto isSutableSecondInput = [](MKLDNNNodePtr node, MKLDNNDims dataDims) { + if (node->getType() != Input || !node->isConstant()) + return false; auto secondInputDims = node->outDims[0]; if (secondInputDims.ndims() != dataDims.ndims() || secondInputDims.ndims() < 2) return false; @@ -326,8 +327,7 @@ void MKLDNNGraphOptimizer::FuseMultiplyAndAdd(MKLDNNGraph &graph) { }; auto isSutableChildNode = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) { - if ((parentNode->isConstant() && !childNode->isConstant()) || childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() || - childNode->getParentEdges().size() != 2) + if (childNode->getAlgorithm() != EltwiseAdd || !childNode->getFusedWith().empty() || childNode->getParentEdges().size() != 2) return false; return isSutableSecondInput(childNode->getParentEdgesAtPort(1)[0]->getParent(), childNode->getParentEdgesAtPort(0)[0]->getDims()); @@ -1518,9 +1518,9 @@ void MKLDNNGraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(MKLDNNGraph auto& graphNodes = graph.GetNodes(); auto getConstPort = [](const MKLDNNNodePtr node) -> int { - if (node->getParentEdgeAt(0)->getParent()->isConstant() && node->getParentEdgeAt(0)->getParent()->getType() == Input) { + if (node->getParentEdgeAt(0)->getParent()->getType() == Input && node->getParentEdgeAt(0)->getParent()->isConstant()) { return 0; - } else if (node->getParentEdgeAt(1)->getParent()->isConstant() && node->getParentEdgeAt(1)->getParent()->getType() == Input) { + } else if (node->getParentEdgeAt(1)->getParent()->getType() == Input && node->getParentEdgeAt(1)->getParent()->isConstant()) { return 1; } else { return -1; diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index 65c662b15c0e49..e2e2a3276b8c78 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -1296,7 +1296,7 @@ bool MKLDNNNode::canBePerformedAsScaleShift(const MKLDNNNode *parentNode) const fusingPort = i; continue; } - if (!node->isConstant() || node->getType() != Input) { + if (node->getType() != Input || !node->isConstant()) { return false; } }