Skip to content

Commit

Permalink
[ONNX] Add support for BitShift operator (#4368)
Browse files Browse the repository at this point in the history
  • Loading branch information
postrational authored Feb 17, 2021
1 parent 45ae389 commit ec9b589
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 12 deletions.
67 changes: 67 additions & 0 deletions ngraph/frontend/onnx_import/src/op/bitshift.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "op/bitshift.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/shape.hpp"

namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
OutputVector bitshift(const Node& node)
{
const Output<ngraph::Node> input_x = node.get_ng_inputs().at(0);
const Output<ngraph::Node> input_y = node.get_ng_inputs().at(1);

std::string direction = node.get_attribute_value<std::string>("direction", "");

CHECK_VALID_NODE(node,
!direction.empty(),
"Required attribute 'direction' is not specified.");

CHECK_VALID_NODE(node,
direction == "LEFT" || direction == "RIGHT",
"Only values 'LEFT' and 'RIGHT' are supported for 'direction' "
"attribute. Given: ",
direction);

auto shift = std::make_shared<default_opset::Power>(
default_opset::Constant::create(input_y.get_element_type(), Shape{1}, {2}),
input_y);

if (direction == "RIGHT")
{
return {std::make_shared<default_opset::Divide>(input_x, shift)};
}
else
{
return {std::make_shared<default_opset::Multiply>(input_x, shift)};
}
}

} // namespace set_1

} // namespace op

} // namespace onnx_import

} // namespace ngraph
40 changes: 40 additions & 0 deletions ngraph/frontend/onnx_import/src/op/bitshift.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <memory>

#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"

namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
OutputVector bitshift(const Node& node);

} // namespace set_1

} // namespace op

} // namespace onnx_import

} // namespace ngraph
2 changes: 2 additions & 0 deletions ngraph/frontend/onnx_import/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "op/atanh.hpp"
#include "op/average_pool.hpp"
#include "op/batch_norm.hpp"
#include "op/bitshift.hpp"
#include "op/cast.hpp"
#include "op/ceil.hpp"
#include "op/clip.hpp"
Expand Down Expand Up @@ -324,6 +325,7 @@ namespace ngraph
REGISTER_OPERATOR("Atanh", 1, atanh);
REGISTER_OPERATOR("AveragePool", 1, average_pool);
REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);
REGISTER_OPERATOR("BitShift", 1, bitshift);
REGISTER_OPERATOR("Cast", 1, cast);
REGISTER_OPERATOR("Ceil", 1, ceil);
REGISTER_OPERATOR("Clip", 1, clip);
Expand Down
2 changes: 0 additions & 2 deletions ngraph/python/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
"MaxUnpool")
xfail_issue_33512 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"Einsum")
xfail_issue_33515 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"BitShift")
xfail_issue_33535 = xfail_test(reason="nGraph does not support the following ONNX operations:"
"DynamicQuantizeLinear")
xfail_issue_33538 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
Expand Down
10 changes: 0 additions & 10 deletions ngraph/python/tests/test_onnx/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tests import (BACKEND_NAME,
xfail_issue_33488,
xfail_issue_33512,
xfail_issue_33515,
xfail_issue_33535,
xfail_issue_33538,
xfail_issue_33540,
Expand Down Expand Up @@ -562,15 +561,6 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None
"OnnxBackendNodeModelTest.test_compress_default_axis_cpu",
"OnnxBackendNodeModelTest.test_compress_1_cpu",
"OnnxBackendNodeModelTest.test_compress_0_cpu"),
(xfail_issue_33515,
"OnnxBackendNodeModelTest.test_bitshift_left_uint8_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint64_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint16_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint32_cpu",
"OnnxBackendNodeModelTest.test_bitshift_right_uint8_cpu",
"OnnxBackendNodeModelTest.test_bitshift_left_uint32_cpu",
"OnnxBackendNodeModelTest.test_bitshift_left_uint16_cpu",
"OnnxBackendNodeModelTest.test_bitshift_left_uint64_cpu"),
(xfail_issue_38732,
"OnnxBackendNodeModelTest.test_convinteger_with_padding_cpu",
"OnnxBackendNodeModelTest.test_basic_convinteger_cpu"),
Expand Down

0 comments on commit ec9b589

Please sign in to comment.