From 6a3452a18f00007cc0580acfcb75ec21c46dde00 Mon Sep 17 00:00:00 2001 From: Anton Voronov Date: Thu, 24 Jun 2021 00:09:49 +0300 Subject: [PATCH] [CPU] fixed conv + dw conv fusing --- .../src/mkldnn_plugin/mkldnn_graph_optimizer.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp index abcb5afab82115..bcd6e2affb25e5 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp @@ -641,6 +641,9 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) { }; auto isSutableParentConvolution = [&](MKLDNNNodePtr node) { + if (node->isDropped()) + return false; + const auto conv = std::dynamic_pointer_cast(node); if (conv == nullptr) IE_THROW() << "Cannot cast to convolution node " << node->getName(); @@ -649,17 +652,26 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) { return false; const auto &strides = conv->getStride(); + const auto &paddings = conv->getPaddingL(); + const auto &inDims = node->getParentEdgeAt(0)->getDims(); + const auto &outDims = node->getChildEdgeAt(0)->getDims(); bool isSupportedParams = conv->getGroupNum() == 1 && + inDims.ndims() == 4 && + inDims[inDims.ndims() - 1] == outDims[outDims.ndims() - 1] && + inDims[inDims.ndims() - 2] == outDims[outDims.ndims() - 2] && is1x1Convolution(conv) && // TODO [oneDNN] : fusing is permitted only with 1x1 convolutions everyone_is(1, strides[strides.size() - 1], strides[strides.size() - 2]) && - !conv->canBeExecutedInInt8() && - node->getChildEdgeAt(0)->getDims().ndims() == 4; + everyone_is(0, paddings[paddings.size() - 1], paddings[paddings.size() - 2]) && + !conv->canBeExecutedInInt8(); if (!isSupportedParams) return false; return node->getChildEdges().size() == 1 && isConvolutionNode(node->getChildEdgeAt(0)->getChild()); }; auto isSutableChildConvolution = [&](const MKLDNNNodePtr &parentNode, const MKLDNNNodePtr &childNode) { + if (parentNode->isDropped() || childNode->isDropped()) + return false; + const auto convChild = std::dynamic_pointer_cast(childNode); if (convChild == nullptr) IE_THROW() << "Cannot cast to convolution node " << childNode->getName();