diff --git a/src/frontends/pytorch/src/op/roll.cpp b/src/frontends/pytorch/src/op/roll.cpp new file mode 100644 index 00000000000000..39b9a37e9f217e --- /dev/null +++ b/src/frontends/pytorch/src/op/roll.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_roll(NodeContext& context) { + const auto data = context.get_input(0); + const auto shifts = context.get_input(1); + const auto axes = context.get_input(2); + const auto shifts_pshape = shifts.get_partial_shape(); + const auto axes_pshape = axes.get_partial_shape(); + const auto match_dims = axes_pshape.compatible(shifts_pshape); + if (!match_dims) { + const auto const_minus_1 = opset8::Constant::create(element::i32, Shape{1}, {-1}); + const auto axis_0 = opset8::Constant::create(element::i32, Shape{1}, {0}); + const auto flat = std::make_shared(data, const_minus_1, false); + const auto roll = std::make_shared(flat, shifts, axis_0); + const auto shape_of_data = std::make_shared(data); + const auto reshape = std::make_shared(roll, shape_of_data, false); + context.mark_nodes({const_minus_1, flat, roll, shape_of_data, reshape}); + return {reshape}; + } + return {context.mark_node(std::make_shared(data, shifts, axes))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 4d16b18ee8744a..2ceeb928cd56e1 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -59,6 +59,7 @@ OP_CONVERTER(translate_relu6); OP_CONVERTER(translate_reshape); OP_CONVERTER(translate_reshape_as); OP_CONVERTER(translate_rsub); +OP_CONVERTER(translate_roll); OP_CONVERTER(translate_rsqrt); OP_CONVERTER(translate_select); OP_CONVERTER(translate_size); @@ -155,6 +156,7 @@ const std::map get_supported_ops() { {"aten::reshape", op::translate_reshape}, {"aten::reshape_as", op::translate_reshape_as}, {"aten::rsub", op::translate_rsub}, + {"aten::roll", op::translate_roll}, {"aten::rsqrt", op::translate_rsqrt}, {"aten::select", op::translate_select}, {"aten::sigmoid", op::translate_1to1_match_1_inputs}, diff --git a/tests/layer_tests/pytorch_tests/test_roll.py b/tests/layer_tests/pytorch_tests/test_roll.py new file mode 100644 index 00000000000000..0acc914bdd6618 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_roll.py @@ -0,0 +1,41 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np +from pytorch_layer_test_class import PytorchLayerTest + + +class TestRoll(PytorchLayerTest): + def _prepare_input(self): + return (np.random.uniform(0, 50, (2, 3, 4)).astype(np.float32),) + + def create_model(self, shifts, dim): + + import torch + + class aten_roll(torch.nn.Module): + def __init__(self, shifts, dim=None): + super(aten_roll, self).__init__() + self.dim = dim + self.shits = shifts + + def forward(self, x): + if self.dim is not None: + return torch.roll(x, self.shits, self.dim) + return torch.roll(x, self.shits) + + ref_net = None + + return aten_roll(shifts, dim), ref_net, "aten::roll" + + @pytest.mark.parametrize(("shifts", "dim"), [ + [(2, 1), (0, 1)], + [1, 0], + [-1, 0], + [1, None], + ]) + @pytest.mark.nightly + def test_roll(self, shifts, dim, ie_device, precision, ir_version): + self._test(*self.create_model(shifts, dim), ie_device, precision, ir_version) +