From 5d8f209df6a47edc3722ed2e4ea55ce3fac3f5b2 Mon Sep 17 00:00:00 2001 From: Aleksandr Pertovsky Date: Mon, 3 May 2021 15:01:05 +0300 Subject: [PATCH] [CPU] Add Roll support (#5112) --- .../src/mkldnn_plugin/mkldnn_node.cpp | 1 + .../src/mkldnn_plugin/mkldnn_node.h | 5 +- .../mkldnn_plugin/nodes/mkldnn_roll_node.cpp | 209 ++++++++++++++++++ .../mkldnn_plugin/nodes/mkldnn_roll_node.h | 41 ++++ .../single_layer_tests/roll.cpp | 96 ++++++++ .../include/single_layer_tests/roll.hpp | 15 ++ .../shared_test_classes/single_layer/roll.hpp | 30 +++ .../src/single_layer/roll.cpp | 46 ++++ .../include/ngraph_functions/builders.hpp | 4 + .../ngraph_functions/src/roll.cpp | 17 ++ ngraph/python/tests/test_ngraph/test_roll.py | 2 - 11 files changed, 463 insertions(+), 3 deletions(-) create mode 100644 inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp create mode 100644 inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.h create mode 100644 inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/roll.hpp create mode 100644 inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/roll.hpp create mode 100644 inference-engine/tests/functional/shared_test_classes/src/single_layer/roll.cpp create mode 100644 inference-engine/tests/ngraph_helpers/ngraph_functions/src/roll.cpp diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp index d3af44347ad50b..f446c339d39989 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.cpp @@ -143,6 +143,7 @@ static const InferenceEngine::details::caseless_unordered_map { "ReduceSum", ReduceSum}, { "ReduceSumSquare", ReduceSumSquare}, { "Erf", Eltwise }, + { "Roll", Roll }, }; Type TypeFromName(const std::string type) { diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_node.h b/inference-engine/src/mkldnn_plugin/mkldnn_node.h index 169bde711c88a6..483c315e955e87 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_node.h +++ b/inference-engine/src/mkldnn_plugin/mkldnn_node.h @@ -87,7 +87,8 @@ enum Type { ReduceOr, ReduceProd, ReduceSum, - ReduceSumSquare + ReduceSumSquare, + Roll }; Type TypeFromName(const std::string type); @@ -206,6 +207,8 @@ static std::string NameFromType(Type type) { return "ReduceSum"; case ReduceSumSquare: return "ReduceSumSquare"; + case Roll: + return "Roll"; default: return "Unknown"; } diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp new file mode 100644 index 00000000000000..aa1d6623463a24 --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.cpp @@ -0,0 +1,209 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include + +#include "mkldnn_roll_node.h" +#include "ie_parallel.hpp" +#include "ie_precision.hpp" +#include "mkldnn/ie_mkldnn.h" +#include "utils/general_utils.h" +#include "common/cpu_memcpy.h" + +using namespace mkldnn; +using namespace MKLDNNPlugin; +using namespace InferenceEngine; + +MKLDNNRollNode::MKLDNNRollNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache) : + MKLDNNNode(layer, eng, cache) { + layerErrorPrefix = "Roll layer with name '" + layer->name + "'"; + if (layer->insData.size() != numberOfInputs) { + IE_THROW() << layerErrorPrefix << " has incorrect number of input/output edges!"; + } + + /* Data */ + auto data = layer->insData[DATA_INDEX].lock(); + if (data == nullptr) { + IE_THROW() << layerErrorPrefix << " has nullable data"; + } + + const auto &dataTensor = data->getTensorDesc(); + shape = dataTensor.getDims(); + const auto &dataPrecision = dataTensor.getPrecision(); + + if (std::find(supportedPrecisionSizes.begin(), supportedPrecisionSizes.end(), dataPrecision.size()) == supportedPrecisionSizes.end()) + IE_THROW() << layerErrorPrefix << "has unsupported precision: " << dataPrecision.name(); + + if (shape.size() < 1) { + IE_THROW() << layerErrorPrefix << " doesn't support 'data' input tensor with rank: " << shape.size(); + } + numOfDims = shape.size(); + + if (shape != layer->outData[0]->getTensorDesc().getDims()) { + IE_THROW() << layerErrorPrefix << " has different 'data' input and output dimensions"; + } + + /* Axes */ + auto axesData = layer->insData[AXES_INDEX].lock(); + if (axesData == nullptr) { + IE_THROW() << layerErrorPrefix << " has nullable 'axes' data"; + } + const auto& axesTensor = axesData->getTensorDesc(); + const auto& axesTensorPrec = axesData->getTensorDesc().getPrecision(); + if (axesTensorPrec != Precision::I32 && axesTensorPrec != Precision::I64) { + IE_THROW() << layerErrorPrefix << " has unsupported 'axes' input precision: " << axesTensorPrec.name(); + } + + const auto axesTensorRank = axesTensor.getDims().size(); + if (axesTensorRank > 1) { + IE_THROW() << layerErrorPrefix << " doesn't support 'axes' input tensor with rank: " << axesTensorRank; + } + + /* Shift */ + auto shiftData = layer->insData[SHIFT_INDEX].lock(); + if (shiftData == nullptr) { + IE_THROW() << layerErrorPrefix << " has nullable 'shift' data"; + } + const auto& shiftTensor = shiftData->getTensorDesc(); + const auto& shiftTensorPrec = shiftData->getTensorDesc().getPrecision(); + if (shiftTensorPrec != Precision::I32 && shiftTensorPrec != Precision::I64) { + IE_THROW() << layerErrorPrefix << " has unsupported 'shift' input precision: " << shiftTensorPrec.name(); + } + + const auto shiftTensorRank = shiftTensor.getDims().size(); + if (shiftTensorRank > 1) { + IE_THROW() << layerErrorPrefix << " doesn't support 'shift' input tensor with rank: " << shiftTensorRank; + } +} +void MKLDNNRollNode::getSupportedDescriptors() {} + +void MKLDNNRollNode::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + + auto inputData = getCnnLayer()->insData[0].lock(); + + if (inputData == nullptr) { + IE_THROW() << layerErrorPrefix << " has nullable 'data'"; + } + + InferenceEngine::Precision precision = inputData->getPrecision(); + + auto dataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision); + + auto srcDims = getParentEdgeAt(0)->getDims(); + + auto dataMemoryFormat = MKLDNNMemory::GetPlainFormat(getParentEdgeAt(0)->getDims()); + InferenceEngine::LayerConfig config; + config.dynBatchSupport = false; + + auto createDataConfig = [](const MKLDNNDims& dims, memory::data_type dataType) -> InferenceEngine::DataConfig { + InferenceEngine::DataConfig dataConfig; + dataConfig.inPlace = -1; + dataConfig.constant = false; + dataConfig.desc = MKLDNNMemoryDesc(dims, dataType, MKLDNNMemory::GetPlainFormat(dims)); + return dataConfig; + }; + + config.inConfs.push_back(createDataConfig(getParentEdgeAt(0)->getDims(), dataType)); + config.inConfs.push_back(createDataConfig(getParentEdgeAt(1)->getDims(), memory::data_type::s32)); + config.inConfs.push_back(createDataConfig(getParentEdgeAt(2)->getDims(), memory::data_type::s32)); + + config.outConfs.push_back(createDataConfig(getChildEdgeAt(0)->getDims(), dataType)); + + supportedPrimitiveDescriptors.push_back({config, impl_desc_type::ref, dataMemoryFormat}); +} + + +void MKLDNNRollNode::execute(mkldnn::stream strm) { + const auto dataPrecision = getParentEdgeAt(DATA_INDEX)->getDesc().getPrecision(); + const auto& dataTypeSize = dataPrecision.size(); + switch (dataTypeSize) { + case sizeof(PrecisionTrait::value_type): { + rollImpl::value_type>(); + break; + } + case sizeof(PrecisionTrait::value_type): { + rollImpl::value_type>(); + break; + } + case sizeof(PrecisionTrait::value_type): { + rollImpl::value_type>(); + break; + } + default: + IE_THROW() << layerErrorPrefix << "has unsupported 'data' input precision: " << dataPrecision.name(); + } +} + +size_t MKLDNNRollNode::calculateShiftOffset(size_t dataOffset, size_t dimShift, size_t segmentSize, size_t dimSize) { + size_t pos = dataOffset / segmentSize % dimSize; + size_t shift = (pos + dimShift) % dimSize - pos; + return dataOffset + shift * segmentSize; +} + +template +void MKLDNNRollNode::rollImpl() { + const auto dataEdge = getParentEdgeAt(DATA_INDEX); + const auto axesEdge = getParentEdgeAt(AXES_INDEX); + const auto shiftsEdge = getParentEdgeAt(SHIFT_INDEX); + + const auto *axes = reinterpret_cast(axesEdge->getMemoryPtr()->GetPtr()); + const auto *shifts = reinterpret_cast(shiftsEdge->getMemoryPtr()->GetPtr()); + + const auto *input = reinterpret_cast(dataEdge->getMemoryPtr()->GetPtr()); + auto *output = reinterpret_cast(getChildEdgeAt(0)->getMemoryPtr()->GetPtr()); + std::vector shiftsVector(numOfDims, 0); + + const size_t axesLength = axesEdge->getDims()[0]; + for (size_t dim = 0; dim < axesLength ; ++dim) { + int32_t currentAxis = axes[dim] < 0 ? axes[dim] + numOfDims : axes[dim]; + int32_t shiftSum = shiftsVector[currentAxis] + shifts[dim]; + int32_t dimSize = shape[currentAxis]; + shiftsVector[currentAxis] = (shiftSum % dimSize + dimSize) % dimSize; + } + + const size_t blockSize = shape.back(); + const size_t totalElements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + const size_t leftBlockSize = blockSize - shiftsVector.back(); + const size_t rightBlockSize = blockSize - leftBlockSize; + const size_t elementSize = sizeof(DataType); + + const size_t nIterations = totalElements / blockSize; + const auto strides = dataEdge->getDesc().getBlockingDesc().getStrides(); + parallel_for(nIterations, [&](size_t iter) { + size_t start = iter * blockSize; + size_t leftBlockStartOffset = start; + size_t rightBlockStartOffset = start + leftBlockSize; + + for (int dim = numOfDims - 1; dim >= 0; --dim) { + leftBlockStartOffset = calculateShiftOffset(leftBlockStartOffset, shiftsVector[dim], strides[dim], shape[dim]); + rightBlockStartOffset = calculateShiftOffset(rightBlockStartOffset, shiftsVector[dim], strides[dim], shape[dim]); + } + + if (leftBlockSize > 0) + cpu_memcpy(output + leftBlockStartOffset, + input + start, + leftBlockSize * elementSize); + + + if (rightBlockSize > 0) + cpu_memcpy(output + rightBlockStartOffset, + input + (start + leftBlockSize), + rightBlockSize * elementSize); + }); +} + +bool MKLDNNRollNode::created() const { + return getType() == Roll; +} + +void MKLDNNRollNode::createPrimitive() {} + +const std::vector MKLDNNRollNode::supportedPrecisionSizes = {1, 2, 4}; + +REG_MKLDNN_PRIM_FOR(MKLDNNRollNode, Roll) diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.h new file mode 100644 index 00000000000000..019d65f633299e --- /dev/null +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_roll_node.h @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace MKLDNNPlugin { + +class MKLDNNRollNode : public MKLDNNNode { +public: + MKLDNNRollNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache); + ~MKLDNNRollNode() override = default; + + void getSupportedDescriptors() override; + void initSupportedPrimitiveDescriptors() override; + void createPrimitive() override; + void execute(mkldnn::stream strm) override; + bool created() const override; + +private: + size_t calculateShiftOffset(size_t dataOffset, size_t dimShift, size_t segmentSize, size_t dimSize); + + template + void rollImpl(); + + std::vector shape; + const static std::vector supportedPrecisionSizes; + std::string layerErrorPrefix; + size_t numOfDims; + + const size_t DATA_INDEX = 0ul; + const size_t SHIFT_INDEX = 1ul; + const size_t AXES_INDEX = 2ul; + const size_t numberOfInputs = 3ul; +}; + +} // namespace MKLDNNPlugin diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp new file mode 100644 index 00000000000000..f47d29704ded90 --- /dev/null +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/roll.cpp @@ -0,0 +1,96 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/roll.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { + +const std::vector inputPrecision = { + InferenceEngine::Precision::I8, + InferenceEngine::Precision::U8, + InferenceEngine::Precision::I16, + InferenceEngine::Precision::I32, + InferenceEngine::Precision::FP32, + InferenceEngine::Precision::BF16 +}; + +const auto testCase2DZeroShifts = ::testing::Combine( + ::testing::Values(std::vector{17, 19}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{0, 0}), // Shift + ::testing::Values(std::vector{0, 1}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCase1D = ::testing::Combine( + ::testing::Values(std::vector{16}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{5}), // Shift + ::testing::Values(std::vector{0}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCase2D = ::testing::Combine( + ::testing::Values(std::vector{600, 450}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{300, 250}), // Shift + ::testing::Values(std::vector{0, 1}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCase3D = ::testing::Combine( + ::testing::Values(std::vector{2, 320, 320}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{160, 160}), // Shift + ::testing::Values(std::vector{1, 2}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCaseNegativeUnorderedAxes4D = ::testing::Combine( + ::testing::Values(std::vector{3, 11, 6, 4}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{7, 3}), // Shift + ::testing::Values(std::vector{-3, -2}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCaseRepeatingAxes5D = ::testing::Combine( + ::testing::Values(std::vector{2, 16, 32, 32}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{16, 15, 10, 2, 1, 7, 2, 8, 1, 1}), // Shift + ::testing::Values(std::vector{-1, -2, -3, 1, 0, 3, 3, 2, -2, -3}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCaseNegativeShifts6D = ::testing::Combine( + ::testing::Values(std::vector{4, 16, 3, 6, 5, 2}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{-2, -15, -2, -1, -4, -1}), // Shift + ::testing::Values(std::vector{0, 1, 2, 3, 4, 5}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +const auto testCaseUnordNegAxesAndShifts10D = ::testing::Combine( + ::testing::Values(std::vector{2, 2, 4, 2, 3, 6, 3, 2, 3, 2}), // Input shape + ::testing::ValuesIn(inputPrecision), // Precision + ::testing::Values(std::vector{-2, -1, 1, 1, 1, -2}), // Shift + ::testing::Values(std::vector{-6, -4, -3, 1, -10, -2}), // Axes + ::testing::Values(CommonTestUtils::DEVICE_CPU) +); + +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_2d_zero_shifts, RollLayerTest, testCase2DZeroShifts, RollLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_1d, RollLayerTest, testCase1D, RollLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_2d, RollLayerTest, testCase2D, RollLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_3d, RollLayerTest, testCase3D, RollLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_negative_unordered_axes_4d, RollLayerTest, testCaseNegativeUnorderedAxes4D, RollLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_negative_unordered_axes_5d, RollLayerTest, testCaseRepeatingAxes5D, RollLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_negative_shifts_6d, RollLayerTest, testCaseNegativeShifts6D, RollLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_MKLDNN_TestsRoll_unord_neg_shifts_and_axes_10d, RollLayerTest, testCaseUnordNegAxesAndShifts10D, RollLayerTest::getTestCaseName); + +} // namespace diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/roll.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/roll.hpp new file mode 100644 index 00000000000000..cefc43f7a75b95 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/roll.hpp @@ -0,0 +1,15 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/single_layer/roll.hpp" + +namespace LayerTestsDefinitions { + +TEST_P(RollLayerTest, CompareWithRefs) { + Run(); +}; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/roll.hpp b/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/roll.hpp new file mode 100644 index 00000000000000..97dfcdb7fbc52d --- /dev/null +++ b/inference-engine/tests/functional/shared_test_classes/include/shared_test_classes/single_layer/roll.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "ngraph_functions/builders.hpp" + +namespace LayerTestsDefinitions { + +typedef std::tuple< + InferenceEngine::SizeVector, // Input shapes + InferenceEngine::Precision, // Input precision + std::vector, // Shift + std::vector, // Axes + std::string> rollParams; // Device name + +class RollLayerTest : public testing::WithParamInterface, virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/shared_test_classes/src/single_layer/roll.cpp b/inference-engine/tests/functional/shared_test_classes/src/single_layer/roll.cpp new file mode 100644 index 00000000000000..e54abc943d987d --- /dev/null +++ b/inference-engine/tests/functional/shared_test_classes/src/single_layer/roll.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "shared_test_classes/single_layer/roll.hpp" + +namespace LayerTestsDefinitions { + +std::string RollLayerTest::getTestCaseName(testing::TestParamInfo obj) { + InferenceEngine::SizeVector inputShapes; + InferenceEngine::Precision inputPrecision; + std::vector shift; + std::vector axes; + std::string targetDevice; + std::tie(inputShapes, inputPrecision, shift, axes, targetDevice) = obj.param; + + std::ostringstream result; + result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_"; + result << "Precision=" << inputPrecision.name() << "_"; + result << "Shift=" << CommonTestUtils::vec2str(shift) << "_"; + result << "Axes=" << CommonTestUtils::vec2str(axes) << "_"; + result << "TargetDevice=" << targetDevice; + return result.str(); +} + +void RollLayerTest::SetUp() { + InferenceEngine::SizeVector inputShapes; + InferenceEngine::Precision inputPrecision; + std::vector shift; + std::vector axes; + std::tie(inputShapes, inputPrecision, shift, axes, targetDevice) = this->GetParam(); + auto inType = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inputPrecision); + ngraph::ParameterVector paramVector; + auto paramData = std::make_shared(inType, ngraph::Shape(inputShapes)); + paramVector.push_back(paramData); + + auto shiftNode = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{shift.size()}, shift)->output(0); + auto axesNode = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{axes.size()}, axes)->output(0); + + auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(paramVector)); + auto roll = std::dynamic_pointer_cast(ngraph::builder::makeRoll(paramOuts[0], shiftNode, axesNode)); + + ngraph::ResultVector results{std::make_shared(roll)}; + function = std::make_shared(results, paramVector, "roll"); +} +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp b/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp index 292776c307b12f..18e71981376053 100644 --- a/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp +++ b/inference-engine/tests/ngraph_helpers/ngraph_functions/include/ngraph_functions/builders.hpp @@ -505,5 +505,9 @@ std::shared_ptr makeOneHot(const ngraph::Output& indices, const float& off_val, const int64_t& axis); +std::shared_ptr makeRoll(const ngraph::Output& dataNode, + const ngraph::Output& shiftNode, + const ngraph::Output& axesNode); + } // namespace builder } // namespace ngraph diff --git a/inference-engine/tests/ngraph_helpers/ngraph_functions/src/roll.cpp b/inference-engine/tests/ngraph_helpers/ngraph_functions/src/roll.cpp new file mode 100644 index 00000000000000..9ebe0b6ecdf18b --- /dev/null +++ b/inference-engine/tests/ngraph_helpers/ngraph_functions/src/roll.cpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph_functions/builders.hpp" + +namespace ngraph { +namespace builder { + +std::shared_ptr makeRoll(const ngraph::Output &in, + const ngraph::Output &shift, + const ngraph::Output &axes) { + return std::make_shared(in, shift, axes); +} + +} // namespace builder +} // namespace ngraph diff --git a/ngraph/python/tests/test_ngraph/test_roll.py b/ngraph/python/tests/test_ngraph/test_roll.py index 07426df0816fbd..877e22d098eb3f 100644 --- a/ngraph/python/tests/test_ngraph/test_roll.py +++ b/ngraph/python/tests/test_ngraph/test_roll.py @@ -1,10 +1,8 @@ import ngraph as ng import numpy as np -from tests import xfail_issue_49391 from tests.runtime import get_runtime -@xfail_issue_49391 def test_roll(): runtime = get_runtime() input = np.reshape(np.arange(10), (2, 5))