Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for concatenation in Loop #15899

Merged
merged 7 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/core/src/op/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void op::v7::Einsum::validate_and_infer_types() {
for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx) {
const auto& input_type_i = get_input_element_type(input_idx);
NODE_VALIDATION_CHECK(this,
input_type_0 == input_type_i,
input_type_0.compatible(input_type_i),
"Inputs to Einsum operation must have the same type.");
}

Expand Down
8 changes: 6 additions & 2 deletions src/core/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ void op::v5::Loop::validate_and_infer_types() {
if (auto slice_input_description = ov::as_type_ptr<SliceInputDescription>(input_description)) {
auto body_parameter = m_bodies[0]->get_parameters().at(slice_input_description->m_body_parameter_index);
const auto& input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
const auto& input_type = inputs().at(index).get_source_output().get_element_type();
body_parameter->set_element_type(input_type);
if (input_partial_shape.rank().is_dynamic()) {
body_parameter->set_partial_shape(ov::PartialShape::dynamic());
} else {
Expand All @@ -176,19 +178,21 @@ void op::v5::Loop::validate_and_infer_types() {

auto body_parameter = m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index);

auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = input(index).get_partial_shape();
auto input_type = input(index).get_element_type();

body_parameter->set_partial_shape(input_partial_shape);
body_parameter->set_element_type(input_type);
back_edges[merged_input_description->m_body_value_index] = merged_input_description->m_body_parameter_index;
} else if (auto invariant_input_description =
ov::as_type_ptr<v0::TensorIterator::InvariantInputDescription>(input_description)) {
auto body_parameter = m_bodies[0]->get_parameters().at(invariant_input_description->m_body_parameter_index);

auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = input(index).get_partial_shape();
auto input_type = input(index).get_element_type();

body_parameter->set_partial_shape(input_partial_shape);
body_parameter->set_element_type(input_type);
}
}

Expand Down
28 changes: 28 additions & 0 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_cat(NodeContext& context) {
// This translator is only needed to get axis as constant from external scope
num_inputs_check(context, 2, 2);
auto fw_node = std::make_shared<PtFrameworkNode>(context.get_decoder(), OutputVector{context.get_input(0)}, 1);
auto attrs = fw_node->get_attrs();
// If this fails it means axis is dynamic and aten::cat will be converted to fw node in regular pipeline
attrs["axis"] = std::to_string(context.const_input<int64_t>(1));
fw_node->set_attrs(attrs);
return {context.mark_node(std::dynamic_pointer_cast<Node>(fw_node))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/op/list_construct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ OutputVector translate_list_construct(NodeContext& context) {
consts.push_back(unsqueezed_c_node);
}
}
auto list_construct = std::make_shared<v0::Concat>(consts, 0);
auto list_construct = context.mark_node(std::make_shared<v0::Concat>(consts, 0));
if (list_construct->has_evaluate()) {
OutputVector replacements(list_construct->get_output_size());

if (list_construct->constant_fold(replacements, list_construct->input_values())) {
return replacements;
}
}
return {context.mark_output(list_construct)};
return {list_construct};
};

} // namespace op
Expand Down
8 changes: 6 additions & 2 deletions src/frontends/pytorch/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ OutputVector translate_loop(NodeContext& context) {
loop->set_special_body_ports(spec_ports);

auto body_parameters = body->get_parameters();
// #0 body parameter is counter; #0 loop input is counter, #1 loop input is condition
// #0 body parameter is counter;
FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required");
// Set counter type and shape
body_parameters[0]->set_element_type(element::i32);
body_parameters[0]->set_partial_shape(PartialShape{});
// #0 loop input is trip_count, #1 loop input is condition
// Connect other inputs
for (size_t i = 2; i < inputs.size(); i++) {
loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]});
Expand All @@ -39,7 +44,6 @@ OutputVector translate_loop(NodeContext& context) {
auto external_output = context.get_tensor_from_model_or_create_input(input_idx);
loop->set_invariant_inputs(external_output, {param});
}
// TODO: Connect back edges (merged inputs)
auto body_results = body->get_results();
FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition.");
std::set<size_t> output_idxs;
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_cat);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
Expand Down Expand Up @@ -160,7 +161,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::Bool", op::translate_bool},
// {"aten::cat", done as transformation},
{"aten::cat", op::translate_cat},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
{"aten::clamp", op::translate_clamp},
Expand Down
69 changes: 60 additions & 9 deletions src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/loop.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
Expand Down Expand Up @@ -37,17 +38,67 @@ AtenCatToConcat::AtenCatToConcat() {
if (!cat)
return false;

auto axis_node = cat->input(1).get_source_output().get_node_shared_ptr();
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
if (!axis_const)
return false;
auto axis = axis_const->cast_vector<int64_t>();
if (axis.size() != 1)
return false;
int64_t axis;
if (cat->get_input_size() > 1) {
auto axis_node = cat->get_input_node_shared_ptr(1);
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
if (!axis_const)
return false;
auto _axis = axis_const->cast_vector<int64_t>();
if (_axis.size() != 1)
return false;
axis = _axis[0];
} else {
const auto& attrs = cat->get_attrs();
if (attrs.find("axis") == attrs.end())
return false;
axis = std::stoll(attrs.at("axis"));
}

std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0);
if (auto loop = std::dynamic_pointer_cast<ov::op::v5::Loop>(input_node)) {
// case when concatenation is done inside the Loop
auto body = loop->get_function();
auto output_index = cat->input(0).get_source_output().get_index();
int64_t body_result_index = -1;
for (auto out_desc : loop->get_output_descriptions()) {
if (out_desc->m_output_index == output_index) {
body_result_index = static_cast<int64_t>(out_desc->m_body_value_index);
break;
}
}
FRONT_END_GENERAL_CHECK(body_result_index >= 0, "Couldn't find descriptor for output.");
auto body_result = body->get_results()[body_result_index];
auto append = cast_fw_node(body_result->get_input_node_shared_ptr(0), "aten::append");
if (!append)
return false;
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(append->get_input_node_shared_ptr(0));
if (!param)
return false;
auto body_param_index = body->get_parameter_index(param);
FRONT_END_GENERAL_CHECK(body_param_index >= 0, "Couldn't find parameter in body parameters.");
int64_t input_index = -1;
for (auto in_desc : loop->get_input_descriptions()) {
if (in_desc->m_body_parameter_index == static_cast<size_t>(body_param_index)) {
input_index = static_cast<int64_t>(in_desc->m_input_index);
break;
}
}
FRONT_END_GENERAL_CHECK(input_index >= 0, "Couldn't find descriptor for input.");
auto list_construct = cast_fw_node(loop->get_input_node_shared_ptr(input_index), "prim::ListConstruct");
if (!list_construct || list_construct->get_input_size() > 0)
return false;
// TODO: Is unsqueeze needed?
auto new_result = std::make_shared<ov::op::v0::Result>(append->input_value(1));
body->add_results({new_result});
auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis);
copy_runtime_info(cat, loop);
cat->output(0).replace(new_output);
return true;
}

