Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose/flatten tuple inputs #18092

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
96f12ca
prim::TupleUnpack and prim::ListUnpack removing transformation in PT …
slyalin Jun 16, 2023
be7de17
Merge remote-tracking branch 'origin/master' into decompose_tuple_inp…
slyalin Jun 16, 2023
38b2699
Enabled tuples and lists as items in example_inputs
slyalin Jun 16, 2023
1b75078
Merged from master
slyalin Jun 22, 2023
308260f
Merge remote-tracking branch 'origin/master' into decompose_tuple_inp…
slyalin Jun 22, 2023
b2ae3b3
Applied code style
slyalin Jun 23, 2023
f734188
Added tests for tuples as inputs and extended test infrastructure to …
slyalin Jun 23, 2023
72edbc3
Merge branch 'master' into decompose_tuple_inputs_pytorch
eaidova Jun 23, 2023
eeca8f0
Merge branch 'master' into decompose_tuple_inputs_pytorch
andrei-kochin Jun 23, 2023
382c75c
Negligible performance optimizations
slyalin Jun 23, 2023
f0ebcc6
Merge branch 'master' into decompose_tuple_inputs_pytorch
slyalin Jun 23, 2023
ed73810
Fixed duplicated names of test classes
slyalin Jun 24, 2023
10d5a6a
Merged from master
slyalin Jun 24, 2023
ec0a7c0
Merged from remote
slyalin Jun 24, 2023
b83293b
Added description for tuple flattening transformation
slyalin Jun 24, 2023
7e78d0f
Removed any support for list flattening on inputs; fixed layer tests
slyalin Jun 24, 2023
b145c6d
Fixed style
slyalin Jun 26, 2023
49edaaf
Merge remote-tracking branch 'origin/master' into decompose_tuple_inp…
slyalin Jun 26, 2023
6dca5a9
Merge branch 'decompose_tuple_inputs_pytorch' of https://github.com/s…
slyalin Jun 26, 2023
9d0fce1
Merge branch 'master' into decompose_tuple_inputs_pytorch
akladiev Jun 27, 2023
affbeba
Fixed order of new Parameters and Results while flattening tuples
slyalin Jun 28, 2023
08b44b7
Merge branch 'decompose_tuple_inputs_pytorch' of https://github.com/s…
slyalin Jun 28, 2023
b094fe3
Merge remote-tracking branch 'origin/master' into decompose_tuple_inp…
slyalin Jun 28, 2023
d077dd5
Fixed style
slyalin Jun 28, 2023
47e1616
Merge branch 'master' into decompose_tuple_inputs_pytorch
slyalin Jul 5, 2023
9656160
Merge remote-tracking branch 'origin/master' into decompose_tuple_inp…
slyalin Jul 5, 2023
a3ab899
Better diagnostics when not all prim::TupleUnpack ops after Parameter…
slyalin Jul 5, 2023
a3309d9
Small fix in diagnostics message
slyalin Jul 5, 2023
7a38d4d
Merge branch 'master' into decompose_tuple_inputs_pytorch
slyalin Jul 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "transforms/min_max_prim_list_construct_replacer.hpp"
#include "transforms/prim_list_construct_pad.hpp"
#include "transforms/prim_list_tuple_construct_replacer.hpp"
#include "transforms/prim_list_tuple_unpack_replacer.hpp"
#include "transforms/prim_list_unpack_replacer.hpp"
#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/string_equality_replacer.hpp"
Expand Down Expand Up @@ -175,6 +176,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleParameters>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "prim_list_tuple_unpack_replacer.hpp"

#include <queue>

#include "openvino/frontend/pytorch/decoder.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

