From b75de7108f390ac5e0c23f2e4abc0c678d978337 Mon Sep 17 00:00:00 2001 From: bszmelcz Date: Thu, 1 Dec 2022 14:31:35 +0100 Subject: [PATCH 1/7] add aten::roll --- src/frontends/pytorch/src/op/roll.cpp | 35 ++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + tests/layer_tests/pytorch_tests/test_roll.py | 42 ++++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 src/frontends/pytorch/src/op/roll.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_roll.py diff --git a/src/frontends/pytorch/src/op/roll.cpp b/src/frontends/pytorch/src/op/roll.cpp new file mode 100644 index 00000000000000..896dc2c328777a --- /dev/null +++ b/src/frontends/pytorch/src/op/roll.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_roll(NodeContext& context) { + const auto data = context.get_input(0); + const auto shifts = context.get_input(1); + const auto axes = context.get_input(2); + // if axes was not set + if (axes.get_shape() != shifts.get_shape()) { + const auto const_minus_1 = opset8::Constant::create(element::i32, Shape{1}, {-1}); + const auto axis_0 = opset8::Constant::create(element::i32, Shape{1}, {0}); + const auto flat = std::make_shared(data, const_minus_1, false); + const auto roll = std::make_shared(flat, shifts, axis_0); + const auto shape_of_data = std::make_shared(data); + const auto reshape = std::make_shared(roll, shape_of_data, false); + context.mark_nodes({const_minus_1, flat, roll, shape_of_data, reshape}); + return {reshape}; + } + return {context.mark_node(std::make_shared(data, shifts, axes))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index a1a4f29824a3d0..885d9ffa351d96 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -43,6 +43,7 @@ OP_CONVERTER(translate_reciprocal); OP_CONVERTER(translate_relu6); OP_CONVERTER(translate_reshape); OP_CONVERTER(translate_rsub); +OP_CONVERTER(translate_roll); OP_CONVERTER(translate_select); OP_CONVERTER(translate_size); OP_CONVERTER(translate_slice); @@ -110,6 +111,7 @@ const std::map get_supported_ops() { {"aten::relu6", op::translate_relu6}, {"aten::reshape", op::translate_reshape}, {"aten::rsub", op::translate_rsub}, + {"aten::roll", op::translate_roll}, {"aten::select", op::translate_select}, {"aten::sigmoid", op::translate_1to1_match_1_inputs}, {"aten::silu", op::translate_1to1_match_1_inputs}, diff --git a/tests/layer_tests/pytorch_tests/test_roll.py b/tests/layer_tests/pytorch_tests/test_roll.py new file mode 100644 index 00000000000000..9b98211b46409b --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_roll.py @@ -0,0 +1,42 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np +from pytorch_layer_test_class import PytorchLayerTest + + +class TestRoll(PytorchLayerTest): + def _prepare_input(self): + return (np.random.uniform(0, 50, (2, 3, 4)).astype(np.float32),) + + def create_model(self, shifts, dim): + + import torch + + class aten_roll(torch.nn.Module): + def __init__(self, shifts, dim=None): + super(aten_roll, self).__init__() + self.dim = dim + self.shits = shifts + + def forward(self, x): + if self.dim is not None: + return torch.roll(x, self.shits, self.dim) + return torch.roll(x, self.shits) + + ref_net = None + + return aten_roll(shifts, dim), ref_net, "aten::roll" + + @pytest.mark.parametrize(("shifts", "dim"), [ + [(2, 1), (0, 1)], + [1, 0], + [-1, 0], + [1, None], + ]) + @pytest.mark.nightly + def test_roll(self, shifts, dim, ie_device, precision, ir_version): + if ie_device == "CPU": + self._test(*self.create_model(shifts, dim), ie_device, precision, ir_version) + From 446b4184feb42838f17b2dab1039d9c4b39c347a Mon Sep 17 00:00:00 2001 From: bszmelcz Date: Thu, 1 Dec 2022 14:33:36 +0100 Subject: [PATCH 2/7] add empty line --- src/frontends/pytorch/src/op/roll.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/roll.cpp b/src/frontends/pytorch/src/op/roll.cpp index 896dc2c328777a..52398c7dfa7ef0 100644 --- a/src/frontends/pytorch/src/op/roll.cpp +++ b/src/frontends/pytorch/src/op/roll.cpp @@ -32,4 +32,4 @@ OutputVector translate_roll(NodeContext& context) { } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov From 9a7e172fab1f8d90b0bb53ffbbee870c247aa66b Mon Sep 17 00:00:00 2001 From: bszmelcz Date: Mon, 5 Dec 2022 11:41:04 +0100 Subject: [PATCH 3/7] fix style --- .../transforms/prim_list_unpack_replacer.cpp | 20 +---------- tests/layer_tests/pytorch_tests/test_roll.py | 3 +- .../layer_tests/pytorch_tests/test_unbind.py | 35 ------------------- 3 files changed, 2 insertions(+), 56 deletions(-) delete mode 100644 tests/layer_tests/pytorch_tests/test_unbind.py diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index 1135694bd85822..7f2637774c9dec 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -84,24 +84,6 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { return true; } - if (auto unbind = cast_fw_node(input_node, "aten::unbind")) { - const auto input = unbind->get_input_source_output(0); - const auto axis = unbind->get_input_source_output(1); - const auto num_splits = list_unpack->get_output_size(); - auto split = std::make_shared(input, axis, num_splits); - NodeVector to_copy_rt{split}; - OutputVector outputs; - for (auto output: split->outputs()) { - const auto squeeze = std::make_shared(output, axis); - outputs.push_back(squeeze); - to_copy_rt.push_back(squeeze); - } - copy_runtime_info({list_unpack, input_node}, to_copy_rt); - replace_node(list_unpack, outputs); - - return true; - } - if (auto shape_of = std::dynamic_pointer_cast(input_node)) { // case aten::size as input // Number of ListUnpack outputs should be equal to rank of input shape. @@ -133,4 +115,4 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { } // namespace pass } // namespace pytorch } // namespace frontend -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/tests/layer_tests/pytorch_tests/test_roll.py b/tests/layer_tests/pytorch_tests/test_roll.py index 9b98211b46409b..0acc914bdd6618 100644 --- a/tests/layer_tests/pytorch_tests/test_roll.py +++ b/tests/layer_tests/pytorch_tests/test_roll.py @@ -37,6 +37,5 @@ def forward(self, x): ]) @pytest.mark.nightly def test_roll(self, shifts, dim, ie_device, precision, ir_version): - if ie_device == "CPU": - self._test(*self.create_model(shifts, dim), ie_device, precision, ir_version) + self._test(*self.create_model(shifts, dim), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_unbind.py b/tests/layer_tests/pytorch_tests/test_unbind.py deleted file mode 100644 index 1af5c1eb7a8958..00000000000000 --- a/tests/layer_tests/pytorch_tests/test_unbind.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (C) 2018-2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import numpy as np -from pytorch_layer_test_class import PytorchLayerTest - - -class TestUnbind(PytorchLayerTest): - def _prepare_input(self): - return (np.random.uniform(0, 50, (3, 3, 3, 3)).astype(np.float32),) - - def create_model(self, shape): - - import torch - - class aten_unbind(torch.nn.Module): - def __init__(self, dim): - super(aten_unbind, self).__init__() - self.dim = dim - - - def forward(self, x): - # Create aten::unbind -> ListUnpack - a, b, c = torch.unbind(x, self.dim) - return b - - ref_net = None - - return aten_unbind(shape), ref_net, "aten::unbind" - - @pytest.mark.parametrize(("dim"), [0, 1, 2, 3]) - @pytest.mark.nightly - def test_unbind(self, dim, ie_device, precision, ir_version): - self._test(*self.create_model(dim), ie_device, precision, ir_version) From 55586040553b328737917058a49a3bed119d7fba Mon Sep 17 00:00:00 2001 From: bszmelcz Date: Mon, 5 Dec 2022 12:21:05 +0100 Subject: [PATCH 4/7] fix style v2 --- src/frontends/pytorch/src/op/roll.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/roll.cpp b/src/frontends/pytorch/src/op/roll.cpp index 52398c7dfa7ef0..1b862092e04dce 100644 --- a/src/frontends/pytorch/src/op/roll.cpp +++ b/src/frontends/pytorch/src/op/roll.cpp @@ -15,7 +15,7 @@ OutputVector translate_roll(NodeContext& context) { const auto data = context.get_input(0); const auto shifts = context.get_input(1); const auto axes = context.get_input(2); - // if axes was not set + // if axes was not set if (axes.get_shape() != shifts.get_shape()) { const auto const_minus_1 = opset8::Constant::create(element::i32, Shape{1}, {-1}); const auto axis_0 = opset8::Constant::create(element::i32, Shape{1}, {0}); From 90e9dd9b470907b00edf3288a9ce1bc799591fe7 Mon Sep 17 00:00:00 2001 From: bszmelcz Date: Tue, 6 Dec 2022 12:05:59 +0100 Subject: [PATCH 5/7] add check for dynamic shapes --- src/frontends/pytorch/src/op/roll.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/roll.cpp b/src/frontends/pytorch/src/op/roll.cpp index 1b862092e04dce..cca9aec1dda857 100644 --- a/src/frontends/pytorch/src/op/roll.cpp +++ b/src/frontends/pytorch/src/op/roll.cpp @@ -15,8 +15,11 @@ OutputVector translate_roll(NodeContext& context) { const auto data = context.get_input(0); const auto shifts = context.get_input(1); const auto axes = context.get_input(2); - // if axes was not set - if (axes.get_shape() != shifts.get_shape()) { + const auto shifts_pshape = shifts.get_partial_shape(); + const auto axes_pshape = axes.get_partial_shape(); + const auto match_dims = axes_pshape.compatible(shifts_pshape); + const auto both_dynamic = (shifts_pshape.rank().is_dynamic() && axes_pshape.rank().is_dynamic()); + if (!(both_dynamic || match_dims)) { const auto const_minus_1 = opset8::Constant::create(element::i32, Shape{1}, {-1}); const auto axis_0 = opset8::Constant::create(element::i32, Shape{1}, {0}); const auto flat = std::make_shared(data, const_minus_1, false); From fb0c33ea98d0d74b1c56dd7a3b728930c3a5c09d Mon Sep 17 00:00:00 2001 From: Bartek Szmelczynski Date: Tue, 6 Dec 2022 12:15:32 +0100 Subject: [PATCH 6/7] remove random change --- .../pytorch/src/transforms/prim_list_unpack_replacer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index a8b7019a760c85..1135694bd85822 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -133,4 +133,4 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { } // namespace pass } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov From ac4d42f8e534ccd9f752270c303dc3d9e6e485dc Mon Sep 17 00:00:00 2001 From: bszmelcz Date: Tue, 6 Dec 2022 13:03:46 +0100 Subject: [PATCH 7/7] merge statements --- src/frontends/pytorch/src/op/roll.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/frontends/pytorch/src/op/roll.cpp b/src/frontends/pytorch/src/op/roll.cpp index cca9aec1dda857..39b9a37e9f217e 100644 --- a/src/frontends/pytorch/src/op/roll.cpp +++ b/src/frontends/pytorch/src/op/roll.cpp @@ -18,8 +18,7 @@ OutputVector translate_roll(NodeContext& context) { const auto shifts_pshape = shifts.get_partial_shape(); const auto axes_pshape = axes.get_partial_shape(); const auto match_dims = axes_pshape.compatible(shifts_pshape); - const auto both_dynamic = (shifts_pshape.rank().is_dynamic() && axes_pshape.rank().is_dynamic()); - if (!(both_dynamic || match_dims)) { + if (!match_dims) { const auto const_minus_1 = opset8::Constant::create(element::i32, Shape{1}, {-1}); const auto axis_0 = opset8::Constant::create(element::i32, Shape{1}, {0}); const auto flat = std::make_shared(data, const_minus_1, false);