Skip to content

Commit

Permalink
Add support for concatenation in Loop (openvinotoolkit#15899)
Browse files Browse the repository at this point in the history
* Add support for concatenation in Loop

* Apply suggestions from code review

* Fix win build

* Fix issues with propagation shapes and types in Loop

* Fix einsum

* Set type and shape of count in frontend
  • Loading branch information
mvafin authored and andrei-cv committed Mar 21, 2023
1 parent c9f1d2d commit 586d309
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/core/src/op/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void op::v7::Einsum::validate_and_infer_types() {
for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx) {
const auto& input_type_i = get_input_element_type(input_idx);
NODE_VALIDATION_CHECK(this,
input_type_0 == input_type_i,
input_type_0.compatible(input_type_i),
"Inputs to Einsum operation must have the same type.");
}

Expand Down
8 changes: 6 additions & 2 deletions src/core/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ void op::v5::Loop::validate_and_infer_types() {
if (auto slice_input_description = ov::as_type_ptr<SliceInputDescription>(input_description)) {
auto body_parameter = m_bodies[0]->get_parameters().at(slice_input_description->m_body_parameter_index);
const auto& input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
const auto& input_type = inputs().at(index).get_source_output().get_element_type();
body_parameter->set_element_type(input_type);
if (input_partial_shape.rank().is_dynamic()) {
body_parameter->set_partial_shape(ov::PartialShape::dynamic());
} else {
Expand All @@ -176,19 +178,21 @@ void op::v5::Loop::validate_and_infer_types() {

auto body_parameter = m_bodies[0]->get_parameters().at(merged_input_description->m_body_parameter_index);

auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = input(index).get_partial_shape();
auto input_type = input(index).get_element_type();

body_parameter->set_partial_shape(input_partial_shape);
body_parameter->set_element_type(input_type);
back_edges[merged_input_description->m_body_value_index] = merged_input_description->m_body_parameter_index;
} else if (auto invariant_input_description =
ov::as_type_ptr<v0::TensorIterator::InvariantInputDescription>(input_description)) {
auto body_parameter = m_bodies[0]->get_parameters().at(invariant_input_description->m_body_parameter_index);

auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = input(index).get_partial_shape();
auto input_type = input(index).get_element_type();

body_parameter->set_partial_shape(input_partial_shape);
body_parameter->set_element_type(input_type);
}
}

Expand Down
28 changes: 28 additions & 0 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_cat(NodeContext& context) {
// This translator is only needed to get axis as constant from external scope
num_inputs_check(context, 2, 2);
auto fw_node = std::make_shared<PtFrameworkNode>(context.get_decoder(), OutputVector{context.get_input(0)}, 1);
auto attrs = fw_node->get_attrs();
// If this fails it means axis is dynamic and aten::cat will be converted to fw node in regular pipeline
attrs["axis"] = std::to_string(context.const_input<int64_t>(1));
fw_node->set_attrs(attrs);
return {context.mark_node(std::dynamic_pointer_cast<Node>(fw_node))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/op/list_construct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ OutputVector translate_list_construct(NodeContext& context) {
consts.push_back(unsqueezed_c_node);
}
}
auto list_construct = std::make_shared<v0::Concat>(consts, 0);
auto list_construct = context.mark_node(std::make_shared<v0::Concat>(consts, 0));
if (list_construct->has_evaluate()) {
OutputVector replacements(list_construct->get_output_size());

if (list_construct->constant_fold(replacements, list_construct->input_values())) {
return replacements;
}
}
return {context.mark_output(list_construct)};
return {list_construct};
};

} // namespace op
Expand Down
8 changes: 6 additions & 2 deletions src/frontends/pytorch/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ OutputVector translate_loop(NodeContext& context) {
loop->set_special_body_ports(spec_ports);

auto body_parameters = body->get_parameters();
// #0 body parameter is counter; #0 loop input is counter, #1 loop input is condition
// #0 body parameter is counter;
FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required");
// Set counter type and shape
body_parameters[0]->set_element_type(element::i32);
body_parameters[0]->set_partial_shape(PartialShape{});
// #0 loop input is trip_count, #1 loop input is condition
// Connect other inputs
for (size_t i = 2; i < inputs.size(); i++) {
loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]});
Expand All @@ -39,7 +44,6 @@ OutputVector translate_loop(NodeContext& context) {
auto external_output = context.get_tensor_from_model_or_create_input(input_idx);
loop->set_invariant_inputs(external_output, {param});
}
// TODO: Connect back edges (merged inputs)
auto body_results = body->get_results();
FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition.");
std::set<size_t> output_idxs;
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_cat);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
Expand Down Expand Up @@ -160,7 +161,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::Bool", op::translate_bool},
// {"aten::cat", done as transformation},
{"aten::cat", op::translate_cat},
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
{"aten::clamp", op::translate_clamp},
Expand Down
69 changes: 60 additions & 9 deletions src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/loop.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
Expand Down Expand Up @@ -37,17 +38,67 @@ AtenCatToConcat::AtenCatToConcat() {
if (!cat)
return false;

auto axis_node = cat->input(1).get_source_output().get_node_shared_ptr();
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
if (!axis_const)
return false;
auto axis = axis_const->cast_vector<int64_t>();
if (axis.size() != 1)
return false;
int64_t axis;
if (cat->get_input_size() > 1) {
auto axis_node = cat->get_input_node_shared_ptr(1);
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
if (!axis_const)
return false;
auto _axis = axis_const->cast_vector<int64_t>();
if (_axis.size() != 1)
return false;
axis = _axis[0];
} else {
const auto& attrs = cat->get_attrs();
if (attrs.find("axis") == attrs.end())
return false;
axis = std::stoll(attrs.at("axis"));
}

std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0);
if (auto loop = std::dynamic_pointer_cast<ov::op::v5::Loop>(input_node)) {
// case when concatenation is done inside the Loop
auto body = loop->get_function();
auto output_index = cat->input(0).get_source_output().get_index();
int64_t body_result_index = -1;
for (auto out_desc : loop->get_output_descriptions()) {
if (out_desc->m_output_index == output_index) {
body_result_index = static_cast<int64_t>(out_desc->m_body_value_index);
break;
}
}
FRONT_END_GENERAL_CHECK(body_result_index >= 0, "Couldn't find descriptor for output.");
auto body_result = body->get_results()[body_result_index];
auto append = cast_fw_node(body_result->get_input_node_shared_ptr(0), "aten::append");
if (!append)
return false;
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(append->get_input_node_shared_ptr(0));
if (!param)
return false;
auto body_param_index = body->get_parameter_index(param);
FRONT_END_GENERAL_CHECK(body_param_index >= 0, "Couldn't find parameter in body parameters.");
int64_t input_index = -1;
for (auto in_desc : loop->get_input_descriptions()) {
if (in_desc->m_body_parameter_index == static_cast<size_t>(body_param_index)) {
input_index = static_cast<int64_t>(in_desc->m_input_index);
break;
}
}
FRONT_END_GENERAL_CHECK(input_index >= 0, "Couldn't find descriptor for input.");
auto list_construct = cast_fw_node(loop->get_input_node_shared_ptr(input_index), "prim::ListConstruct");
if (!list_construct || list_construct->get_input_size() > 0)
return false;
// TODO: Is unsqueeze needed?
auto new_result = std::make_shared<ov::op::v0::Result>(append->input_value(1));
body->add_results({new_result});
auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis);
copy_runtime_info(cat, loop);
cat->output(0).replace(new_output);
return true;
}