bool DecomposeListTupleParameters::run_on_model(const std::shared_ptr<Model>& model) {
bool at_least_one_decomposed = false;
std::queue<std::shared_ptr<ov::op::v0::Parameter>> parameters;
for (auto par : model->get_parameters()) {
slyalin marked this conversation as resolved.
Show resolved Hide resolved
parameters.push(par);
}
while (!parameters.empty()) {
slyalin marked this conversation as resolved.
Show resolved Hide resolved
auto parameter = parameters.front();
parameters.pop();
auto consumers = parameter->get_output_target_inputs(0);
size_t num_outputs = 0; // number of outputs in each unpack consumer should match
bool all_unpacks = true;

// collects all outputs per each consumer operation for this tuple/list Parameter
std::vector<OutputVector> consumer_outputs;

for (auto consumer : consumers) {
slyalin marked this conversation as resolved.
Show resolved Hide resolved
auto node = consumer.get_node()->shared_from_this();
auto tuple_unpack = cast_fw_node(node, "prim::TupleUnpack");
auto list_unpack = cast_fw_node(node, "prim::ListUnpack");
if (!tuple_unpack && !list_unpack) {
all_unpacks = false;
break;
}
if (num_outputs == 0) {
num_outputs = node->get_output_size();
} else if (num_outputs != node->get_output_size()) {
std::cerr << "[ PT FE WARNING ] Unpack node " << node
<< " as one of the consumers of tuple/list object has number of outputs "
<< node->get_output_size() << " not matching number of outputs " << num_outputs
<< " for other consumer.\n";
all_unpacks = false;
break;
}
consumer_outputs.push_back(node->outputs());
}

if (!all_unpacks || consumer_outputs.empty()) {
// if at least one consumer is not an unpack-like op or there are not matching number of unpacked objects,
// we cannot replace other unpacks even if they exist, leaving Unpack-op(s) in the graph for this Parameter
continue;
}

for (size_t i = 0; i < num_outputs; ++i) {
// Merged partial shape and element type among all the consumers of i-th result of unpack ops
PartialShape ps = PartialShape::dynamic();
element::Type et = element::dynamic;
std::set<Input<Node>> inputs;

for (auto outputs : consumer_outputs) {
slyalin marked this conversation as resolved.
Show resolved Hide resolved
auto output = outputs[i];
OPENVINO_ASSERT(PartialShape::merge_into(ps, output.get_partial_shape()),
"Consumers for unpack op have incompatible shape");
OPENVINO_ASSERT(element::Type::merge(et, et, output.get_element_type()),
"Consumers for unpack op have incompatible types");
auto target_inputs = output.get_target_inputs();
inputs.insert(target_inputs.begin(), target_inputs.end());
}

auto new_parameter = std::make_shared<ov::op::v0::Parameter>(et, ps);

for (auto input : inputs) {
auto names = input.get_tensor().get_names();
input.replace_source_output(new_parameter->output(0));
new_parameter->output(0).add_names(names);
}

// TODO: Assign correct names
slyalin marked this conversation as resolved.
Show resolved Hide resolved
model->add_parameters({new_parameter});
parameters.push(new_parameter);
model->remove_parameter(parameter);
at_least_one_decomposed = true;
}
}

return at_least_one_decomposed;
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

slyalin marked this conversation as resolved.
Show resolved Hide resolved
class DecomposeListTupleParameters : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::DecomposeListTupleParameters");
bool run_on_model(const std::shared_ptr<Model>& model) override;
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
29 changes: 19 additions & 10 deletions tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,15 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti
else:
inputs = self._prepare_input()

torch_inputs = [torch.from_numpy(inp) if isinstance(
inp, np.ndarray) else inp for inp in inputs]
def numpy_to_torch_recursively(x):
if isinstance(x, tuple):
return tuple(numpy_to_torch_recursively(y) for y in x)
elif isinstance(x, np.ndarray):
return torch.from_numpy(x)
else:
return x

torch_inputs = [numpy_to_torch_recursively(inp) for inp in inputs]

if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None:
custom_eps = kwargs['custom_eps']
Expand All @@ -61,14 +68,16 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti
def use_ts_backend():
return(os.environ.get('USE_TS_BACKEND', False))

ov_inputs = flattenize(inputs)

if use_ts_backend():
self.ts_backend_test(model, torch_inputs, custom_eps)
else:
with torch.no_grad():
model.eval()
trace_model = kwargs.get('trace_model', False)
freeze_model = kwargs.get('freeze_model', True)
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, inputs, freeze_model)
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
graph = model.inlined_graph

