Skip to content

Commit

Permalink
[PT FE] Support prim::TupleIndex operation (openvinotoolkit#19978)
Browse files Browse the repository at this point in the history
* [PT FE] Support prim::TupleIndex

* Update src/frontends/pytorch/src/op/tuple_index.cpp

* Update src/frontends/pytorch/src/op/tuple_index.cpp
  • Loading branch information
mvafin authored Sep 21, 2023
1 parent f790a3b commit 37d54bc
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
39 changes: 39 additions & 0 deletions src/frontends/pytorch/src/op/tuple_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_tuple_index(const NodeContext& context) {
// prim::TupleIndex(Any tup, int i) -> Any
num_inputs_check(context, 2, 2);
auto tuple = context.get_input(0).get_node_shared_ptr();
if (cast_fw_node(tuple, "prim::TupleConstruct")) {
// this case require index to be constant
auto index = context.const_input<int64_t>(1);
FRONT_END_OP_CONVERSION_CHECK(static_cast<size_t>(index) < tuple->get_input_size(),
"Index of TupleIndex operation is higher then number of tuple elements.");
return {tuple->get_input_source_output(index)};
} else {
// Assume this case is when tuple is represented as tensor
auto index = context.get_input(1);
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
return {std::make_shared<v8::Gather>(context.get_input(0), index, zero)};
}
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ OP_CONVERTER(translate_topk);
OP_CONVERTER(translate_transpose);
OP_CONVERTER(translate_tril);
OP_CONVERTER(translate_triu);
OP_CONVERTER(translate_tuple_index);
OP_CONVERTER(translate_unflatten);
OP_CONVERTER(translate_unfold);
OP_CONVERTER(translate_upsample_bicubic2d);
Expand Down Expand Up @@ -479,6 +480,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
{"prim::TupleIndex", op::translate_tuple_index},
{"quantized::add", op::translate_quantized_add},
{"quantized::add_relu", op::translate_quantized_add_relu},
{"quantized::cat", op::translate_quantized_cat},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ SoftmaxReshapeElimination::SoftmaxReshapeElimination() {

register_matcher(
std::make_shared<ov::pass::pattern::Matcher>(m_reshape1,
"ov::frontend::pytorch::pass::PrimTupleUnpackReplacer"),
"ov::frontend::pytorch::pass::SoftmaxReshapeElimination"),
[=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto reshape0 = pattern_to_output[m_reshape0].get_node_shared_ptr();
Expand Down
38 changes: 29 additions & 9 deletions tests/layer_tests/pytorch_tests/test_tuple_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,21 @@ def forward(self, x):
def prepare_input(self, x):
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)


ref_net = None

return prim_tuple_construct_tuple_unpack(), ref_net, ["prim::TupleConstruct", "prim::TupleUnpack"]

@pytest.mark.nightly
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
self._test(*self.create_model(), ie_device,
precision, ir_version, freeze_model=False)


class TestTupleUnpackParameterSingle(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), )
return ((tensor_gen(), tensor_gen()), )

def create_model(self):
import torch
Expand All @@ -105,7 +105,6 @@ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
x1, x2 = x
return x1, x2


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
Expand All @@ -118,6 +117,7 @@ def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
# generate tensor with a different shape for easier mismatch detection in case of mixed input order

def tensor_gen_2():
return np.random.uniform(0, 50, (2, 3)).astype(np.float32)
return (tensor_gen_2(), (tensor_gen(), tensor_gen()), tensor_gen_2())
Expand All @@ -132,7 +132,6 @@ def forward(self, y1, x: Tuple[torch.Tensor, torch.Tensor], y2):
x1, x2 = x
return x1, x2, y1, y2


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
Expand All @@ -144,7 +143,7 @@ class TestTupleUnpackParameterNested(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), )
return (((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), )

def create_model(self):
import torch
Expand All @@ -158,7 +157,6 @@ def forward(self, x: Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor
y3, y4 = x2
return y1, y2, y3, y4


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
Expand All @@ -170,7 +168,7 @@ class TestTupleUnpackParameterMultiple(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()) )
return ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()))

def create_model(self):
import torch
Expand All @@ -183,9 +181,31 @@ def forward(self, x: Tuple[torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, t
z3, z4 = y
return z1, z2, z3, z4


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)


class TestTupleIndex(PytorchLayerTest):
def _prepare_input(self):
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)

def create_model(self):
import torch
from typing import Tuple

class model(torch.nn.Module):
def forward(self, x):
return self.some_func((x,x))

def some_func(self, x: Tuple[torch.Tensor, torch.Tensor]):
return x[1] * 2, x[0] * 3

return model(), None, "prim::TupleIndex"

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=False, freeze_model=False)

0 comments on commit 37d54bc

Please sign in to comment.