Skip to content

Commit

Permalink
[CPU] MVN node migration on nGraph. (openvinotoolkit#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
nshchego authored Feb 25, 2021
1 parent 17a78bf commit b247886
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 170 deletions.
2 changes: 1 addition & 1 deletion inference-engine/src/mkldnn_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ set(LAYERS
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_split_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tensoriterator_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tile_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_normalize_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_scatter_update_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_interpolate_node.cpp
Expand Down
132 changes: 62 additions & 70 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ void MKLDNNGraphOptimizer::ApplyCommonGraphOptimizations(MKLDNNGraph &graph) {
// FuseFullyConnectedAndSimpleOperation(graph);
// graph.RemoveDroppedNodes();

// TODO [NM]: transformation should be implemented w/o using of CNNLayer
// FuseMVNAndSimpleOperation(graph);
// graph.RemoveDroppedNodes();
FuseMVNAndSimpleOperation(graph);
graph.RemoveDroppedNodes();

FuseInterpolateAndSimpleOperation(graph);
graph.RemoveDroppedNodes();
Expand Down Expand Up @@ -1351,73 +1350,66 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
}

void MKLDNNGraphOptimizer::FuseMVNAndSimpleOperation(MKLDNNGraph &graph) {
// auto& graphNodes = graph.GetNodes();
//
// auto isSutableParentNode = [](MKLDNNNodePtr node) {
// bool isSutableMVN = (node->getType() == MVN) && (node->inDims[0].ndims() == 4 || node->inDims[0].ndims() == 5);
//
// if (isSutableMVN) {
// auto *mvnLayer = dynamic_cast<MVNLayer *>(node->getCnnLayer().get());
// if (mvnLayer == nullptr)
// THROW_IE_EXCEPTION << "Cannot get MVN layer " << node->getName();
//
// return node->getChildEdges().size() == 1 && mvnLayer->across_channels == 0 && mvnLayer->normalize == 1;
// } else {
// return false;
// }
// };
//
// auto isSutableChildNode = [](MKLDNNNodePtr node) {
// if (!node->getCnnLayer())
// return false;
//
// if (node->getType() == Quantize) {
// auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(node.get());
// if (quantizeNode == nullptr)
// THROW_IE_EXCEPTION << "Cannot get quantize layer " << node->getName();
// return !quantizeNode->isBinarization();
// } else if (node->getType() == Eltwise) {
// auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
// if (eltwiseNode == nullptr)
// THROW_IE_EXCEPTION << "Cannot get eltwise node " << node->getName();
//
// return ((eltwiseNode->getOpType() == MulAdd) ||
// (eltwiseNode->getOpType() == Prelu) ||
// eltwiseNode->getOpType() == Relu);
// }
//
// return false;
// };
//
// auto parent = graphNodes.begin();
// while (parent != graphNodes.end()) {
// auto parentNode = *parent;
// if (!isSutableParentNode(parentNode)) {
// parent++;
// continue;
// }
//
// auto childNode = parentNode->getChildEdgeAt(0)->getChild();
// if (!isSutableChildNode(childNode)) {
// parent++;
// continue;
// }
//
// parentNode->fuseWith(childNode);
//
// if (childNode->getType() == Quantize || childNode->getType() == Eltwise) {
// auto parentEdges = childNode->parentEdges;
// for (auto &parentEdge : parentEdges) {
// auto p_edge = parentEdge.lock();
// if (p_edge->getParent()->getType() == MVN)
// continue;
//
// removeEdge(graph, p_edge);
// }
// }
//
// graph.DropNode(childNode);
// }
auto& graphNodes = graph.GetNodes();

auto isSutableParentNode = [](MKLDNNNodePtr node) {
bool isSutableMVN = (node->getType() == MVN) && (node->inDims[0].ndims() == 4 || node->inDims[0].ndims() == 5);

if (isSutableMVN) {
auto mvnNode = std::dynamic_pointer_cast<MKLDNNMVNNode>(node);
if (mvnNode == nullptr)
THROW_IE_EXCEPTION << "CPU node with name '" << node->getName() << "' is not a MVN node.";

return mvnNode->getChildEdges().size() == 1 && !mvnNode->getAcrossChannels() && mvnNode->getNormalizeVariance();
} else {
return false;
}
};

auto isSutableChildNode = [](MKLDNNNodePtr node) {
if (node->getType() == Quantize) {
auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(node.get());
if (quantizeNode == nullptr)
THROW_IE_EXCEPTION << "CPU node with name '" << node->getName() << "' is not a Quantize node.";
return !quantizeNode->isBinarization();
} else if (node->getType() == Eltwise) {
return ((node->getAlgorithm() == EltwiseMulAdd) ||
(node->getAlgorithm() == EltwisePrelu) ||
(node->getAlgorithm() == EltwiseRelu));
}

return false;
};

auto parent = graphNodes.begin();
while (parent != graphNodes.end()) {
auto parentNode = *parent;
if (!isSutableParentNode(parentNode)) {
parent++;
continue;
}

auto childNode = parentNode->getChildEdgeAt(0)->getChild();
if (!isSutableChildNode(childNode)) {
parent++;
continue;
}

parentNode->fuseWith(childNode);

if (childNode->getType() == Quantize || childNode->getType() == Eltwise) {
auto parentEdges = childNode->parentEdges;
for (auto &parentEdge : parentEdges) {
auto p_edge = parentEdge.lock();
if (p_edge->getParent()->getType() == MVN)
continue;

removeEdge(graph, p_edge);
}
}

graph.DropNode(childNode);
}
}

void MKLDNNGraphOptimizer::FuseInterpolateAndSimpleOperation(MKLDNNGraph &graph) {
Expand Down
2 changes: 1 addition & 1 deletion inference-engine/src/mkldnn_plugin/mkldnn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
// { "MemoryInput", MemoryInput}, // for construction from name ctor, arbitrary name is used
// { "Memory", MemoryOutput }, // for construction from layer ctor
// { "Convert", Convert },
// { "MVN", MVN},
{ "MVN", MVN},
// { "Normalize", Normalize},
// { "ScatterUpdate", ScatterUpdate},
// { "ScatterElementsUpdate", ScatterElementsUpdate},
Expand Down
Loading

0 comments on commit b247886

Please sign in to comment.