Skip to content

Commit

Permalink
Separate Transpose layer support into native and reference (openvinot…
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat authored Mar 27, 2021
1 parent 5c84617 commit dce274a
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 8 deletions.
3 changes: 2 additions & 1 deletion modules/arm_plugin/src/arm_converter/arm_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Converter::Converter(const std::shared_ptr<const ngraph::Function> function, boo
Register<opset::Clamp>();
Register<opset::Sqrt>();
Register<opset::Elu>();
Register<opset::Transpose>();
Register<opset::ArmTranspose>();
Register<opset::Softmax>();
Register<opset::Split>();
Register<opset::LRN>();
Expand Down Expand Up @@ -112,6 +112,7 @@ Converter::Converter(const std::shared_ptr<const ngraph::Function> function, boo
Register<opset::NormalizeL2>();
Register<opset::Interpolate>();
Register<opset::Concat>();
Register<opset::Transpose>();
Register<opset::ROIPooling>();
Register<opset::PSROIPooling>();
Register<opset::TopK>();
Expand Down
70 changes: 69 additions & 1 deletion modules/arm_plugin/src/arm_converter/arm_converter_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@


#include <arm_compute/runtime/NEON/functions/NEPermute.h>
#include <ngraph/runtime/reference/transpose.hpp>
#include "arm_converter/arm_converter.hpp"

namespace ArmPlugin {
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Transpose& node) {
template<> Converter::Conversion::Ptr Converter::Convert(const opset::ArmTranspose& node) {
enum {Data, Order};
auto&& inputOrder = std::dynamic_pointer_cast<ngraph::op::Constant>(
node.input_value(Order).get_node_shared_ptr())->cast_vector<size_t>();

if (inputOrder.empty()) {
inputOrder.resize(node.get_input_shape(0).size());
std::iota(inputOrder.begin(), inputOrder.end(), 0);
std::reverse(inputOrder.begin(), inputOrder.end());
}
arm_compute::PermutationVector order;
const auto maxSupportedNumOfDimensions = (inputOrder.size() < 4) ? 3u : 4u;
for (unsigned int i = 0; i < maxSupportedNumOfDimensions; ++i) {
Expand All @@ -20,4 +27,65 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::Transpose&
}
return MakeConversion<arm_compute::NEPermute>(node.input(0), node.output(0), order);
}

template<> Converter::Conversion::Ptr Converter::Convert(const opset::Transpose& node) {
auto make = [&] (auto refFunction) {
if (ngraph::shape_size(node.get_input_shape(1)) == 0) {
return MakeConversion(refFunction,
node.input(0),
node.output(0),
node.get_input_shape(0),
nullptr);
}
return MakeConversion(refFunction,
node.input(0),
node.output(0),
node.get_input_shape(0),
node.input(1));
};

switch (node.get_input_element_type(0)) {
case ngraph::element::Type_t::u8 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<std::uint8_t, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<std::uint8_t, std::int64_t>);
case ngraph::element::Type_t::i16 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<std::int16_t, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<std::int16_t, std::int64_t>);
case ngraph::element::Type_t::u16 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<std::uint16_t, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<std::uint16_t, std::int64_t>);
case ngraph::element::Type_t::u32 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<std::uint32_t, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<std::uint32_t, std::int64_t>);
case ngraph::element::Type_t::i32 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<std::int32_t, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<std::int32_t, std::int64_t>);
case ngraph::element::Type_t::i64 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<std::int64_t, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<std::int64_t, std::int64_t>);
case ngraph::element::Type_t::f16 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<ngraph::float16, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<ngraph::float16, std::int64_t>);
case ngraph::element::Type_t::f32 :
if (node.get_input_element_type(1) == ngraph::element::i32) {
return make(ngraph::runtime::reference::transpose<float, std::int32_t>);
}
return make(ngraph::runtime::reference::transpose<float, std::int64_t>);
default: IE_THROW() << "Unsupported Type: " << node.get_input_element_type(0); return {};
}
}
} // namespace ArmPlugin
1 change: 1 addition & 0 deletions modules/arm_plugin/src/opset/opset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
#include "matmul_bias.hpp"
#include "mvn_arm.hpp"
#include "normalizel2_arm.hpp"
#include "transpose_arm.hpp"
#include "ngraph_opset.hpp"
#include "utils.hpp"
26 changes: 26 additions & 0 deletions modules/arm_plugin/src/opset/transpose_arm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transpose_arm.hpp"

using namespace ngraph;
using namespace ArmPlugin;

constexpr NodeTypeInfo opset::ArmTranspose::type_info;

opset::ArmTranspose::~ArmTranspose() {}

opset::ArmTranspose::ArmTranspose(const ngraph::Output<ngraph::Node>& arg, const ngraph::Output<ngraph::Node>& input_order)
: Transpose{arg, input_order} {
constructor_validate_and_infer_types();
}

std::shared_ptr<ngraph::Node> ArmPlugin::opset::ArmTranspose::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
auto num_args = new_args.size();
if (num_args == 2) {
return std::make_shared<ArmTranspose>(new_args.at(0), new_args.at(1));
} else {
throw ngraph_error("Unsupported number of arguments for ArmTranspose operation");
}
}
25 changes: 25 additions & 0 deletions modules/arm_plugin/src/opset/transpose_arm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph_opset.hpp"
#include "utils.hpp"