if kind is not None and not isinstance(kind, (tuple, list)):
Expand All @@ -80,7 +89,7 @@ def use_ts_backend():
# OV infer:
core = Core()
compiled = core.compile_model(converted_model, ie_device)
infer_res = compiled(deepcopy(inputs))
infer_res = compiled(deepcopy(ov_inputs))

if hasattr(self, 'skip_framework') and self.skip_framework:
warnings.warn('Framework is skipped')
Expand All @@ -96,7 +105,7 @@ def use_ts_backend():

flatten_fw_res = []

flatten_fw_res = flattenize_outputs(fw_res)
flatten_fw_res = flattenize(fw_res)

assert len(flatten_fw_res) == len(
output_list), f'number of outputs not equal, {len(flatten_fw_res)} != {len(output_list)}'
Expand Down Expand Up @@ -212,8 +221,8 @@ def ts_backend_test(self, model, inputs, custom_eps):
ov_res = (ov_res,)

flatten_fw_res, flatten_ov_res = [], []
flatten_fw_res = flattenize_outputs(fw_res)
flatten_ov_res = flattenize_outputs(ov_res)
flatten_fw_res = flattenize(fw_res)
flatten_ov_res = flattenize(ov_res)

assert len(flatten_fw_res) == len(
flatten_ov_res
Expand Down Expand Up @@ -268,18 +277,18 @@ def get_params(ie_device=None, precision=None):

def flattenize_dict_outputs(res):
if isinstance(res, dict):
return flattenize_outputs(res.values())
return flattenize(res.values())


def flattenize_outputs(res):
def flattenize(res):
results = []
for res_item in res:
# if None is at output we skip it
if res_item is None:
continue
# If input is list or tuple flattenize it
if isinstance(res_item, (list, tuple)):
decomposed_res = flattenize_outputs(res_item)
decomposed_res = flattenize(res_item)
results.extend(decomposed_res)
continue
if isinstance(res_item, dict):
Expand Down
77 changes: 76 additions & 1 deletion tests/layer_tests/pytorch_tests/test_tuple_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,81 @@ def forward(self, x):
@pytest.mark.nightly
def test_tuple_construct(self, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version)


class TestTupleUnpackParameterSingle(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), )

def create_model(self):
import torch
from typing import Tuple

class model(torch.nn.Module):

def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
x1, x2 = x
return x1, x2


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)


class TestTupleUnpackParameterSingle(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), )

def create_model(self):
import torch
from typing import Tuple

class model(torch.nn.Module):

def forward(self, x: Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]):
x1, x2 = x
y1, y2 = x1
y3, y4 = x2
return y1, y2, y3, y4


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)


class TestTupleUnpackParameterSingle(PytorchLayerTest):
def _prepare_input(self):
def tensor_gen():
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
return ( (tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()) )

def create_model(self):
import torch
from typing import Tuple

class model(torch.nn.Module):

def forward(self, x: Tuple[torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]):
z1, z2 = x
z3, z4 = y
return z1, z2, z3, z4


return model(), None, ["prim::TupleUnpack"]

@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)


class TestTupleConstructTupleUnpack(PytorchLayerTest):
Expand All @@ -80,4 +155,4 @@ def prepare_input(self, x):

@pytest.mark.nightly
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def to_torch_tensor(tensor):
return torch.tensor(tensor.data)
if isinstance(tensor, (float, int, bool)):
return tensor
if isinstance(tensor, tuple):
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
return tuple(to_torch_tensor(x) for x in tensor)
if isinstance(tensor, list):
return [to_torch_tensor(x) for x in tensor]
else:
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
"Got {}".format(type(tensor)))
Expand Down