OutputVector tmp_inputs;
NodeVector rt_copy_from{cat};
std::shared_ptr<Node> input_node = cat->input(0).get_source_output().get_node_shared_ptr();
while (const auto& input_fw_node = cast_fw_node(input_node, "aten::append")) {
rt_copy_from.push_back(input_fw_node);
tmp_inputs.push_back(input_fw_node->input(1).get_source_output());
Expand All @@ -62,7 +113,7 @@ AtenCatToConcat::AtenCatToConcat() {
inputs.push_back(input.get_source_output());
}
inputs.insert(inputs.end(), tmp_inputs.rbegin(), tmp_inputs.rend());
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis[0]);
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis);
copy_runtime_info(rt_copy_from, result);
replace_node(cat, result);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti
else:
torch_inputs = [torch.from_numpy(inp) for inp in inputs]
model = torch.jit.trace(model, torch_inputs)
model = torch.jit.freeze(model)
if kwargs.get('freeze_model', True):
model = torch.jit.freeze(model)
graph = model.inlined_graph
print(graph)

Expand Down
2 changes: 1 addition & 1 deletion tests/layer_tests/pytorch_tests/test_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def forward_scalar(self, x:int):
@pytest.mark.parametrize("input_type", ["tensor", "scalar"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_ceil(self, ie_device, precision, ir_version, input_type):
def test_bool(self, ie_device, precision, ir_version, input_type):
self._test(*self.create_model(input_type), ie_device, precision, ir_version)
51 changes: 51 additions & 0 deletions tests/layer_tests/pytorch_tests/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class aten_cat(torch.nn.Module):
def forward(self, x):
return torch.cat([x, x], 1)


class aten_append_cat(torch.nn.Module):
def forward(self, x):
list = []
list.append(x)
list.append(x)
return torch.cat(list, 1)

class aten_loop_append_cat(torch.nn.Module):
def forward(self, x):
list = []
for i in range(3):
list.append(x)
return torch.cat(list, 1)


class TestCat(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 1, 3),)

@pytest.mark.nightly
@pytest.mark.precommit
def test_cat(self, ie_device, precision, ir_version):
self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"],
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
def test_append_cat(self, ie_device, precision, ir_version):
self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"],
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
def test_loop_append_cat(self, ie_device, precision, ir_version):
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
ie_device, precision, ir_version, freeze_model=False)