Skip to content

Commit

Permalink
[CPU] MVN node migration on nGraph. (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
nshchego authored and mandrono committed May 3, 2021
1 parent bd78e71 commit 708e050
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 214 deletions.
2 changes: 1 addition & 1 deletion inference-engine/src/mkldnn_plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ set(LAYERS
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_split_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tensoriterator_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_tile_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_normalize_node.cpp
# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_scatter_update_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_interpolate_node.cpp
Expand Down
133 changes: 63 additions & 70 deletions inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ void MKLDNNGraphOptimizer::ApplyCommonGraphOptimizations(MKLDNNGraph &graph) {
// FuseFullyConnectedAndSimpleOperation(graph);
// graph.RemoveDroppedNodes();

// OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseMVNAndSimpleOperation");
// FuseMVNAndSimpleOperation(graph);
// graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseMVNAndSimpleOperation");
FuseMVNAndSimpleOperation(graph);
graph.RemoveDroppedNodes();

// OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseInterpolateAndSimpleOperation");
// FuseInterpolateAndSimpleOperation(graph);
Expand Down Expand Up @@ -1344,73 +1344,66 @@ void MKLDNNGraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(MKLDNNG
}

void MKLDNNGraphOptimizer::FuseMVNAndSimpleOperation(MKLDNNGraph &graph) {
// auto& graphNodes = graph.GetNodes();
//
// auto isSutableParentNode = [](MKLDNNNodePtr node) {
// bool isSutableMVN = (node->getType() == MVN) && (node->inDims[0].ndims() == 4 || node->inDims[0].ndims() == 5);
//
// if (isSutableMVN) {
// auto *mvnLayer = dynamic_cast<MVNLayer *>(node->getCnnLayer().get());
// if (mvnLayer == nullptr)
// IE_THROW() << "Cannot get MVN layer " << node->getName();
//
// return node->getChildEdges().size() == 1 && mvnLayer->across_channels == 0 && mvnLayer->normalize == 1;
// } else {
// return false;
// }
// };
//
// auto isSutableChildNode = [](MKLDNNNodePtr node) {
// if (!node->getCnnLayer())
// return false;
//
// if (node->getType() == Quantize) {
// auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(node.get());
// if (quantizeNode == nullptr)
// IE_THROW() << "Cannot get quantize layer " << node->getName();
// return !quantizeNode->isBinarization();
// } else if (node->getType() == Eltwise) {
// auto* eltwiseNode = dynamic_cast<MKLDNNEltwiseNode *>(node.get());
// if (eltwiseNode == nullptr)
// IE_THROW() << "Cannot get eltwise node " << node->getName();
//
// return ((eltwiseNode->getOpType() == MulAdd) ||
// (eltwiseNode->getOpType() == Prelu) ||
// eltwiseNode->getOpType() == Relu);
// }
//
// return false;
// };
//
// auto parent = graphNodes.begin();
// while (parent != graphNodes.end()) {
// auto parentNode = *parent;
// if (!isSutableParentNode(parentNode)) {
// parent++;
// continue;
// }
//
// auto childNode = parentNode->getChildEdgeAt(0)->getChild();
// if (!isSutableChildNode(childNode)) {
// parent++;
// continue;
// }
//
// parentNode->fuseWith(childNode);
//
// if (childNode->getType() == Quantize || childNode->getType() == Eltwise) {
// auto parentEdges = childNode->parentEdges;
// for (auto &parentEdge : parentEdges) {
// auto p_edge = parentEdge.lock();
// if (p_edge->getParent()->getType() == MVN)
// continue;
//
// removeEdge(graph, p_edge);
// }
// }
//
// graph.DropNode(childNode);
// }
auto& graphNodes = graph.GetNodes();

auto isSutableParentNode = [](MKLDNNNodePtr node) {
bool isSutableMVN = (node->getType() == MVN) && (node->inDims[0].ndims() == 4 || node->inDims[0].ndims() == 5);

if (isSutableMVN) {
auto mvnNode = std::dynamic_pointer_cast<MKLDNNMVNNode>(node);
if (mvnNode == nullptr)
IE_THROW() << "CPU node with name '" << node->getName() << "' is not a MVN node.";

return mvnNode->getChildEdges().size() == 1 && !mvnNode->getAcrossChannels() && mvnNode->getNormalizeVariance();
} else {
return false;
}
};

auto isSutableChildNode = [](MKLDNNNodePtr node) {
if (node->getType() == Quantize) {
auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(node.get());
if (quantizeNode == nullptr)
IE_THROW() << "CPU node with name '" << node->getName() << "' is not a Quantize node.";
return !quantizeNode->isBinarization();
} else if (node->getType() == Eltwise) {
return ((node->getAlgorithm() == EltwiseMulAdd) ||
(node->getAlgorithm() == EltwisePrelu) ||
(node->getAlgorithm() == EltwiseRelu));
}

return false;
};

auto parent = graphNodes.begin();
while (parent != graphNodes.end()) {
auto parentNode = *parent;
if (!isSutableParentNode(parentNode)) {
parent++;
continue;
}

auto childNode = parentNode->getChildEdgeAt(0)->getChild();
if (!isSutableChildNode(childNode)) {
parent++;
continue;
}

parentNode->fuseWith(childNode);

if (childNode->getType() == Quantize || childNode->getType() == Eltwise) {
auto parentEdges = childNode->parentEdges;
for (auto &parentEdge : parentEdges) {
auto p_edge = parentEdge.lock();
if (p_edge->getParent()->getType() == MVN)
continue;

removeEdge(graph, p_edge);
}
}

graph.DropNode(childNode);
}
}

void MKLDNNGraphOptimizer::FuseInterpolateAndSimpleOperation(MKLDNNGraph &graph) {
Expand Down
130 changes: 84 additions & 46 deletions inference-engine/src/mkldnn_plugin/mkldnn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,53 +120,91 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
{ "Mod", Eltwise },
{ "Power", Eltwise },
{ "Reshape", Reshape },
{ "Tile", Tile },
{ "SimplerNMS", SimplerNMS },
{ "ROIAlign", ROIAlign },
{ "ROIPooling", ROIPooling },
{ "BatchNormalization", BatchNormalization },
{ "DepthToSpace", DepthToSpace },
{ "Flatten", Flatten },
{ "Pad", Pad },
{ "Permute", Permute },
{ "SpaceToDepth", SpaceToDepth },
{ "StridedSlice", StridedSlice },
{ "Copy", Copy },
{ "LSTMCell", RNNCell },
{ "GRUCell", RNNCell },
{ "RNNCell", RNNCell },
{ "LSTMSequence", RNNSeq },
{ "GRUSequence", RNNSeq },
{ "RNNSequence", RNNSeq },
{ "Quantize", Quantize },
{ "FakeQuantize", Quantize },
{ "BinaryConvolution", BinaryConvolution },
{ "DeformableConvolution", DeformableConvolution },
{ "TensorIterator", TensorIterator },
{ "Loop", TensorIterator },
{ "MemoryInput", MemoryInput}, // for construction from name ctor, arbitrary name is used
{ "Memory", MemoryOutput }, // for construction from layer ctor
{ "Convert", Convert },
{ "MVN", MVN},
{ "Normalize", Normalize},
{ "ScatterUpdate", ScatterUpdate},
{ "ScatterElementsUpdate", ScatterElementsUpdate},
{ "ScatterNDUpdate", ScatterNDUpdate},
{ "Interpolate", Interpolate},
{ "ReduceAnd", ReduceAnd},
{ "ReduceL1", ReduceL1},
{ "ReduceL2", ReduceL2},
{ "ReduceLogSum", ReduceLogSum},
{ "ReduceLogSumExp", ReduceLogSumExp},
{ "ReduceMax", ReduceMax},
{ "ReduceMean", ReduceMean},
{ "ReduceMin", ReduceMin},
{ "ReduceOr", ReduceOr},
{ "ReduceProd", ReduceProd},
{ "ReduceSum", ReduceSum},
{ "ReduceSumSquare", ReduceSumSquare},
{ "Erf", Eltwise },
{ "Softmax", Softmax },
{ "Reorder", Reorder },
{ "Roll", Roll },

// { "Unknown", Unknown },
// { "Input", Input },
// { "Reorder", Reorder },
// { "Convolution", Convolution },
// { "ReLU", Eltwise },
// { "GELU", Eltwise },
// { "ELU", Eltwise },
// { "Sigmoid", Eltwise },
// { "Logistic", Eltwise },
// { "TanH", Eltwise },
// { "ReLU6", Eltwise },
// { "Exp", Eltwise },
// { "Not", Eltwise },
// { "Activation", Eltwise },
// { "Clamp", Eltwise },
// { "Swish", Eltwise },
// { "HSwish", Eltwise },
// { "Mish", Eltwise },
// { "HSigmoid", Eltwise },
// { "Round", Eltwise },
// { "ScaleShift", Eltwise },
// { "PReLU", Eltwise },
// { "Norm", Lrn },
// { "LRN", Lrn },
// { "Pooling", Pooling },
// { "FullyConnected", FullyConnected },
// { "InnerProduct", FullyConnected },
// { "Gemm", Gemm },
// { "Softmax", SoftMax },
// { "SoftMax", SoftMax },
// { "Split", Split },
// { "Slice", Split },
// { "Concat", Concatenation },
// { "Deconvolution", Deconvolution },
// { "Eltwise", Eltwise },
// { "Mod", Eltwise },
// { "Power", Eltwise },
// { "Crop", Crop },
// { "Reshape", Reshape },
// { "Tile", Tile },
// { "SimplerNMS", SimplerNMS },
// { "ROIAlign", ROIAlign },
// { "ROIPooling", ROIPooling },
// { "BatchNormalization", BatchNormalization },
// { "Flatten", Flatten },
// { "Pad", Pad },
// { "Permute", Permute },
// { "Copy", Copy },
// { "LSTMCell", RNNCell },
// { "GRUCell", RNNCell },
// { "RNNCell", RNNCell },
// { "LSTMSequence", RNNSeq },
// { "GRUSequence", RNNSeq },
// { "RNNSequence", RNNSeq },
// { "Quantize", Quantize },
// { "FakeQuantize", Quantize },
// { "BinaryConvolution", BinaryConvolution },
// { "DeformableConvolution", DeformableConvolution },
// { "TensorIterator", TensorIterator },
// { "Loop", TensorIterator },
// { "MemoryInput", MemoryInput}, // for construction from name ctor, arbitrary name is used
// { "Memory", MemoryOutput }, // for construction from layer ctor
// { "Convert", Convert },
{ "MVN", MVN},
// { "Normalize", Normalize},
// { "ScatterUpdate", ScatterUpdate},
// { "ScatterElementsUpdate", ScatterElementsUpdate},
// { "ScatterNDUpdate", ScatterNDUpdate},
// { "Interpolate", Interpolate},
// { "ReduceAnd", ReduceAnd},
// { "ReduceL1", ReduceL1},
// { "ReduceL2", ReduceL2},
// { "ReduceLogSum", ReduceLogSum},
// { "ReduceLogSumExp", ReduceLogSumExp},
// { "ReduceMax", ReduceMax},
// { "ReduceMean", ReduceMean},
// { "ReduceMin", ReduceMin},
// { "ReduceOr", ReduceOr},
// { "ReduceProd", ReduceProd},
// { "ReduceSum", ReduceSum},
// { "ReduceSumSquare", ReduceSumSquare},
};

Type TypeFromName(const std::string type) {
Expand Down
Loading

0 comments on commit 708e050

Please sign in to comment.