namespace ArmPlugin {
namespace opset {

class ArmTranspose : public Transpose {
public:
static constexpr ngraph::NodeTypeInfo type_info{"ArmTranspose", 0};
const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; }
ArmTranspose() = default;
~ArmTranspose() override;

ArmTranspose(const ngraph::Output<ngraph::Node>& arg, const ngraph::Output<ngraph::Node>& input_order);

std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
};
} // namespace opset
} // namespace ArmPlugin
3 changes: 3 additions & 0 deletions modules/arm_plugin/src/transformations/arm_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "decompose_swish.hpp"
#include "convert_shuffle_channels.hpp"
#include "convert_tile_to_concats.hpp"
#include "convert_transpose_arm.hpp"
#include "convert_prelu.hpp"
#include "convert_mvn_arm.hpp"
#include "convert_reduce_multi_axis.hpp"
Expand Down Expand Up @@ -118,6 +119,8 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<pass::BroadcastPRelu>();
manager.register_pass<pass::ConvertLogical>();
manager.register_pass<pass::ConvertComparison>();
manager.register_pass<pass::ConvertTranspose>();
manager.register_pass<ngraph::pass::ConstantFolding>();

manager.register_pass<pass::ConvertRound>();
manager.register_pass<pass::ConvertSign>();
Expand Down
33 changes: 33 additions & 0 deletions modules/arm_plugin/src/transformations/convert_transpose_arm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0


#include "transformations/convert_transpose_arm.hpp"
#include "opset/opset.hpp"
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>


ArmPlugin::pass::ConvertTranspose::ConvertTranspose() {
auto transpose = ngraph::pattern::wrap_type<opset::Transpose>();

ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
auto transpose = std::dynamic_pointer_cast<opset::Transpose>(m.get_match_root());
if (!transpose) {
return false;
}

if (transpose->get_shape().size() > 4) {
return false;
}

auto arm_transpose = std::make_shared<opset::ArmTranspose>(transpose->input_value(0), transpose->input_value(1));
arm_transpose->set_friendly_name(transpose->get_friendly_name());
ngraph::copy_runtime_info(transpose, arm_transpose);
ngraph::replace_node(transpose, arm_transpose);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(transpose, "ConvertTranspose");
register_matcher(m, callback);
}
15 changes: 15 additions & 0 deletions modules/arm_plugin/src/transformations/convert_transpose_arm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <ngraph/pass/graph_rewrite.hpp>

namespace ArmPlugin {
namespace pass {

struct ConvertTranspose: public ngraph::pass::MatcherPass {
ConvertTranspose();
};
} // namespace pass
} // namespace ArmPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP16
};

// Empty order is not supported yet: CVS-32756
std::vector<std::vector<size_t>> inputShape2D = {{2, 10}, {10, 2}, {10, 10}};
std::vector<std::vector<size_t>> order2D = {{0, 1}, {1, 0}, /*{}*/};
std::vector<std::vector<size_t>> order2D = {{0, 1}, {1, 0}, {}};

INSTANTIATE_TEST_CASE_P(smoke_Transpose2D, TransposeLayerTest,
::testing::Combine(
Expand All @@ -31,11 +30,9 @@ INSTANTIATE_TEST_CASE_P(smoke_Transpose2D, TransposeLayerTest,
::testing::Values(CommonTestUtils::DEVICE_CPU)),
TransposeLayerTest::getTestCaseName);

// TODO: fix Transpose for tensors with equal dimensions
std::vector<std::vector<size_t>> inputShape4D = {/*{2, 2, 2, 2},*/ {1, 10, 2, 3}, {2, 3, 4, 5}};
std::vector<std::vector<size_t>> inputShape4D = {{2, 2, 2, 2}, {1, 10, 2, 3}, {2, 3, 4, 5}};
std::vector<std::vector<size_t>> order4D = {
// {}
{0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1},
{}, {0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1},
{1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0},
{2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0},
{3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0}
Expand All @@ -52,4 +49,22 @@ INSTANTIATE_TEST_CASE_P(smoke_Transpose4D, TransposeLayerTest,
::testing::ValuesIn(inputShape4D),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
TransposeLayerTest::getTestCaseName);

std::vector<std::vector<size_t>> inputShape5D = {{2, 2, 2, 2, 2}, {1, 10, 2, 3, 4}, {2, 3, 4, 5, 6}};
std::vector<std::vector<size_t>> order5D = {
{}, {0, 1, 2, 3, 4}, {1, 0, 2, 3, 4}, {4, 3, 2, 1, 0}, {0, 2, 3, 4, 1},
{1, 4, 2, 3, 0}, {2, 4, 1, 0, 3}, {3, 0, 2, 1, 4}, {4, 1, 0, 3, 2}
};

INSTANTIATE_TEST_CASE_P(smoke_Transpose5D, TransposeLayerTest,
::testing::Combine(
::testing::ValuesIn(order5D),
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::ValuesIn(inputShape5D),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
TransposeLayerTest::getTestCaseName);
} // namespace

0 comments on commit dce274a

Please sign in to comment.