Skip to content

Commit

Permalink
Merge pull request openvinotoolkit#56 from bszmelcz/add_aten_roll
Browse files Browse the repository at this point in the history
Add support for aten::roll
  • Loading branch information
slyalin authored Dec 6, 2022
2 parents 7056c28 + 6e409e0 commit 88b12d9
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/frontends/pytorch/src/op/roll.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset8.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_roll(NodeContext& context) {
const auto data = context.get_input(0);
const auto shifts = context.get_input(1);
const auto axes = context.get_input(2);
const auto shifts_pshape = shifts.get_partial_shape();
const auto axes_pshape = axes.get_partial_shape();
const auto match_dims = axes_pshape.compatible(shifts_pshape);
if (!match_dims) {
const auto const_minus_1 = opset8::Constant::create(element::i32, Shape{1}, {-1});
const auto axis_0 = opset8::Constant::create(element::i32, Shape{1}, {0});
const auto flat = std::make_shared<opset8::Reshape>(data, const_minus_1, false);
const auto roll = std::make_shared<opset8::Roll>(flat, shifts, axis_0);
const auto shape_of_data = std::make_shared<opset8::ShapeOf>(data);
const auto reshape = std::make_shared<opset8::Reshape>(roll, shape_of_data, false);
context.mark_nodes({const_minus_1, flat, roll, shape_of_data, reshape});
return {reshape};
}
return {context.mark_node(std::make_shared<opset8::Roll>(data, shifts, axes))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ OP_CONVERTER(translate_relu6);
OP_CONVERTER(translate_reshape);
OP_CONVERTER(translate_reshape_as);
OP_CONVERTER(translate_rsub);
OP_CONVERTER(translate_roll);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_select);
OP_CONVERTER(translate_size);
Expand Down Expand Up @@ -155,6 +156,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::reshape", op::translate_reshape},
{"aten::reshape_as", op::translate_reshape_as},
{"aten::rsub", op::translate_rsub},
{"aten::roll", op::translate_roll},
{"aten::rsqrt", op::translate_rsqrt},
{"aten::select", op::translate_select},
{"aten::sigmoid", op::translate_1to1_match_1_inputs<opset8::Sigmoid>},
Expand Down
41 changes: 41 additions & 0 deletions tests/layer_tests/pytorch_tests/test_roll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest


class TestRoll(PytorchLayerTest):
def _prepare_input(self):
return (np.random.uniform(0, 50, (2, 3, 4)).astype(np.float32),)

def create_model(self, shifts, dim):

import torch

class aten_roll(torch.nn.Module):
def __init__(self, shifts, dim=None):
super(aten_roll, self).__init__()
self.dim = dim
self.shits = shifts

def forward(self, x):
if self.dim is not None:
return torch.roll(x, self.shits, self.dim)
return torch.roll(x, self.shits)

ref_net = None

return aten_roll(shifts, dim), ref_net, "aten::roll"

@pytest.mark.parametrize(("shifts", "dim"), [
[(2, 1), (0, 1)],
[1, 0],
[-1, 0],
[1, None],
])
@pytest.mark.nightly
def test_roll(self, shifts, dim, ie_device, precision, ir_version):
self._test(*self.create_model(shifts, dim), ie_device, precision, ir_version)

0 comments on commit 88b12d9

Please sign in to comment.