From 0849f6cbf8a1e4858f2f5906afdf0c82d490cf66 Mon Sep 17 00:00:00 2001 From: Maxim Andronov Date: Mon, 19 Apr 2021 10:01:55 +0300 Subject: [PATCH] Performance problems fixes. Part 2 (#50) * Performance problems fixes. Part 2 * additional fixes * dw fixes * int8 pooling fusing fix * moved transformation to ngraph * [CPU] Select node migration on nGraph * [CPU] DepthToSpace nodes migration on nGraph * [CPU] SpaceToDepth nodes migration on nGraph * added check that op is supported --- .../src/mkldnn_plugin/CMakeLists.txt | 6 +- .../emitters/jit_eltwise_emitters.cpp | 14 +- .../mkldnn_plugin/mkldnn_graph_optimizer.cpp | 202 +++++++------- .../src/mkldnn_plugin/mkldnn_node.cpp | 8 +- .../convert_to_cpu_specific_opset.hpp | 4 + .../convert_to_leaky_relu.cpp | 38 +++ .../convert_to_leaky_relu.hpp | 17 ++ .../convert_to_power_static.cpp | 124 +++++++++ .../convert_to_power_static.hpp | 17 ++ .../ngraph_transformations/op/leaky_relu.cpp | 31 +++ .../ngraph_transformations/op/leaky_relu.hpp | 33 +++ .../op/power_static.cpp | 35 +++ .../op/power_static.hpp | 34 +++ .../mkldnn_plugin/nodes/depth_to_space.cpp | 71 +++-- .../src/mkldnn_plugin/nodes/list_tbl.hpp | 6 +- .../mkldnn_plugin/nodes/mkldnn_conv_node.cpp | 254 ++++++------------ .../mkldnn_plugin/nodes/mkldnn_conv_node.h | 22 +- .../nodes/mkldnn_eltwise_node.cpp | 20 +- .../mkldnn_plugin/nodes/mkldnn_eltwise_node.h | 1 + .../nodes/mkldnn_fake_quantize_node.cpp | 32 +-- .../nodes/mkldnn_fake_quantize_node.h | 3 + .../src/mkldnn_plugin/nodes/select.cpp | 147 +++++----- .../mkldnn_plugin/nodes/space_to_depth.cpp | 72 +++-- .../skip_tests_config.cpp | 3 - .../convert_to_plugin_specific_node.cpp | 120 +++++++++ ngraph/core/src/op/depth_to_space.cpp | 2 +- ngraph/core/src/op/space_to_depth.cpp | 2 +- 27 files changed, 902 insertions(+), 416 deletions(-) create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.cpp create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.hpp create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.cpp create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.hpp create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.cpp create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.hpp create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.cpp create mode 100644 inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.hpp create mode 100644 inference-engine/tests/functional/plugin/cpu/single_layer_tests/convert_to_plugin_specific_node.cpp diff --git a/inference-engine/src/mkldnn_plugin/CMakeLists.txt b/inference-engine/src/mkldnn_plugin/CMakeLists.txt index 7df0735bd7175a..505c267aad80a8 100644 --- a/inference-engine/src/mkldnn_plugin/CMakeLists.txt +++ b/inference-engine/src/mkldnn_plugin/CMakeLists.txt @@ -54,7 +54,7 @@ set(LAYERS ${CMAKE_CURRENT_SOURCE_DIR}/nodes/ctc_greedy_decoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/ctc_greedy_decoder_seq_len.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/ctc_loss.cpp -# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/depth_to_space.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/nodes/depth_to_space.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/detectionoutput.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/detectionoutput_onnx.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/embedding_bag_offset_sum.cpp @@ -83,11 +83,11 @@ set(LAYERS ${CMAKE_CURRENT_SOURCE_DIR}/nodes/reorg_yolo.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/nodes/reverse_sequence.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/nodes/roifeatureextractor_onnx.cpp -# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/select.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/nodes/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/nodes/shuffle_channels.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/nodes/simplernms.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/nodes/space_to_batch.cpp -# ${CMAKE_CURRENT_SOURCE_DIR}/nodes/space_to_depth.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/nodes/space_to_depth.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/nodes/sparse_fill_empty_rows.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/nodes/sparse_segment_reduce.cpp # ${CMAKE_CURRENT_SOURCE_DIR}/nodes/sparse_weighted_reduce.cpp diff --git a/inference-engine/src/mkldnn_plugin/emitters/jit_eltwise_emitters.cpp b/inference-engine/src/mkldnn_plugin/emitters/jit_eltwise_emitters.cpp index 6d87d548788a25..29c17d3f172ba1 100644 --- a/inference-engine/src/mkldnn_plugin/emitters/jit_eltwise_emitters.cpp +++ b/inference-engine/src/mkldnn_plugin/emitters/jit_eltwise_emitters.cpp @@ -5,6 +5,7 @@ #include "jit_eltwise_emitters.hpp" #include #include +#include using namespace InferenceEngine; using namespace mkldnn::impl::utils; @@ -1303,13 +1304,16 @@ jit_power_static_emitter::jit_power_static_emitter(jit_generator *host, cpu_isa_ prepare_table(); } + jit_power_static_emitter::jit_power_static_emitter(jit_generator *host, cpu_isa_t host_isa, const MKLDNNNode* node, Precision exec_prc) : jit_emitter(host, host_isa, node, exec_prc) { - IE_THROW() << "[NM] Not implemented"; - -// power = powerLayer->power; -// scale = powerLayer->scale; -// shift = powerLayer->offset; + const MKLDNNEltwiseNode *powerNode = dynamic_cast(node); + if (powerNode == nullptr) { + IE_THROW() << "Can't cast to MKLDNNEltwiseNode"; + } + power = powerNode->getAlpha(); + scale = powerNode->getBeta(); + shift = powerNode->getGamma(); prepare_table(); } diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp index fa3d028ae003fd..92a466ffeb5bed 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp @@ -93,9 +93,8 @@ void MKLDNNGraphOptimizer::ApplyCommonGraphOptimizations(MKLDNNGraph &graph) { graph.SortTopologically(); graph.RemoveDroppedEdges(); -// TODO [NM]: transformation should be implemented w/o using of CNNLayer -// FuseConvolutionAndDWConvolution(graph); -// graph.RemoveDroppedNodes(); + FuseConvolutionAndDWConvolution(graph); + graph.RemoveDroppedNodes(); FuseBinaryConvolutionAndFakeQuantize(graph); graph.RemoveDroppedNodes(); @@ -777,134 +776,120 @@ void MKLDNNGraphOptimizer::FuseConvolutionAndDepthwise(MKLDNNGraph &graph) { } void MKLDNNGraphOptimizer::FuseConvolutionAndDWConvolution(MKLDNNGraph &graph) { - // auto& graphNodes = graph.GetNodes(); - - // auto isConvolutionNode = [](MKLDNNNodePtr node) { - // return node->getType() == Convolution; - // }; - - // auto is1x1Convolution = [](ConvolutionLayer* layer) { - // return layer->_kernel[X_AXIS] == 1 && layer->_kernel[Y_AXIS] == 1; - // }; - - // auto isSutableParentConvolution = [&](MKLDNNNodePtr node) { - // auto *layer = dynamic_cast(node->getCnnLayer().get()); - // if (layer == nullptr) - // IE_THROW() << "Cannot get convolution layer " << node->getName(); - - // auto* parentConvolutionNode = dynamic_cast(node.get()); - // if (parentConvolutionNode == nullptr) - // IE_THROW() << "Cannot get convolution node " << node->getName(); - - // if (!parentConvolutionNode->weightsZeroPoints.empty()) - // return false; - - // // TODO [oneDNN]: is it still valide constrain on conv to fuse in? - // bool isSupportedParams = layer->_group == 1 && - // is1x1Convolution(layer) && // TODO [oneDNN] : fusing is permitted only with 1x1 convolutions - // everyone_is(1, layer->_stride[X_AXIS], layer->_stride[Y_AXIS]) && - // everyone_is(Precision::FP32, layer->insData[0].lock()->getPrecision(), layer->outData[0].get()->getPrecision()) && - // node->getChildEdgeAt(0)->getDims().ndims() == 4; - // if (!isSupportedParams) return false; - - // return node->getChildEdges().size() == 1 && isConvolutionNode(node->getChildEdgeAt(0)->getChild()); - // }; + auto& graphNodes = graph.GetNodes(); - // auto isSutableChildConvolution = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) { - // auto* childLayer = dynamic_cast(childNode->getCnnLayer().get()); - // if (childLayer == nullptr) - // IE_THROW() << "Cannot get convolution layer " << childNode->getName(); + auto isConvolutionNode = [](const MKLDNNNodePtr &node) { + return node->getType() == Convolution; + }; - // auto* parentLayer = dynamic_cast(parentNode->getCnnLayer().get()); - // if (parentLayer == nullptr) - // IE_THROW() << "Cannot get convolution layer " << parentNode->getName(); + auto is1x1Convolution = [](const std::shared_ptr &conv) { + const auto weightRank = conv->getWeightDims().size(); + return conv->getWeightDims()[weightRank - 1] == 1 && conv->getWeightDims()[weightRank - 2] == 1; + }; - // if (!everyone_is(Precision::FP32, parentLayer->outData[0].get()->getPrecision(), childLayer->insData[0].lock()->getPrecision(), - // childLayer->outData[0].get()->getPrecision())) - // return false; + auto isSutableParentConvolution = [&](MKLDNNNodePtr node) { + const auto conv = std::dynamic_pointer_cast(node); + if (conv == nullptr) + IE_THROW() << "Cannot cast to convolution node " << node->getName(); - // if (!everyone_is(Precision::FP32, parentLayer->precision, childLayer->precision)) - // return false; + if (!conv->weightsZeroPoints.empty()) + return false; - // auto parentOutputPrecision = !parentNode->fusedWith.empty() - // ? parentNode->fusedWith[parentNode->fusedWith.size() - 1]->getCnnLayer()->outData[0].get()->getPrecision() - // : parentNode->getCnnLayer()->outData[0].get()->getPrecision(); + const auto &strides = conv->getStride(); + bool isSupportedParams = conv->getGroupNum() == 1 && + is1x1Convolution(conv) && // TODO [oneDNN] : fusing is permitted only with 1x1 convolutions + everyone_is(1, strides[strides.size() - 1], strides[strides.size() - 2]) && + everyone_is(Precision::FP32, conv->getOriginalInputPrecisionAtPort(0), conv->getOriginalOutputPrecisionAtPort(0)) && + node->getChildEdgeAt(0)->getDims().ndims() == 4; + if (!isSupportedParams) return false; - // auto childOutputPrecision = !childNode->fusedWith.empty() - // ? childNode->fusedWith[childNode->fusedWith.size() - 1]->getCnnLayer()->outData[0].get()->getPrecision() - // : childNode->getCnnLayer()->outData[0].get()->getPrecision(); + return node->getChildEdges().size() == 1 && isConvolutionNode(node->getChildEdgeAt(0)->getChild()); + }; - // if (!everyone_is(Precision::FP32, parentOutputPrecision, childOutputPrecision)) - // return false; + auto isSutableChildConvolution = [&](const MKLDNNNodePtr &parentNode, const MKLDNNNodePtr &childNode) { + const auto convChild = std::dynamic_pointer_cast(childNode); + if (convChild == nullptr) + IE_THROW() << "Cannot cast to convolution node " << childNode->getName(); - // auto* childConvolutionNode = dynamic_cast(childNode.get()); - // if (childConvolutionNode == nullptr) - // IE_THROW() << "Cannot get convolution node " << childNode->getName(); + const auto convParent = std::dynamic_pointer_cast(parentNode); + if (convParent == nullptr) + IE_THROW() << "Cannot cast to convolution node " << parentNode->getName(); - // if (!childConvolutionNode->inputZeroPoints.empty() || !childConvolutionNode->weightsZeroPoints.empty()) - // return false; + if (!everyone_is(Precision::FP32, convParent->getOriginalOutputPrecisionAtPort(0), convChild->getOriginalInputPrecisionAtPort(0), + convChild->getOriginalOutputPrecisionAtPort(0))) + return false; - // bool withBias = (childLayer->_biases != nullptr && childLayer->_biases->size() != 0) || - // childConvolutionNode->getBaseIntputsNumber() == 3; + auto parentOutputPrecision = !parentNode->fusedWith.empty() + ? parentNode->fusedWith[parentNode->fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0) + : parentNode->getOriginalOutputPrecisionAtPort(0); - // auto allPads = getPaddings(*childLayer); + auto childOutputPrecision = !childNode->fusedWith.empty() + ? childNode->fusedWith[childNode->fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0) + : childNode->getOriginalOutputPrecisionAtPort(0); - // bool isSupportedParams = childLayer->_out_depth == childLayer->_group && - // childLayer->_out_depth != 1 && - // everyone_is(3, childLayer->_kernel[X_AXIS], childLayer->_kernel[Y_AXIS]) && - // everyone_is(1, allPads.begin[X_AXIS], allPads.begin[Y_AXIS]) && - // everyone_is(1, allPads.end[X_AXIS], allPads.end[Y_AXIS]) && - // everyone_is(1, childLayer->_dilation[X_AXIS], childLayer->_dilation[Y_AXIS]) && - // childLayer->_stride[X_AXIS] == childLayer->_stride[Y_AXIS] && - // withBias && - // one_of(childLayer->_stride[X_AXIS], 1, 2) && - // childNode->getChildEdgeAt(0)->getDims().ndims() == 4; + if (!everyone_is(Precision::FP32, parentOutputPrecision, childOutputPrecision)) + return false; - // return isSupportedParams; - // }; + if (!convChild->inputZeroPoints.empty() || !convChild->weightsZeroPoints.empty()) + return false; - // auto isFusingWorthwhile = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) { - // auto layer = std::dynamic_pointer_cast(childNode->getCnnLayer()); - // if (layer == nullptr) - // IE_THROW() << "Cannot get convolution layer " << childNode->getName(); + bool withBias = convChild->getOriginalInputPrecisions().size() == 3; + + const auto weightRank = convChild->getWeightDims().size(); + const auto stridesSize = convChild->getStride().size(); + bool isSupportedParams = convChild->outDims[0][1] == convChild->getGroupNum() && + convChild->outDims[0][1] != 1 && + everyone_is(3, convChild->getWeightDims()[weightRank - 1], convChild->getWeightDims()[weightRank - 2]) && + everyone_is(1, convChild->getPaddingL()[stridesSize - 1], convChild->getPaddingL()[stridesSize - 2]) && + everyone_is(1, convChild->getPaddingR()[stridesSize - 1], convChild->getPaddingR()[stridesSize - 2]) && + everyone_is(1, convChild->getDilation()[stridesSize - 1] + 1, convChild->getDilation()[stridesSize - 2] + 1) && + convChild->getStride()[stridesSize - 1] == convChild->getStride()[stridesSize - 2] && + withBias && + one_of(convChild->getStride()[stridesSize - 1], 1, 2) && + childNode->getChildEdgeAt(0)->getDims().ndims() == 4; + + return isSupportedParams; + }; - // auto inDims = childNode->inDims[0]; - // auto outDims = childNode->outDims[0]; - // int elemSize = layer->precision.size(); + auto isFusingWorthwhile = [&](const MKLDNNNodePtr &parentNode, const MKLDNNNodePtr &childNode) { + auto inDims = childNode->inDims[0]; + auto outDims = childNode->outDims[0]; + int elemSize = childNode->getOriginalOutputPrecisionAtPort(0).size(); - // int L3_cache_size = utils::get_cache_size(3, false); - // int dw_conv_input_size = inDims[0] * inDims[1] * inDims[2] * inDims[3] * elemSize; - // int dw_conv_output_size = outDims[0] * outDims[1]* outDims[2] * outDims[3] * elemSize; + int L3_cache_size = utils::get_cache_size(3, false); + int dw_conv_input_size = inDims[0] * inDims[1] * inDims[2] * inDims[3] * elemSize; + int dw_conv_output_size = outDims[0] * outDims[1]* outDims[2] * outDims[3] * elemSize; - // auto parentConvolutionNode = std::dynamic_pointer_cast(parentNode); - // if (parentConvolutionNode == nullptr) - // IE_THROW() << "Cannot get convolution node " << parentNode->getName(); + auto parentConvolutionNode = std::dynamic_pointer_cast(parentNode); + if (parentConvolutionNode == nullptr) + IE_THROW() << "Cannot get convolution node " << parentNode->getName(); - // if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx2) || impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_common)) - // return false; + if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx2) || impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_common)) + return false; - // return (dw_conv_input_size + dw_conv_output_size > L3_cache_size / 2); - // }; + return (dw_conv_input_size + dw_conv_output_size > L3_cache_size / 2); + }; - // for (int i = 0; i < graphNodes.size(); i++) { - // if (!isConvolutionNode(graphNodes[i])) continue; + for (int i = 0; i < graphNodes.size(); i++) { + if (!isConvolutionNode(graphNodes[i])) continue; - // auto parentConvNode = graphNodes[i]; - // if (!isSutableParentConvolution(parentConvNode)) continue; + auto parentConvNode = graphNodes[i]; + if (!isSutableParentConvolution(parentConvNode)) continue; - // auto childConvNode = parentConvNode->getChildEdgeAt(0)->getChild(); - // if (!isSutableChildConvolution(parentConvNode, childConvNode)) continue; + auto childConvNode = parentConvNode->getChildEdgeAt(0)->getChild(); + if (!isSutableChildConvolution(parentConvNode, childConvNode)) continue; - // if (!isFusingWorthwhile(parentConvNode, childConvNode)) continue; + if (!isFusingWorthwhile(parentConvNode, childConvNode)) continue; - // parentConvNode->fuseWith(childConvNode); + parentConvNode->addFusedNode(childConvNode); - // for (auto node : childConvNode->getFusedWith()) - // parentConvNode->fuseWith(node); - // childConvNode->clearFusedWith(); + for (auto node : childConvNode->getFusedWith()) { + parentConvNode->addFusedNode(node); + } + childConvNode->clearFusedWith(); - // graph.DropDWConvNode(childConvNode); - // } + graph.DropDWConvNode(childConvNode); + } } // TODO: mandrono: unite with FuseConvolutionAndSimpleOperation @@ -1039,7 +1024,12 @@ void MKLDNNGraphOptimizer::FusePoolingAndFakeQuantize(MKLDNNGraph &graph) { auto& graphNodes = graph.GetNodes(); auto isSutableParentNode = [](MKLDNNNodePtr node) { - return node->getType() == Pooling && node->getChildEdges().size() == 1 && node->getAlgorithm() == Algorithm::PoolingAvg; + if (node->getType() == Pooling) { + if (!one_of(node->getOriginalInputPrecisionAtPort(0), Precision::U8, Precision::I8)) + return false; + return node->getChildEdges().size() == 1 && node->getAlgorithm() == Algorithm::PoolingAvg; + } + return false; }; auto isSutableChildNode = [](MKLDNNNodePtr node) { diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index 0edb7a4dd7139f..9992d62b95292a 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -78,6 +78,7 @@ static const InferenceEngine::details::caseless_unordered_map { "Mod", Eltwise }, { "FloorMod", Eltwise }, { "Power", Eltwise }, + { "PowerStatic", Eltwise }, { "Equal", Eltwise }, { "NotEqual", Eltwise }, { "Greater", Eltwise }, @@ -89,6 +90,7 @@ static const InferenceEngine::details::caseless_unordered_map { "LogicalXor", Eltwise }, { "LogicalNot", Eltwise }, { "Relu", Eltwise }, + { "LeakyRelu", Eltwise }, { "Gelu", Eltwise }, { "Elu", Eltwise }, { "Tanh", Eltwise }, @@ -222,7 +224,8 @@ MKLDNNNode::MKLDNNNode(const std::shared_ptr& op, const mkldnn::en } for (size_t i = 0; i < op->get_input_size(); i++) { - inDims.emplace_back(op->get_input_shape(i)); + const auto &shape = op->get_input_shape(i); + inDims.emplace_back(ngraph::is_scalar(shape) ? ngraph::Shape{1} : shape); originalInputPrecisions.emplace_back(details::convertPrecision(op->get_input_element_type(i))); } @@ -231,7 +234,8 @@ MKLDNNNode::MKLDNNNode(const std::shared_ptr& op, const mkldnn::en IE_THROW() << "Node with type '" << typeStr << "' and name '" << name << "' does not have any outputs."; } for (size_t i = 0; i < op->get_output_size(); i++) { - outDims.emplace_back(op->get_output_shape(i)); + const auto &shape = op->get_output_shape(i); + outDims.emplace_back(ngraph::is_scalar(shape) ? ngraph::Shape{1} : shape); originalOutputPrecisions.emplace_back(details::convertPrecision(op->get_output_element_type(i))); } } diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_cpu_specific_opset.hpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_cpu_specific_opset.hpp index 63fe847afa1de7..e3feb60e70d18c 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_cpu_specific_opset.hpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_cpu_specific_opset.hpp @@ -10,6 +10,8 @@ #include "convert_broadcast_to_tiles.hpp" #include "convert_tile_to_seq_tiles.hpp" #include "reshape_1d_ops.hpp" +#include "convert_to_power_static.hpp" +#include "convert_to_leaky_relu.hpp" namespace MKLDNNPlugin { @@ -25,6 +27,8 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr &nGraphF manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); if (!ngraph::op::util::has_op_with_type(nGraphFunc)) { manager.register_pass(); } diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.cpp new file mode 100644 index 00000000000000..73d469c652c5ad --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "convert_to_leaky_relu.hpp" + +#include +#include +#include +#include "op/leaky_relu.hpp" + +NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertToLeakyRelu, "ConvertToLeakyRelu", 0); + +MKLDNNPlugin::ConvertToLeakyRelu::ConvertToLeakyRelu() { + auto prelu = ngraph::pattern::wrap_type({ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), + ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}); + + ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) { + auto prelu = std::dynamic_pointer_cast(m.get_match_root()); + if (!prelu) { + return false; + } + auto slopeNode = std::dynamic_pointer_cast(prelu->get_input_node_shared_ptr(1)); + if (slopeNode != nullptr && ngraph::shape_size(prelu->get_input_shape(1)) == 1) { + const float slope = slopeNode->cast_vector()[0]; + const auto leakyRelu = std::make_shared(prelu->input(0).get_source_output(), slope, + prelu->output(0).get_element_type()); + leakyRelu->set_friendly_name(prelu->get_friendly_name()); + ngraph::copy_runtime_info(prelu, leakyRelu); + ngraph::replace_node(prelu, leakyRelu); + return true; + } + return false; + }; + + auto m = std::make_shared(prelu, "ConvertToLeakyRelu"); + this->register_matcher(m, callback); +} diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.hpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.hpp new file mode 100644 index 00000000000000..6e5eff2937c2e1 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_leaky_relu.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace MKLDNNPlugin { + +class ConvertToLeakyRelu: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + ConvertToLeakyRelu(); +}; + +} // namespace MKLDNNPlugin \ No newline at end of file diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.cpp new file mode 100644 index 00000000000000..0e5ed24e6d196e --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.cpp @@ -0,0 +1,124 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "convert_to_power_static.hpp" + +#include +#include +#include +#include +#include +#include +#include "op/power_static.hpp" +#include "op/fully_connected.hpp" +#include "utils/general_utils.h" + +int getConstPort(const std::shared_ptr &node) { + const auto const1 = std::dynamic_pointer_cast(node->get_input_node_shared_ptr(0)); + const auto const2 = std::dynamic_pointer_cast(node->get_input_node_shared_ptr(1)); + int constPort = -1; + if (const2) { + constPort = 1; + } else if (const1) { + constPort = 0; + } + return constPort; +} + +template +bool isConvertableToPowerStatic(const std::shared_ptr &node) { + const int constPort = getConstPort(node); + if ((!node->get_input_element_type(0).is_real() && !node->get_input_element_type(1).is_real()) || !node->get_output_element_type(0).is_real() || + constPort == -1) { + return false; + } + + const int nonConstPort = 1 - constPort; + const auto constNode = std::dynamic_pointer_cast(node->get_input_node_shared_ptr(constPort)); + + return ngraph::shape_size(node->get_input_shape(constPort)) == 1 && + node->get_input_shape(nonConstPort).size() >= node->get_input_shape(constPort).size() && + !MKLDNNPlugin::one_of(node->get_input_node_shared_ptr(nonConstPort)->get_type_info(), ngraph::opset1::NormalizeL2::type_info, + ngraph::opset4::Interpolate::type_info, + ngraph::opset1::Convolution::type_info, + ngraph::opset1::GroupConvolution::type_info, + ngraph::opset1::ConvolutionBackpropData::type_info, + ngraph::opset1::GroupConvolutionBackpropData::type_info, + MKLDNNPlugin::FullyConnectedNode::type_info, + ngraph::op::v0::MVN::type_info, + ngraph::opset6::MVN::type_info); +} + +template <> +bool isConvertableToPowerStatic(const std::shared_ptr &node) { + return std::dynamic_pointer_cast(node->get_input_node_shared_ptr(1)) != nullptr && + node->get_input_shape(0).size() >= node->get_input_shape(1).size() && ngraph::shape_size(node->get_input_shape(1)) == 1; +} + +template +std::shared_ptr convert(const std::shared_ptr &node) { + const int constPort = getConstPort(node); + const int nonConstPort = 1 - constPort; + std::shared_ptr powerNode = std::dynamic_pointer_cast(node->get_input_node_shared_ptr(constPort)); + const float value = powerNode->cast_vector()[0]; + if (std::is_same::value) { + return std::make_shared(node->input(nonConstPort).get_source_output(), value, 1.0f, 0.0f, + node->output(0).get_element_type()); + } else if (std::is_same::value) { + return std::make_shared(node->input(nonConstPort).get_source_output(), 1.0f, 1.0f, value, + node->output(0).get_element_type()); + } else if (std::is_same::value) { + return std::make_shared(node->input(nonConstPort).get_source_output(), 1.0f, 1.0f, (-1.0f * value), + node->output(0).get_element_type()); + } else if (std::is_same::value) { + return std::make_shared(node->input(nonConstPort).get_source_output(), 1.f, value, 0.0f, + node->output(0).get_element_type()); + } else { + throw ngraph::ngraph_error("ConvertToPowerStatic: op type is not supported"); + } +} + +NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ConvertToPowerStatic, "ConvertToPowerStatic", 0); + +MKLDNNPlugin::ConvertToPowerStatic::ConvertToPowerStatic() { + ngraph::OutputVector twoInputs = {ngraph::pattern::any_input(ngraph::pattern::has_static_shape()), + ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}; + auto power = ngraph::pattern::wrap_type(twoInputs); + auto add = ngraph::pattern::wrap_type(twoInputs); + auto sub = ngraph::pattern::wrap_type(twoInputs); + auto mult = ngraph::pattern::wrap_type(twoInputs); + const auto candidate = std::make_shared(ngraph::OutputVector{power, add, sub, mult}); + + ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) { + auto node = m.get_match_root(); + + std::shared_ptr toReplace = node; + if (auto power = std::dynamic_pointer_cast(node)) { + if (!isConvertableToPowerStatic(power)) + return false; + toReplace = convert(power); + } else if (auto add = std::dynamic_pointer_cast(node)) { + if (!isConvertableToPowerStatic(add)) + return false; + toReplace = convert(add); + } else if (auto sub = std::dynamic_pointer_cast(node)) { + if (!isConvertableToPowerStatic(sub)) + return false; + toReplace = convert(sub); + } else if (auto mult = std::dynamic_pointer_cast(node)) { + if (!isConvertableToPowerStatic(mult)) + return false; + toReplace = convert(mult); + } else { + throw ngraph::ngraph_error("ConvertToPowerStatic: op type is not supported"); + } + toReplace->set_friendly_name(node->get_friendly_name()); + ngraph::copy_runtime_info(node, toReplace); + ngraph::replace_node(node, toReplace); + return true; + }; + + auto m = std::make_shared(candidate, "ConvertToPowerStatic"); + this->register_matcher(m, callback); +} diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.hpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.hpp new file mode 100644 index 00000000000000..9fefa3a9ba55c8 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/convert_to_power_static.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace MKLDNNPlugin { + +class ConvertToPowerStatic: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + ConvertToPowerStatic(); +}; + +} // namespace MKLDNNPlugin diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.cpp new file mode 100644 index 00000000000000..4e943d4b517516 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "leaky_relu.hpp" + +constexpr ngraph::NodeTypeInfo MKLDNNPlugin::LeakyReluNode::type_info; + +MKLDNNPlugin::LeakyReluNode::LeakyReluNode(const ngraph::Output &data, + const float &negative_slope, + const ngraph::element::Type output_type) + : Op({data}), m_negative_slope(negative_slope), m_output_type(output_type) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr MKLDNNPlugin::LeakyReluNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const { + check_new_args_count(this, new_args); + return std::make_shared(new_args.at(0), m_negative_slope, m_output_type); +} + +void MKLDNNPlugin::LeakyReluNode::validate_and_infer_types() { + set_output_type( + 0, + m_output_type == ngraph::element::undefined ? get_input_element_type(0) : m_output_type, + get_input_partial_shape(0)); +} + +bool MKLDNNPlugin::LeakyReluNode::visit_attributes(ngraph::AttributeVisitor &visitor) { + visitor.on_attribute("negative_slope", m_negative_slope); + return true; +} diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.hpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.hpp new file mode 100644 index 00000000000000..3465ffc75100c4 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/leaky_relu.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace MKLDNNPlugin { + +class LeakyReluNode : public ngraph::op::Op { +public: + static constexpr ngraph::NodeTypeInfo type_info{"LeakyRelu", 0}; + const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; } + + LeakyReluNode(const ngraph::Output &data, const float &negative_slope, const ngraph::element::Type output_type); + + void validate_and_infer_types() override; + + bool visit_attributes(ngraph::AttributeVisitor &visitor) override; + + std::shared_ptr clone_with_new_inputs(const ngraph::OutputVector &new_args) const override; + + float get_slope() { return m_negative_slope; } + + ngraph::element::Type get_output_type() const { return m_output_type; } + +private: + float m_negative_slope; + ngraph::element::Type m_output_type; +}; + +} // namespace MKLDNNPlugin diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.cpp new file mode 100644 index 00000000000000..be1f23f9bb3183 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "power_static.hpp" + +constexpr ngraph::NodeTypeInfo MKLDNNPlugin::PowerStaticNode::type_info; + +MKLDNNPlugin::PowerStaticNode::PowerStaticNode(const ngraph::Output &data, + const float &power, + const float &scale, + const float &shift, + const ngraph::element::Type output_type) + : Op({data}), scale(scale), power(power), shift(shift), m_output_type(output_type) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr MKLDNNPlugin::PowerStaticNode::clone_with_new_inputs(const ngraph::OutputVector &new_args) const { + if (new_args.size() != 1) { + throw ngraph::ngraph_error("Incorrect number of new arguments"); + } + + return std::make_shared(new_args.at(0), this->power, this->scale, this->shift, this->m_output_type); +} + +void MKLDNNPlugin::PowerStaticNode::validate_and_infer_types() { + set_output_type(0, m_output_type == ngraph::element::undefined ? get_input_element_type(0) : m_output_type, get_input_partial_shape(0)); +} + +bool MKLDNNPlugin::PowerStaticNode::visit_attributes(ngraph::AttributeVisitor &visitor) { + visitor.on_attribute("scale", scale); + visitor.on_attribute("power", power); + visitor.on_attribute("shift", shift); + return true; +} diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.hpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.hpp new file mode 100644 index 00000000000000..e43a54c4e03acd --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/power_static.hpp @@ -0,0 +1,34 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace MKLDNNPlugin { + +class PowerStaticNode : public ngraph::op::Op { +public: + static constexpr ngraph::NodeTypeInfo type_info{"PowerStatic", 0}; + const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; } + + PowerStaticNode(const ngraph::Output &data, const float &power, const float &scale, const float &shift, + const ngraph::element::Type output_type = ngraph::element::undefined); + + void validate_and_infer_types() override; + + bool visit_attributes(ngraph::AttributeVisitor &visitor) override; + + std::shared_ptr clone_with_new_inputs(const ngraph::OutputVector &new_args) const override; + + float get_power() const { return power; } + float get_scale() const { return scale; } + float get_shift() const { return shift; } + +private: + float scale, power, shift; + ngraph::element::Type m_output_type; +}; + +} // namespace MKLDNNPlugin diff --git a/inference-engine/src/mkldnn_plugin/nodes/depth_to_space.cpp b/inference-engine/src/mkldnn_plugin/nodes/depth_to_space.cpp index 6c91ca30b7fe3a..5ac5872d1649ee 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/depth_to_space.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/depth_to_space.cpp @@ -10,6 +10,10 @@ #include #include #include "ie_parallel.hpp" +#include +#include + +using namespace MKLDNNPlugin; namespace InferenceEngine { namespace Extensions { @@ -21,54 +25,81 @@ class DepthToSpaceImpl: public ExtLayerBase { DEPTH_FIRST }; + std::string errorPrefix; + + bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + const auto depthToSpace = std::dynamic_pointer_cast(op); + if (!depthToSpace) { + errorMessage = "Only opset1 DepthToSpace operation is supported"; + return false; + } + const auto mode = depthToSpace->get_mode(); + if (!one_of(mode, ngraph::op::v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, ngraph::op::v0::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST)) { + errorMessage = "Does not support mode: " + ngraph::as_string(mode); + return false; + } + } catch (...) { + return false; + } + return true; + } + public: - explicit DepthToSpaceImpl(const CNNLayer* layer) { + explicit DepthToSpaceImpl(const std::shared_ptr& op) { try { - if (layer->insData.empty() || layer->outData.empty()) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << "' has incorrect number of input/output edges"; + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + IE_THROW(NotImplemented) << errorMessage; + } + + errorPrefix = "DepthToSpace layer with name '" + op->get_friendly_name() + "'"; + const auto depthToSpace = std::dynamic_pointer_cast(op); + + if (op->get_input_size() != 1 || op->get_output_size() != 1) + IE_THROW() << errorPrefix << " has incorrect number of input/output edges"; - inDims = layer->insData[0].lock()->getTensorDesc().getDims(); + inDims = op->get_input_shape(0); if (inDims.size() < 3) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << "' has incorrect number of input dimensions"; + IE_THROW() << errorPrefix << " has incorrect number of input dimensions"; if (inDims.size() > 5) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << "' doesn't support dimensions with rank greater than 5"; + IE_THROW() << errorPrefix << " doesn't support dimensions with rank greater than 5"; - SizeVector outDims = layer->outData[0]->getTensorDesc().getDims(); + SizeVector outDims = op->get_output_shape(0); if (inDims.size() != outDims.size()) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << "' has incorrect number of input/output dimensions"; + IE_THROW() << errorPrefix << " has incorrect number of input/output dimensions"; - std::string modeString = layer->GetParamAsString("mode"); - if (modeString == "blocks_first") { + const auto modeNgraph = depthToSpace->get_mode(); + if (modeNgraph == ngraph::op::v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST) { mode = DepthToSpaceMode::BLOCKS_FIRST; - } else if (modeString == "depth_first") { + } else if (modeNgraph == ngraph::op::v0::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST) { mode = DepthToSpaceMode::DEPTH_FIRST; } else { - IE_THROW() << "DepthToSpace layer with name '" << layer->name << "' doesn't support mode: " << modeString; + IE_THROW() << errorPrefix << " doesn't support mode: " << ngraph::as_string(modeNgraph); } - blockSize = layer->GetParamAsUInt("block_size", 1); + blockSize = depthToSpace->get_block_size(); if (blockSize == 0) - IE_THROW() << layer->name << " Incorrect blockSize parameter is zero!"; + IE_THROW() << errorPrefix << " has incorrect block_size = 0"; size_t numSpatialDims = inDims.size() - 2; blockStep = static_cast(std::pow(blockSize, numSpatialDims)); if (inDims[1] % blockStep) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << - "' has block_size parameter which is incompatible with input tensor channels dimension size"; + IE_THROW() << errorPrefix << " has block_size parameter which is incompatible with input tensor channels dimension size"; if (inDims[1] / blockStep != outDims[1]) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << " has incompatible input/output channels"; + IE_THROW() << errorPrefix << " has incompatible input/output channels"; for (int i = 0; i < numSpatialDims; i++) { if (inDims[i + 2] * blockSize != outDims[i + 2]) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << " has incompatible spatial dims"; + IE_THROW() << errorPrefix << " has incompatible spatial dims"; } - auto computePrc = layer->insData[0].lock()->getTensorDesc().getPrecision(); + auto computePrc = details::convertPrecision(op->get_input_element_type(0)); const std::set supported_precision_sizes = {1, 2, 4, 8}; if (supported_precision_sizes.find(computePrc.size()) == supported_precision_sizes.end()) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << " doesn't support precision: " << computePrc.name(); + IE_THROW() << errorPrefix << " doesn't support precision: " << computePrc.name(); if (inDims.size() == 4 || inDims.size() == 5) { diff --git a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp index de15fd6f8e9bc4..11f5609704834d 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp +++ b/inference-engine/src/mkldnn_plugin/nodes/list_tbl.hpp @@ -51,7 +51,7 @@ MKLDNN_EXTENSION_NODE(ReorgYoloImpl, ReorgYolo); //MKLDNN_EXTENSION_NODE(UniqueImpl, Unique); MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, PSROIPooling); MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, DeformablePSROIPooling); -//MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace); +MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace); //MKLDNN_EXTENSION_NODE(OneHotImpl, OneHot); MKLDNN_EXTENSION_NODE(BroadcastImpl, Broadcast); //MKLDNN_EXTENSION_NODE(ExperimentalSparseWeightedReduceImpl, ExperimentalSparseWeightedSum); @@ -61,7 +61,7 @@ MKLDNN_EXTENSION_NODE(ExperimentalDetectronGenerateProposalsSingleImageImpl, Exp MKLDNN_EXTENSION_NODE(NonMaxSuppressionImpl, NonMaxSuppressionIEInternal); MKLDNN_EXTENSION_NODE(TopKImpl, TopK); MKLDNN_EXTENSION_NODE(ShuffleChannelsImpl, ShuffleChannels); -//MKLDNN_EXTENSION_NODE(SpaceToDepthImpl, SpaceToDepth); +MKLDNN_EXTENSION_NODE(SpaceToDepthImpl, SpaceToDepth); //MKLDNN_EXTENSION_NODE(PowerFileImpl, PowerFile); //MKLDNN_EXTENSION_NODE(BatchToSpaceImpl, BatchToSpace); //MKLDNN_EXTENSION_NODE(ExperimentalDetectronPriorGridGeneratorImpl, ExperimentalDetectronPriorGridGenerator); @@ -76,7 +76,7 @@ MKLDNN_EXTENSION_NODE(GatherElementsImpl, GatherElements); MKLDNN_EXTENSION_NODE(GatherNDImpl, GatherND); MKLDNN_EXTENSION_NODE(ProposalImpl, Proposal); //MKLDNN_EXTENSION_NODE(RangeImpl, Range); -//MKLDNN_EXTENSION_NODE(SelectImpl, Select); +MKLDNN_EXTENSION_NODE(SelectImpl, Select); MKLDNN_EXTENSION_NODE(GatherTreeImpl, GatherTree); //MKLDNN_EXTENSION_NODE(PriorBoxClusteredImpl, PriorBoxClustered); //MKLDNN_EXTENSION_NODE(SpaceToBatchImpl, SpaceToBatch); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp index 7ee1ec178e314a..4d9f0544e75b6e 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp @@ -40,7 +40,7 @@ bool MKLDNNConvolutionNode::isSupportedOperation(const std::shared_ptr& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache), withBiases(false), withSum(false), withDWConv(false), - isGrouped(false), /* dw_conv_oc(0), dw_conv_ih(0), dw_conv_iw(0), dw_conv_in_dt(memory::data_type::undef), */ + isGrouped(false), dw_conv_oc(0), dw_conv_ih(0), dw_conv_iw(0), dw_conv_in_dt(memory::data_type::undef), groupNum(1lu), eltwisePrecision(Precision::FP32) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { @@ -187,43 +187,39 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() { withDWConv = isFusedWith(Convolution); - // TODO: fusing with Convolution is not ported yet -// for (int i = 0; i < fusedWith.size(); i++) { -// auto *convolutionNode = dynamic_cast(fusedWith[i].get()); -// if (convolutionNode) { -// auto *convLayer = reinterpret_cast(convolutionNode->getCnnLayer().get()); -// dw_conv_ih = convolutionNode->inDims[0][convolutionNode->inDims[0].ndims() - 2]; -// dw_conv_iw = convolutionNode->inDims[0][convolutionNode->inDims[0].ndims() - 1]; -// dw_conv_oc = convLayer->_out_depth; -// for (int j = 0; j < convLayer->_kernel.size(); j++) { -// dw_conv_kernel.push_back(convLayer->_kernel[j]); -// } -// for (int j = 0; j < convLayer->_stride.size(); j++) { -// dw_conv_strides.push_back(convLayer->_stride[j]); -// } -// -// if (canBeExecutedInInt8()) { -// if (i == 0) { -// dw_conv_in_dt = precisionToDataType(getCnnLayer()->outData[0]->getPrecision()); -// } else { -// dw_conv_in_dt = precisionToDataType(fusedWith[i - 1].get()->getCnnLayer()->outData[0]->getPrecision()); -// } -// } else { -// dw_conv_in_dt = memory::data_type::f32; -// } -// -// for (int j = 0; j < paddingR.size(); j++) { -// int with_group = (isGrouped || isMerged) ? 1 : 0; -// int krn = weightsDims[with_group + 2 + j]; -// int src = getParentEdgeAt(0)->getDims()[2 + j]; -// int dst = getChildEdgeAt(0)->getDims()[2 + j]; -// -// krn = (krn - 1)*(dilation[j] + 1) + 1; -// int calc_dst = (src - krn + paddingL[j]) / stride[j] + 1; -// paddingR[j] = (dst - calc_dst) * stride[j]; -// } -// } -// } + for (int i = 0; i < fusedWith.size(); i++) { + auto *convolutionNode = dynamic_cast(fusedWith[i].get()); + if (convolutionNode) { + dw_conv_ih = convolutionNode->inDims[0][convolutionNode->inDims[0].ndims() - 2]; + dw_conv_iw = convolutionNode->inDims[0][convolutionNode->inDims[0].ndims() - 1]; + dw_conv_oc = convolutionNode->outDims[0][1]; + const auto &dwWeightsDims = convolutionNode->inDims[1].ToSizeVector(); + dw_conv_kernel.push_back(dwWeightsDims[dwWeightsDims.size() - 1]); + dw_conv_kernel.push_back(dwWeightsDims[dwWeightsDims.size() - 2]); + dw_conv_strides = convolutionNode->getStride(); + + if (canBeExecutedInInt8()) { + if (i == 0) { + dw_conv_in_dt = MKLDNNExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(0)); + } else { + dw_conv_in_dt = MKLDNNExtensionUtils::IEPrecisionToDataType(fusedWith[i - 1]->getOriginalOutputPrecisionAtPort(0)); + } + } else { + dw_conv_in_dt = memory::data_type::f32; + } + + for (int j = 0; j < paddingR.size(); j++) { + int with_group = isGrouped ? 1 : 0; + int krn = weightsDims[with_group + 2 + j]; + int src = getParentEdgeAt(0)->getDims()[2 + j]; + int dst = getChildEdgeAt(0)->getDims()[2 + j]; + + krn = (krn - 1)*(dilation[j] + 1) + 1; + int calc_dst = (src - krn + paddingL[j]) / stride[j] + 1; + paddingR[j] = (dst - calc_dst) * stride[j]; + } + } + } MKLDNNMemoryDesc in_candidate, out_candidate; if (canBeExecutedInInt8()) { @@ -334,101 +330,27 @@ void MKLDNNConvolutionNode::setPostOps(mkldnn::primitive_attr &attr, bool initWe continue; } - // auto* convolutionNode = dynamic_cast(node.get()); - // if (convolutionNode) { - // if (initWeights) { - // if (convolutionNode->getBaseIntputsNumber() == 1) { - // auto* convLayer = reinterpret_cast(convolutionNode->getCnnLayer().get()); - - // auto weightsPrc = precisionToDataType(convLayer->precision); - // auto biasPrc = memory::data_type::s32; - - // PostOpsIntBlobMemory.push_back(MKLDNNMemoryPtr(new MKLDNNMemory(getEngine()))); - // MKLDNNDims dwWeightsDims({dw_conv_oc, (ptrdiff_t)1, (ptrdiff_t)1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]}); - // PostOpsIntBlobMemory[blob_idx]->Create(dwWeightsDims, weightsPrc, memory::format_tag::Goihw8g); - // PostOpsIntBlobMemory[blob_idx]->FillZero(); - - // Blob::Ptr weights = convLayer->blobs.find("weights")->second; - // Blob::Ptr biases = convLayer->blobs.find("biases")->second; - - // PostOpsIntBlobMemory[blob_idx]->SetData(weightsPrc, memory::format_tag::goihw, weights->buffer(), - // dwWeightsDims.size() * MKLDNNExtensionUtils::sizeOfDataType(weightsPrc)); - - // PostOpsIntBlobMemory.push_back(MKLDNNMemoryPtr(new MKLDNNMemory(getEngine()))); - // MKLDNNDims dwBiasesDims({dw_conv_oc}); - // PostOpsIntBlobMemory[blob_idx + 1]->Create(dwBiasesDims, biasPrc, memory::format_tag::x); - // PostOpsIntBlobMemory[blob_idx + 1]->FillZero(); - // PostOpsIntBlobMemory[blob_idx + 1]->SetData(biasPrc, memory::format_tag::x, biases->buffer(), - // dwBiasesDims.size() * MKLDNNExtensionUtils::sizeOfDataType(biasPrc)); - // // todo: rewrite onto append_dw_k3s2p1 - // ops.append_dw_conv(dw_conv_ih, dw_conv_iw, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS], - // dw_conv_strides[Y_AXIS], dw_conv_strides[X_AXIS], - // mkldnn::memory::convert_to_c(dw_conv_in_dt), - // static_cast(PostOpsIntBlobMemory[blob_idx]->GetData()), - // static_cast(PostOpsIntBlobMemory[blob_idx + 1]->GetData())); - - // blob_idx += 2; - // } else { - // // todo: rewrite onto append_dw_k3s2p1 - // ops.append_dw_conv(dw_conv_ih, dw_conv_iw, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS], - // dw_conv_strides[Y_AXIS], dw_conv_strides[X_AXIS], - // mkldnn::memory::convert_to_c(dw_conv_in_dt), - // static_cast(getParentEdgeAt( - // baseInputsNumber + 0)->getMemory().GetData()), - // static_cast(getParentEdgeAt( - // baseInputsNumber + 1)->getMemory().GetData())); - // } - // } else { - // // todo: rewrite onto append_dw_k3s2p1 - // ops.append_dw_conv(dw_conv_ih, dw_conv_iw, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS], - // dw_conv_strides[Y_AXIS], dw_conv_strides[X_AXIS], - // mkldnn::memory::convert_to_c(dw_conv_in_dt), - // nullptr, - // nullptr); - // } - - // if (convolutionNode->wScale != nullptr) { - // float* wScaleData = static_cast(convolutionNode->wScale->buffer()); - - // std::vector oScaleDataVector; - // std::vector oShiftDataVector; - // if (convolutionNode->getCnnLayer()->precision == Precision::I8 && - // convolutionNode->getCnnLayer()->outData[0]->getPrecision() != Precision::FP32) { - // float *oScaleData = static_cast(convolutionNode->oScale->buffer()); - - // for (size_t c = 0; c < convolutionNode->wScale->size(); c++) { - // oScaleDataVector.push_back(wScaleData[c] / oScaleData[c]); - // oShiftDataVector.push_back(0.f); - // } - // } else { - // for (size_t c = 0; c < convolutionNode->wScale->size(); c++) { - // oScaleDataVector.push_back(wScaleData[c]); - // oShiftDataVector.push_back(0.f); - // } - // } - - // MKLDNNDims oScaleDims({static_cast(rnd_up(biasesDims[0], 16))}); - - // PostOpsIntBlobMemory.push_back(MKLDNNMemoryPtr(new MKLDNNMemory(getEngine()))); - // PostOpsIntBlobMemory[blob_idx]->Create(oScaleDims, memory::data_type::f32, memory::format_tag::x); - // PostOpsIntBlobMemory[blob_idx]->FillZero(); - // PostOpsIntBlobMemory[blob_idx]->SetData(memory::data_type::f32, memory::format_tag::x, &oScaleDataVector[0], - // oScaleDataVector.size() * MKLDNNExtensionUtils::sizeOfDataType(memory::data_type::f32)); - - // PostOpsIntBlobMemory.push_back(MKLDNNMemoryPtr(new MKLDNNMemory(getEngine()))); - // PostOpsIntBlobMemory[blob_idx + 1]->Create(oScaleDims, memory::data_type::f32, memory::format_tag::x); - // PostOpsIntBlobMemory[blob_idx + 1]->FillZero(); - // PostOpsIntBlobMemory[blob_idx + 1]->SetData(memory::data_type::f32, memory::format_tag::x, &oShiftDataVector[0], - // oShiftDataVector.size() * MKLDNNExtensionUtils::sizeOfDataType(memory::data_type::f32)); - - // ops.append_depthwise(mkldnn::algorithm::depthwise_scale_shift, - // static_cast(PostOpsIntBlobMemory[blob_idx]->GetData()), - // static_cast(PostOpsIntBlobMemory[blob_idx + 1]->GetData())); - - // blob_idx += 2; - // } - // continue; - // } + auto* convolutionNode = dynamic_cast(node.get()); + if (convolutionNode) { + if (initWeights) { + // todo: rewrite onto append_dw_k3s2p1 + ops.append_dw_conv(dw_conv_ih, dw_conv_iw, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS], + dw_conv_strides[Y_AXIS], dw_conv_strides[X_AXIS], + mkldnn::memory::convert_to_c(dw_conv_in_dt), + static_cast(getParentEdgeAt( + getOriginalInputsNumber() + 0)->getMemory().GetData()), + static_cast(getParentEdgeAt( + getOriginalInputsNumber() + 1)->getMemory().GetData())); + } else { + // todo: rewrite onto append_dw_k3s2p1 + ops.append_dw_conv(dw_conv_ih, dw_conv_iw, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS], + dw_conv_strides[Y_AXIS], dw_conv_strides[X_AXIS], + mkldnn::memory::convert_to_c(dw_conv_in_dt), + nullptr, + nullptr); + } + continue; + } IE_THROW() << "Fusing of " << NameFromType(node->getType()) << " operation to " << NameFromType(this->getType()) << " node is not implemented"; } @@ -463,23 +385,22 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() { config.inConfs.push_back(dataConfig); } -// TODO: fusing with Convolution is not ported yet -// if (withDWConv) { -// auto weightsPrc = precisionToDataType(dw_conv_in_dt == mkldnn_u8 ? Precision::I8 : Precision::FP32); -// auto biasPrc = memory::data_type::f32; -// -// MKLDNNDims dwWeightsDims({dw_conv_oc, (ptrdiff_t)1, (ptrdiff_t)1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]}); -// MKLDNNDims dwBiasesDims({dw_conv_oc}); -// -// InferenceEngine::DataConfig dataConfig; -// dataConfig.inPlace = -1; -// dataConfig.constant = false; -// dataConfig.desc = MKLDNNMemoryDesc(dwWeightsDims, weightsPrc, memory::format_tag::Goihw8g); -// config.inConfs.push_back(dataConfig); -// -// dataConfig.desc = MKLDNNMemoryDesc(dwBiasesDims, biasPrc, memory::format_tag::x); -// config.inConfs.push_back(dataConfig); -// } + if (withDWConv) { + auto weightsPrc = MKLDNNExtensionUtils::IEPrecisionToDataType(dw_conv_in_dt == mkldnn_u8 ? Precision::I8 : Precision::FP32); + auto biasPrc = memory::data_type::f32; + + MKLDNNDims dwWeightsDims({dw_conv_oc, (ptrdiff_t)1, (ptrdiff_t)1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]}); + MKLDNNDims dwBiasesDims({dw_conv_oc}); + + InferenceEngine::DataConfig dataConfig; + dataConfig.inPlace = -1; + dataConfig.constant = false; + dataConfig.desc = MKLDNNMemoryDesc(dwWeightsDims, weightsPrc, memory::format_tag::Goihw8g); + config.inConfs.push_back(dataConfig); + + dataConfig.desc = MKLDNNMemoryDesc(dwBiasesDims, biasPrc, memory::format_tag::x); + config.inConfs.push_back(dataConfig); + } for (size_t i = 0; i < descOutputNumbers(desc); i++) { InferenceEngine::DataConfig dataConfig; @@ -651,23 +572,22 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c cfg.inConfs.push_back(dataConfig); } - // TODO: fusing with Convolution is not ported yet -// if (withDWConv) { -// auto weightsPrc = precisionToDataType(dw_conv_in_dt == mkldnn_u8 ? Precision::I8 : Precision::FP32); -// auto biasPrc = memory::data_type::f32; -// -// MKLDNNDims dwWeightsDims({dw_conv_oc, (ptrdiff_t)1, (ptrdiff_t)1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]}); -// MKLDNNDims dwBiasesDims({dw_conv_oc}); -// -// InferenceEngine::DataConfig dataConfig; -// dataConfig.inPlace = -1; -// dataConfig.constant = false; -// dataConfig.desc = MKLDNNMemoryDesc(dwWeightsDims, weightsPrc, memory::format_tag::Goihw8g); -// cfg.inConfs.push_back(dataConfig); -// -// dataConfig.desc = MKLDNNMemoryDesc(dwBiasesDims, biasPrc, memory::format_tag::x); -// cfg.inConfs.push_back(dataConfig); -// } + if (withDWConv) { + auto weightsPrc = MKLDNNExtensionUtils::IEPrecisionToDataType(dw_conv_in_dt == mkldnn_u8 ? Precision::I8 : Precision::FP32); + auto biasPrc = memory::data_type::f32; + + MKLDNNDims dwWeightsDims({dw_conv_oc, (ptrdiff_t)1, (ptrdiff_t)1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]}); + MKLDNNDims dwBiasesDims({dw_conv_oc}); + + InferenceEngine::DataConfig dataConfig; + dataConfig.inPlace = -1; + dataConfig.constant = false; + dataConfig.desc = MKLDNNMemoryDesc(dwWeightsDims, weightsPrc, memory::format_tag::Goihw8g); + cfg.inConfs.push_back(dataConfig); + + dataConfig.desc = MKLDNNMemoryDesc(dwBiasesDims, biasPrc, memory::format_tag::x); + cfg.inConfs.push_back(dataConfig); + } for (size_t j = 0; j < descOutputNumbers(desc); j++) { InferenceEngine::DataConfig dataConfig; diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h index 95793af25b4404..58527200da3a9a 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h @@ -44,6 +44,12 @@ class MKLDNNConvolutionNode : public MKLDNNNode { std::vector weightsZeroPoints; std::vector outputCompensation; + const InferenceEngine::SizeVector &getWeightDims() { return weightDims; } + const std::vector &getStride() { return stride; } + const std::vector &getDilation() { return dilation; } + const std::vector &getPaddingL() { return paddingL; } + const std::vector &getPaddingR() { return paddingR; } + protected: InferenceEngine::Precision fusedEltwisePrecision(const MKLDNNNodePtr& fusingNode) const; @@ -65,13 +71,12 @@ class MKLDNNConvolutionNode : public MKLDNNNode { InferenceEngine::SizeVector weightDims; InferenceEngine::SizeVector biasesDims; -// TODO: fusing with Convolution is not ported yet -// ptrdiff_t dw_conv_oc; -// ptrdiff_t dw_conv_ih; -// ptrdiff_t dw_conv_iw; -// std::vector dw_conv_kernel; -// std::vector dw_conv_strides; -// mkldnn::memory::data_type dw_conv_in_dt; + ptrdiff_t dw_conv_oc; + ptrdiff_t dw_conv_ih; + ptrdiff_t dw_conv_iw; + std::vector dw_conv_kernel; + std::vector dw_conv_strides; + mkldnn::memory::data_type dw_conv_in_dt; size_t groupNum; size_t IC; @@ -79,6 +84,9 @@ class MKLDNNConvolutionNode : public MKLDNNNode { size_t groupOC; InferenceEngine::Precision eltwisePrecision; + + const size_t X_AXIS = 0; + const size_t Y_AXIS = 1; }; } // namespace MKLDNNPlugin diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp index b45e74d1f13d78..c8c61351649183 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.cpp @@ -26,6 +26,8 @@ #include "ngraph/ngraph.hpp" #include +#include "ngraph_transformations/op/power_static.hpp" +#include "ngraph_transformations/op/leaky_relu.hpp" #include #include @@ -800,6 +802,13 @@ std::map& op, MKLDNNEltwiseNode& node) { node.algorithm = EltwisePowerDynamic; }}, + {PowerStaticNode::type_info, [](const std::shared_ptr& op, MKLDNNEltwiseNode& node) { + auto powerStatic = getNgraphOpAs(op); + node.algorithm = EltwisePowerStatic; + node.alpha = powerStatic->get_power(); + node.beta = powerStatic->get_scale(); + node.gamma = powerStatic->get_shift(); + }}, {ngraph::op::v1::Equal::type_info, [](const std::shared_ptr& op, MKLDNNEltwiseNode& node) { node.algorithm = EltwiseEqual; }}, @@ -834,6 +843,13 @@ std::map& op, MKLDNNEltwiseNode& node) { + auto leakyRelu = getNgraphOpAs(op); + node.algorithm = EltwiseRelu; + node.mkldnnAlgorithm = mkldnn::algorithm::eltwise_relu; + node.alpha = leakyRelu->get_slope(); + node.beta = 0.0f; + }}, {ngraph::op::v0::Gelu::type_info, [](const std::shared_ptr& op, MKLDNNEltwiseNode& node) { node.algorithm = EltwiseGelu; node.mkldnnAlgorithm = mkldnn::algorithm::eltwise_gelu_erf; @@ -947,9 +963,9 @@ MKLDNNEltwiseNode::MKLDNNEltwiseNode(const std::shared_ptr& op, co size_t MKLDNNEltwiseNode::getOpInputsNum() const { switch (getAlgorithm()) { case EltwiseRelu: case EltwiseGelu: case EltwiseElu: case EltwiseTanh: case EltwiseSigmoid: case EltwiseSquare: case EltwiseAbs: case EltwiseSqrt: - case EltwisePowerStatic: case EltwiseLinear: case EltwiseBoundedRelu: case EltwiseSoftRelu: case EltwiseRelu6: case EltwiseExp: case EltwiseClamp: + case EltwiseLinear: case EltwiseBoundedRelu: case EltwiseSoftRelu: case EltwiseRelu6: case EltwiseExp: case EltwiseClamp: case EltwiseErf: case EltwiseSwish: case EltwiseHswish: case EltwiseMish: case EltwiseHsigmoid: case EltwiseRoundHalfToEven: case EltwiseRoundHalfAwayFromZero: - case EltwiseLogicalNot: case EltwiseErf: + case EltwiseLogicalNot: case EltwisePowerStatic: return 1; case EltwiseAdd: case EltwiseSubtract: case EltwiseMultiply: case EltwiseDivide: case EltwiseFloorMod: case EltwiseMod: case EltwiseMaximum: case EltwiseMinimum: case EltwiseSquaredDifference: case EltwisePowerDynamic: case EltwiseEqual: case EltwiseNotEqual: case EltwiseGreater: diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.h index 507bcc19303db2..8ae340004306a9 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_eltwise_node.h @@ -77,6 +77,7 @@ class MKLDNNEltwiseNode : public MKLDNNNode { float getAlpha() const { return alpha; } float getBeta() const { return beta; } + float getGamma() const { return gamma; } mkldnn::algorithm getMKLDNNAlgorithm() const { return mkldnnAlgorithm; } bool isWithBroadcast(); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.cpp index 7d7b69b1f07e01..0a5ad38507bda5 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.cpp @@ -956,7 +956,7 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr(fq->get_input_node_shared_ptr(4)); auto outputHighData = outputHighNode->cast_vector(); - bool binarization = levels == 2; + binarization = levels == 2; if (binarization) { for (int i = 0; i < outputLowAxisSize; i++) { @@ -1085,20 +1085,6 @@ MKLDNNFakeQuantizeNode::MKLDNNFakeQuantizeNode(const std::shared_ptr MKLDNNFakeQuantizeNode::getDataFormats() } } +void MKLDNNFakeQuantizeNode::init() { + if (binarization) { + inputPrecision = Precision::FP32; + outputPrecision = Precision::BIN; + } else { + inputPrecision = getOriginalInputPrecisionAtPort(0); + outputPrecision = getOriginalOutputPrecisionAtPort(0); + + if (inputPrecision != Precision::FP32 && inputPrecision != Precision::U8 && inputPrecision != Precision::I8) + inputPrecision = Precision::FP32; + + if (outputPrecision != Precision::FP32 && outputPrecision != Precision::U8 && outputPrecision != Precision::I8) + outputPrecision = Precision::FP32; + } +} + void MKLDNNFakeQuantizeNode::getSupportedDescriptors() { if (getParentEdges().size() != 5) IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size(); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.h index 28419eb9a10674..7e94afb1c4bed3 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_fake_quantize_node.h @@ -109,6 +109,7 @@ class MKLDNNFakeQuantizeNode : public MKLDNNNode { static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: + void init() override; std::vector getDataFormats() const; void executeReference(); void executeBinarization(); @@ -116,6 +117,8 @@ class MKLDNNFakeQuantizeNode : public MKLDNNNode { int levels = -1; + bool binarization = false; + std::vector binarizationThresholds; std::vector binarizationOutputMask; diff --git a/inference-engine/src/mkldnn_plugin/nodes/select.cpp b/inference-engine/src/mkldnn_plugin/nodes/select.cpp index e23b32ab3810ab..2de70f665425e9 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/select.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/select.cpp @@ -7,6 +7,10 @@ #include #include #include "ie_parallel.hpp" +#include +#include + +using namespace MKLDNNPlugin; namespace InferenceEngine { namespace Extensions { @@ -15,74 +19,115 @@ namespace Cpu { class SelectImpl: public ExtLayerBase { enum { CONDITION, THEN, ELSE, numOfInputs }; enum { N, C, D, H, W, numOfDims }; + enum class SelectBroadcastType { + NONE, + NUMPY + }; - std::string broadcast; + SelectBroadcastType broadcastType; std::vector resDims; std::vector resOffset; std::vector condOffset; std::vector thenOffset; std::vector elseOffset; + std::string errorPrefix; + + bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + const auto select = std::dynamic_pointer_cast(op); + if (!select) { + errorMessage = "Only opset1 Select operation is supported"; + return false; + } + const auto broadcast = select->get_auto_broadcast(); + if (!one_of(broadcast, ngraph::op::AutoBroadcastSpec::NONE, ngraph::op::AutoBroadcastSpec::NUMPY)) { + errorMessage = "Does not support broadcast type: " + ngraph::as_string(broadcast.m_type); + return false; + } + } catch (...) { + return false; + } + return true; + } + public: - explicit SelectImpl(const CNNLayer* layer) { + explicit SelectImpl(const std::shared_ptr& op) { try { - if (layer->insData.size() != numOfInputs || layer->outData.size() != 1) - IE_THROW() << "Select layer with name '" << layer->name << "' has incorrect number of input/output edges!"; + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + IE_THROW(NotImplemented) << errorMessage; + } + + errorPrefix = "Select layer with name '" + op->get_friendly_name() + "'"; + const auto select = std::dynamic_pointer_cast(op); + + if (op->get_input_size() != numOfInputs || op->get_output_size() != 1) + IE_THROW() << errorPrefix << " has incorrect number of input/output edges!"; - broadcast = layer->GetParamAsString("auto_broadcast", "numpy"); + const auto broadcast = select->get_auto_broadcast(); + if (broadcast == ngraph::op::AutoBroadcastSpec::NONE) { + broadcastType = SelectBroadcastType::NONE; + } else if (broadcast == ngraph::op::AutoBroadcastSpec::NUMPY) { + broadcastType = SelectBroadcastType::NUMPY; + } else { + IE_THROW() << errorPrefix << " has unsupported broadcast type: " + ngraph::as_string(broadcast.m_type); + } - auto inputPrecision = layer->insData[THEN].lock()->getTensorDesc().getPrecision(); - if (inputPrecision == Precision::BF16 || layer->insData[ELSE].lock()->getTensorDesc().getPrecision() == Precision::BF16) { + const auto inputThenPrecision = details::convertPrecision(op->get_input_element_type(THEN)); + const auto inputElsePrecision = details::convertPrecision(op->get_input_element_type(ELSE)); + auto inputPrecision = inputThenPrecision; + if (inputThenPrecision == Precision::BF16 || inputElsePrecision == Precision::BF16) { inputPrecision = Precision::BF16; - } else if (layer->insData[THEN].lock()->getTensorDesc().getPrecision() != layer->insData[ELSE].lock()->getTensorDesc().getPrecision()) { - IE_THROW() << "Select layer with name '" << layer->name << "' has different precisions on 'Then' and 'Else' inputs "; + } else if (inputThenPrecision != inputElsePrecision) { + IE_THROW() << errorPrefix << " has different precisions on 'Then' and 'Else' inputs "; } - const auto& conditionPrecision = layer->insData[CONDITION].lock()->getTensorDesc().getPrecision(); + const auto conditionPrecision = details::convertPrecision(op->get_input_element_type(CONDITION)); if (conditionPrecision != Precision::BOOL && conditionPrecision != Precision::I32 && conditionPrecision != Precision::U8) - IE_THROW() << "Select layer with name '" << layer->name << "' has unsupported precision: " << conditionPrecision - << " on 'Condition' input"; + IE_THROW() << errorPrefix << " has unsupported precision: " << conditionPrecision << " on 'Condition' input"; - const auto& inputPrecisionSize = layer->insData[THEN].lock()->getTensorDesc().getPrecision().size(); + const auto inputPrecisionSize = inputPrecision.size(); if (inputPrecisionSize != 1 && inputPrecisionSize != 2 && inputPrecisionSize != 4 && inputPrecisionSize != 8) - IE_THROW() << "Select layer with name '" << layer->name << "' has unsupported precision: " << - layer->insData[THEN].lock()->getTensorDesc().getPrecision() << " on 'Then' and 'Else' inputs"; - - const auto &conditionShapes = layer->insData[CONDITION].lock()->getTensorDesc().getDims(); - const auto &thenShapes = layer->insData[THEN].lock()->getTensorDesc().getDims(); - const auto &elseShapes = layer->insData[ELSE].lock()->getTensorDesc().getDims(); - const auto &outputShapes = layer->outData[0]->getTensorDesc().getDims(); - - if (broadcast != "none" && broadcast != "numpy") - IE_THROW() << "Select layer with name '" << layer->name << "' has unsupported broadcast type: " << broadcast; - - if (broadcast == "none" && ((conditionShapes != outputShapes) || (thenShapes != outputShapes) || (elseShapes != outputShapes))) - IE_THROW() << "Select layer with name '" << layer->name << "' and auto_broadcast='none' has input shapes mismatch"; - - if (broadcast == "numpy") { + IE_THROW() << errorPrefix << " has unsupported precision: " << inputPrecision << " on 'Then' and 'Else' inputs"; + + auto conditionShapes = op->get_input_shape(CONDITION); + if (ngraph::is_scalar(conditionShapes)) + conditionShapes = ngraph::Shape{1}; + auto thenShapes = op->get_input_shape(THEN); + if (ngraph::is_scalar(thenShapes)) + thenShapes = ngraph::Shape{1}; + auto elseShapes = op->get_input_shape(ELSE); + if (ngraph::is_scalar(elseShapes)) + elseShapes = ngraph::Shape{1}; + auto outputShapes = op->get_output_shape(0); + if (ngraph::is_scalar(outputShapes)) + outputShapes = ngraph::Shape{1}; + + if (broadcastType == SelectBroadcastType::NONE && ((conditionShapes != outputShapes) || (thenShapes != outputShapes) || + (elseShapes != outputShapes))) + IE_THROW() << errorPrefix << " and auto_broadcast='none' has input shapes mismatch"; + + if (broadcastType == SelectBroadcastType::NUMPY) { if (outputShapes.size() < conditionShapes.size() || outputShapes.size() < thenShapes.size() || outputShapes.size() < elseShapes.size()) - IE_THROW() << "Select layer with name '" << layer->name << "' and auto_broadcast='numpy' has incompatible input and output shapes"; + IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible input and output shapes"; for (int condIt = conditionShapes.size() - 1, outIt = outputShapes.size() - 1; condIt >= 0; condIt--, outIt--) if (conditionShapes[condIt] != outputShapes[outIt] && conditionShapes[condIt] != 1) - IE_THROW() << "Select layer with name '" << layer->name - << "' and auto_broadcast='numpy' has incompatible 'Condition' input and output shapes"; + IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Condition' input and output shapes"; for (int thenIt = thenShapes.size() - 1, outIt = outputShapes.size() - 1; thenIt >= 0; thenIt--, outIt--) if (thenShapes[thenIt] != outputShapes[outIt] && thenShapes[thenIt] != 1) - IE_THROW() << "Select layer with name '" << layer->name - << "' and auto_broadcast='numpy' has incompatible 'Then' input and output shapes"; - + IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Then' input and output shapes"; for (int elseIt = elseShapes.size() - 1, outIt = outputShapes.size() - 1; elseIt >= 0; elseIt--, outIt--) if (elseShapes[elseIt] != outputShapes[outIt] && elseShapes[elseIt] != 1) - IE_THROW() << "Select layer with name '" << layer->name - << "' and auto_broadcast='numpy' has incompatible 'Else' input and output shapes"; + IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Else' input and output shapes"; } resDims.resize(numOfDims, 1); std::copy(std::begin(outputShapes), std::end(outputShapes), std::begin(resDims) + (numOfDims - outputShapes.size())); - if (broadcast == "numpy") { + if (broadcastType == SelectBroadcastType::NUMPY) { calcOutOffset(resOffset, resDims); std::vector condDims(numOfDims, 1); @@ -98,28 +143,10 @@ class SelectImpl: public ExtLayerBase { calcInOffset(elseOffset, elseDims, resDims); } - LayerConfig config; - for (size_t i = 0; i < numOfInputs; i++) { - DataConfig inConfig; - inConfig.inPlace = -1; - inConfig.constant = false; - - Precision inPrecision = i == CONDITION ? conditionPrecision : inputPrecision; - const SizeVector& inDims = layer->insData[i].lock()->getTensorDesc().getDims(); - inConfig.desc = TensorDesc(inPrecision, inDims, InferenceEngine::TensorDesc::getLayoutByDims(inDims)); - - config.inConfs.push_back(inConfig); - } - - DataConfig outConfig; - outConfig.inPlace = -1; - outConfig.constant = false; - const SizeVector& outDims = layer->outData[0]->getTensorDesc().getDims(); - outConfig.desc = TensorDesc(inputPrecision, outDims, InferenceEngine::TensorDesc::getLayoutByDims(outDims)); - config.outConfs.push_back(outConfig); - - config.dynBatchSupport = false; - confs.push_back(config); + addConfig(op, {{TensorDescCreatorTypes::ncsp, conditionPrecision}, + {TensorDescCreatorTypes::ncsp, inputPrecision}, + {TensorDescCreatorTypes::ncsp, inputPrecision}}, + {{TensorDescCreatorTypes::ncsp, inputPrecision}}); } catch (InferenceEngine::Exception &ex) { errorMsg = ex.what(); } @@ -204,7 +231,7 @@ class SelectImpl: public ExtLayerBase { auto *elseData = inputs[ELSE]->cbuffer().as() + inputs[ELSE]->getTensorDesc().getBlockingDesc().getOffsetPadding(); auto *dstData = output->buffer().as() + output->getTensorDesc().getBlockingDesc().getOffsetPadding(); - if (broadcast == "none") { + if (broadcastType == SelectBroadcastType::NONE) { size_t dstDataSize = std::accumulate(begin(resDims), end(resDims), 1, std::multiplies()); parallel_for(dstDataSize, [&](size_t i) { dstData[i] = conditionData[i] ? thenData[i] : elseData[i]; diff --git a/inference-engine/src/mkldnn_plugin/nodes/space_to_depth.cpp b/inference-engine/src/mkldnn_plugin/nodes/space_to_depth.cpp index 59c18b4495b792..e0c4e6b4c6f303 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/space_to_depth.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/space_to_depth.cpp @@ -10,6 +10,10 @@ #include #include #include "ie_parallel.hpp" +#include +#include + +using namespace MKLDNNPlugin; namespace InferenceEngine { namespace Extensions { @@ -21,55 +25,81 @@ class SpaceToDepthImpl: public ExtLayerBase { DEPTH_FIRST }; + std::string errorPrefix; + + bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + try { + const auto spaceToDepth = std::dynamic_pointer_cast(op); + if (!spaceToDepth) { + errorMessage = "Only opset1 SpaceToDepth operation is supported"; + return false; + } + const auto mode = spaceToDepth->get_mode(); + if (!one_of(mode, ngraph::op::v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, ngraph::op::v0::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST)) { + errorMessage = "Does not support mode: " + ngraph::as_string(mode); + return false; + } + } catch (...) { + return false; + } + return true; + } + public: - explicit SpaceToDepthImpl(const CNNLayer* layer) { + explicit SpaceToDepthImpl(const std::shared_ptr& op) { try { - if (layer->insData.empty() || layer->outData.empty()) - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << "' has incorrect number of input/output edges"; + std::string errorMessage; + if (!isSupportedOperation(op, errorMessage)) { + IE_THROW(NotImplemented) << errorMessage; + } - SizeVector inDims = layer->insData[0].lock()->getTensorDesc().getDims(); + errorPrefix = "SpaceToDepth layer with name '" + op->get_friendly_name() + "'"; + const auto spaceToDepth = std::dynamic_pointer_cast(op); + + if (op->get_input_size() != 1 || op->get_output_size() != 1) + IE_THROW() << errorPrefix << " has incorrect number of input/output edges"; + + SizeVector inDims = op->get_input_shape(0); if (inDims.size() < 3) - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << "' has incorrect number of input dimensions"; + IE_THROW() << errorPrefix << " has incorrect number of input dimensions"; if (inDims.size() > 5) - IE_THROW() << "DepthToSpace layer with name '" << layer->name << "' doesn't support dimensions with rank greater than 5"; + IE_THROW() << errorPrefix << " doesn't support dimensions with rank greater than 5"; - outDims = layer->outData[0]->getTensorDesc().getDims(); + outDims = op->get_output_shape(0); if (inDims.size() != outDims.size()) - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << "' has incorrect number of input/output dimensions"; + IE_THROW() << errorPrefix << " has incorrect number of input/output dimensions"; - std::string modeString = layer->GetParamAsString("mode"); - if (modeString == "blocks_first") { + const auto modeNgraph = spaceToDepth->get_mode(); + if (modeNgraph == ngraph::op::v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST) { mode = SpaceToDepthMode::BLOCKS_FIRST; - } else if (modeString == "depth_first") { + } else if (modeNgraph == ngraph::op::v0::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST) { mode = SpaceToDepthMode::DEPTH_FIRST; } else { - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << "' doesn't support mode: " << modeString; + IE_THROW() << errorPrefix << " doesn't support mode: " << ngraph::as_string(modeNgraph); } - blockSize = layer->GetParamAsUInt("block_size", 1); + blockSize = spaceToDepth->get_block_size(); if (blockSize == 0) - IE_THROW() << layer->name << " Incorrect blockSize parameter is zero!"; + IE_THROW() << errorPrefix << " has incorrect block_size = 0"; size_t numSpatialDims = inDims.size() - 2; blockStep = static_cast(std::pow(blockSize, numSpatialDims)); if (outDims[1] % blockStep) - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << - "' has block_size parameter which is incompatible with input tensor channels dimension size"; + IE_THROW() << errorPrefix << " has block_size parameter which is incompatible with input tensor channels dimension size"; if (inDims[1] != outDims[1] / blockStep) - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << " has incompatible input/output channels"; + IE_THROW() << errorPrefix << " has incompatible input/output channels"; for (int i = 0; i < numSpatialDims; i++) { if (inDims[i + 2] != outDims[i + 2] * blockSize) - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << " has incompatible spatial dims"; + IE_THROW() << errorPrefix << " has incompatible spatial dims"; } - auto computePrc = layer->insData[0].lock()->getTensorDesc().getPrecision(); + auto computePrc = details::convertPrecision(op->get_input_element_type(0)); const std::set supported_precision_sizes = {1, 2, 4, 8}; if (supported_precision_sizes.find(computePrc.size()) == supported_precision_sizes.end()) - IE_THROW() << "SpaceToDepth layer with name '" << layer->name << " doesn't support precision: " << computePrc.name(); - + IE_THROW() << errorPrefix << " doesn't support precision: " << computePrc.name(); if (inDims.size() == 4 || inDims.size() == 5) { LayerConfig config; diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp index c18fb7b8c71038..e2e7a383fd3dad 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp @@ -61,7 +61,6 @@ std::vector disabledTestPatterns() { // shared SLT test R"(.*BatchToSpaceLayerTest.*)", R"(.*BucketizeLayerTest.*)", - R"(.*DepthToSpaceLayerTest.*)", R"(.*ExtractImagePatchesTest.*)", R"(.*GRUCellTest.*)", R"(.*GRUSequenceTest.*)", @@ -75,9 +74,7 @@ std::vector disabledTestPatterns() { R"(.*ReverseSequenceLayerTest.*)", R"(.*RNNCellTest.*)", R"(.*RNNSequenceTest.*)", - R"(.*SelectLayerTest.*)", R"(.*SpaceToBatchLayerTest.*)", - R"(.*SpaceToDepthLayerTest.*)", R"(.*TensorIteratorTest.*)", R"(.*VariadicSplitPad.*)", diff --git a/inference-engine/tests/functional/plugin/cpu/single_layer_tests/convert_to_plugin_specific_node.cpp b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/convert_to_plugin_specific_node.cpp new file mode 100644 index 00000000000000..a59b7f6fc3b5ce --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/single_layer_tests/convert_to_plugin_specific_node.cpp @@ -0,0 +1,120 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils/cpu_test_utils.hpp" +#include "ngraph_functions/builders.hpp" + +using namespace ngraph; +using namespace InferenceEngine; +using namespace CPUTestUtils; + +namespace CPULayerTestsDefinitions { + +using ConvertToPluginSpecificNodeParams = std::tuple; // expected number of constant node + +class ConvertToPluginSpecificNode : public testing::WithParamInterface, + public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + SizeVector nonConstShape, constShape; + Precision prc; + helpers::EltwiseTypes nodeType; + size_t port, constNodeNum; + std::tie(nonConstShape, constShape, prc, nodeType, port, constNodeNum) = obj.param; + + std::ostringstream result; + result << "IS_NON_CONST=" << CommonTestUtils::vec2str(nonConstShape) << "_"; + result << "IS_CONST=" << CommonTestUtils::vec2str(constShape) << "_"; + result << "PRC=" << prc << "_"; + result << "NODE=" << nodeType << "_"; + result << "PORT=" << port << "_"; + result << "CONST_NUM=" << constNodeNum; + + return result.str(); + } + +protected: + size_t constNodeNum; + + void SetUp() override { + targetDevice = CommonTestUtils::DEVICE_CPU; + + SizeVector nonConstShape, constShape; + Precision prc; + helpers::EltwiseTypes nodeType; + size_t port; + + std::tie(nonConstShape, constShape, prc, nodeType, port, constNodeNum) = this->GetParam(); + IE_ASSERT(shape_size(constShape) == 1); + + const auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(prc); + const auto param = std::make_shared(ngPrc, ngraph::Shape(nonConstShape)); + const auto constNode = builder::makeConstant(ngPrc, ngraph::Shape(constShape), std::vector{}, true); + OutputVector inputs(2); + inputs[port] = constNode; + inputs[1 - port] = param; + + auto powerStatic = ngraph::builder::makeEltwise(inputs[0], inputs[1], nodeType); + + function = std::make_shared(powerStatic, ParameterVector{param}, "ConvertToPluginSpecificNode"); + } +}; + +TEST_P(ConvertToPluginSpecificNode, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + Run(); + CheckNodeOfTypeCount(executableNetwork, "Const", constNodeNum); +} + +namespace { + +const std::vector> nonConstIS = { + {3, 4, 5, 6} +}; + +const std::vector> constIS = { + {}, + {1}, + {1, 1}, + {1, 1, 1}, + {1, 1, 1, 1}, +}; + +std::vector nodeTypes = { + ngraph::helpers::EltwiseTypes::ADD, + ngraph::helpers::EltwiseTypes::SUBTRACT, + ngraph::helpers::EltwiseTypes::MULTIPLY +}; + +const std::vector port = { + 0, 1 +}; + +const auto testParamsEltwise = ::testing::Combine(::testing::ValuesIn(nonConstIS), + ::testing::ValuesIn(constIS), + ::testing::Values(Precision::FP32), + ::testing::ValuesIn(nodeTypes), + ::testing::ValuesIn(port), + ::testing::Values(0)); + +INSTANTIATE_TEST_CASE_P(smoke_CheckEltwise, ConvertToPluginSpecificNode, testParamsEltwise, ConvertToPluginSpecificNode::getTestCaseName); + +const auto testParamsPower = ::testing::Combine(::testing::ValuesIn(nonConstIS), + ::testing::ValuesIn(constIS), + ::testing::Values(Precision::FP32), + ::testing::Values(ngraph::helpers::EltwiseTypes::POWER), + ::testing::Values(1), + ::testing::Values(0)); + +INSTANTIATE_TEST_CASE_P(smoke_CheckPower, ConvertToPluginSpecificNode, testParamsPower, ConvertToPluginSpecificNode::getTestCaseName); + +} // namespace + +} // namespace CPULayerTestsDefinitions diff --git a/ngraph/core/src/op/depth_to_space.cpp b/ngraph/core/src/op/depth_to_space.cpp index 4e90fbf980e4ef..6892daa9352a88 100644 --- a/ngraph/core/src/op/depth_to_space.cpp +++ b/ngraph/core/src/op/depth_to_space.cpp @@ -238,7 +238,7 @@ bool op::DepthToSpace::evaluate(const HostTensorVector& outputs, namespace ngraph { template <> - EnumNames& + NGRAPH_API EnumNames& EnumNames::get() { static auto enum_names = EnumNames( diff --git a/ngraph/core/src/op/space_to_depth.cpp b/ngraph/core/src/op/space_to_depth.cpp index 9b8a5786765d5c..5a38fb771902f3 100644 --- a/ngraph/core/src/op/space_to_depth.cpp +++ b/ngraph/core/src/op/space_to_depth.cpp @@ -226,7 +226,7 @@ bool ngraph::op::v0::SpaceToDepth::evaluate(const HostTensorVector& outputs, namespace ngraph { template <> - EnumNames& + NGRAPH_API EnumNames& EnumNames::get() { static auto enum_names = EnumNames(