OutputVector tmp_inputs;
NodeVector rt_copy_from{cat};
std::shared_ptr<Node> input_node = cat->input(0).get_source_output().get_node_shared_ptr();
while (const auto& input_fw_node = cast_fw_node(input_node, "aten::append")) {
rt_copy_from.push_back(input_fw_node);
tmp_inputs.push_back(input_fw_node->input(1).get_source_output());
Expand All @@ -62,7 +113,7 @@ AtenCatToConcat::AtenCatToConcat() {
inputs.push_back(input.get_source_output());
}
inputs.insert(inputs.end(), tmp_inputs.rbegin(), tmp_inputs.rend());
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis[0]);
auto result = std::make_shared<ov::op::v0::Concat>(inputs, axis);
copy_runtime_info(rt_copy_from, result);
replace_node(cat, result);

Expand Down
3 changes: 2 additions & 1 deletion tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def _test(self, model, ref_net, kind, ie_device, precision, ir_version, infer_ti
else:
torch_inputs = [torch.from_numpy(inp) for inp in inputs]
model = torch.jit.trace(model, torch_inputs)
model = torch.jit.freeze(model)
if kwargs.get('freeze_model', True):
model = torch.jit.freeze(model)
graph = model.inlined_graph
print(graph)

Expand Down
2 changes: 1 addition & 1 deletion tests/layer_tests/pytorch_tests/test_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def forward_scalar(self, x:int):
@pytest.mark.parametrize("input_type", ["tensor", "scalar"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_ceil(self, ie_device, precision, ir_version, input_type):
def test_bool(self, ie_device, precision, ir_version, input_type):
self._test(*self.create_model(input_type), ie_device, precision, ir_version)
51 changes: 51 additions & 0 deletions tests/layer_tests/pytorch_tests/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class aten_cat(torch.nn.Module):
def forward(self, x):
return torch.cat([x, x], 1)


class aten_append_cat(torch.nn.Module):
def forward(self, x):
list = []
list.append(x)
list.append(x)
return torch.cat(list, 1)

class aten_loop_append_cat(torch.nn.Module):
def forward(self, x):
list = []
for i in range(3):
list.append(x)
return torch.cat(list, 1)


class TestCat(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 1, 3),)

@pytest.mark.nightly
@pytest.mark.precommit
def test_cat(self, ie_device, precision, ir_version):
self._test(aten_cat(), None, ["aten::cat", "prim::ListConstruct"],
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
def test_append_cat(self, ie_device, precision, ir_version):
self._test(aten_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct"],
ie_device, precision, ir_version)

@pytest.mark.nightly
@pytest.mark.precommit
def test_loop_append_cat(self, ie_device, precision, ir_version):
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
ie_device, precision, ir_version, freeze_model=False)

0 comments on commit 586d309

Please sign in to comment.