forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Separate Transpose layer support into native and reference (openvinot…
- Loading branch information
Showing
9 changed files
with
195 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (C) 2020-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transpose_arm.hpp" | ||
|
||
using namespace ngraph; | ||
using namespace ArmPlugin; | ||
|
||
constexpr NodeTypeInfo opset::ArmTranspose::type_info; | ||
|
||
opset::ArmTranspose::~ArmTranspose() {} | ||
|
||
opset::ArmTranspose::ArmTranspose(const ngraph::Output<ngraph::Node>& arg, const ngraph::Output<ngraph::Node>& input_order) | ||
: Transpose{arg, input_order} { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
std::shared_ptr<ngraph::Node> ArmPlugin::opset::ArmTranspose::clone_with_new_inputs(const ngraph::OutputVector& new_args) const { | ||
auto num_args = new_args.size(); | ||
if (num_args == 2) { | ||
return std::make_shared<ArmTranspose>(new_args.at(0), new_args.at(1)); | ||
} else { | ||
throw ngraph_error("Unsupported number of arguments for ArmTranspose operation"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "ngraph_opset.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ArmPlugin { | ||
namespace opset { | ||
|
||
class ArmTranspose : public Transpose { | ||
public: | ||
static constexpr ngraph::NodeTypeInfo type_info{"ArmTranspose", 0}; | ||
const ngraph::NodeTypeInfo& get_type_info() const override { return type_info; } | ||
ArmTranspose() = default; | ||
~ArmTranspose() override; | ||
|
||
ArmTranspose(const ngraph::Output<ngraph::Node>& arg, const ngraph::Output<ngraph::Node>& input_order); | ||
|
||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override; | ||
}; | ||
} // namespace opset | ||
} // namespace ArmPlugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
modules/arm_plugin/src/transformations/convert_transpose_arm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (C) 2020-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
#include "transformations/convert_transpose_arm.hpp" | ||
#include "opset/opset.hpp" | ||
#include <ngraph/rt_info.hpp> | ||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
|
||
|
||
ArmPlugin::pass::ConvertTranspose::ConvertTranspose() { | ||
auto transpose = ngraph::pattern::wrap_type<opset::Transpose>(); | ||
|
||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { | ||
auto transpose = std::dynamic_pointer_cast<opset::Transpose>(m.get_match_root()); | ||
if (!transpose) { | ||
return false; | ||
} | ||
|
||
if (transpose->get_shape().size() > 4) { | ||
return false; | ||
} | ||
|
||
auto arm_transpose = std::make_shared<opset::ArmTranspose>(transpose->input_value(0), transpose->input_value(1)); | ||
arm_transpose->set_friendly_name(transpose->get_friendly_name()); | ||
ngraph::copy_runtime_info(transpose, arm_transpose); | ||
ngraph::replace_node(transpose, arm_transpose); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose, "ConvertTranspose"); | ||
register_matcher(m, callback); | ||
} |
15 changes: 15 additions & 0 deletions
15
modules/arm_plugin/src/transformations/convert_transpose_arm.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (C) 2020-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <ngraph/pass/graph_rewrite.hpp> | ||
|
||
namespace ArmPlugin { | ||
namespace pass { | ||
|
||
struct ConvertTranspose: public ngraph::pass::MatcherPass { | ||
ConvertTranspose(); | ||
}; | ||
} // namespace pass | ||
} // namespace ArmPlugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters