From 02d7cb959a19300e57a44a41a29a219a37d56ff2 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 1 Feb 2024 16:50:05 +0100 Subject: [PATCH] [PT FE] Add support for PackedSequence for aten::lstm (#22586) * [PT FE] Add support for PackedSequence for aten::lstm * Apply suggestions from code review --- src/frontends/pytorch/src/frontend.cpp | 3 + .../src/helper_ops/packed_sequence.hpp | 43 ++++++ src/frontends/pytorch/src/op/lstm.cpp | 44 +++++- .../pytorch/src/op/pack_sequence.cpp | 52 +++++++ src/frontends/pytorch/src/op_table.cpp | 4 + .../src/transforms/remove_packing_ops.cpp | 135 ++++++++++++++++++ .../src/transforms/remove_packing_ops.hpp | 36 +++++ tests/layer_tests/pytorch_tests/test_lstm.py | 60 ++++++++ 8 files changed, 374 insertions(+), 3 deletions(-) create mode 100644 src/frontends/pytorch/src/helper_ops/packed_sequence.hpp create mode 100644 src/frontends/pytorch/src/op/pack_sequence.cpp create mode 100644 src/frontends/pytorch/src/transforms/remove_packing_ops.cpp create mode 100644 src/frontends/pytorch/src/transforms/remove_packing_ops.hpp diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index fa00b94a4c6158..03835b72935327 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -39,6 +39,7 @@ #include "transforms/prim_list_unpack_replacer.hpp" #include "transforms/prim_unpack_parameter_replacer.hpp" #include "transforms/quantized_node_remover.hpp" +#include "transforms/remove_packing_ops.hpp" #include "transforms/reverseprop_resolver.hpp" #include "transforms/rfftn_complex_replacer.hpp" #include "transforms/softmax_reshape_elimination.hpp" @@ -213,6 +214,8 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); // Second pass of AlignTypesRemoval after all converting transformations diff --git a/src/frontends/pytorch/src/helper_ops/packed_sequence.hpp b/src/frontends/pytorch/src/helper_ops/packed_sequence.hpp new file mode 100644 index 00000000000000..30e1a37c9d1d96 --- /dev/null +++ b/src/frontends/pytorch/src/helper_ops/packed_sequence.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "internal_op.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { + +class PackPadded : public InternalOperation { +public: + OPENVINO_OP("PackPadded", "util", ov::op::util::FrameworkNode); + PackPadded(const Output& input, const Output& lengths) + : InternalOperation("prim::PackPadded", {input, lengths}, 2, "This is PackedSequence pack operation.") { + validate_and_infer_types(); + } + + void validate_and_infer_types() override { + set_output_type(0, get_input_element_type(0), PartialShape({-1, -1, -1})); + set_output_type(1, get_input_element_type(1), PartialShape::dynamic()); + } +}; + +class PadPacked : public InternalOperation { +public: + OPENVINO_OP("PadPacked", "util", ov::op::util::FrameworkNode); + PadPacked(const Output& input, const Output& lengths) + : InternalOperation("prim::PadPacked", {input, lengths}, 2, "This is PackedSequence unpack operation.") { + validate_and_infer_types(); + } + + void validate_and_infer_types() override { + set_output_type(0, get_input_element_type(0), PartialShape({-1, -1, -1})); + set_output_type(1, get_input_element_type(1), get_input_partial_shape(1)); + } +}; +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/lstm.cpp b/src/frontends/pytorch/src/op/lstm.cpp index 0ea42e8bfa1799..1ec859e5e7b8c5 100644 --- a/src/frontends/pytorch/src/op/lstm.cpp +++ b/src/frontends/pytorch/src/op/lstm.cpp @@ -1,6 +1,7 @@ // Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include "helper_ops/packed_sequence.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" @@ -248,12 +249,23 @@ OutputVector generic_rnn(ov::pass::NodeRegistry& rg, } if (!batch_first) prev_output = rg.make(prev_output, order_102); - Output h_res = rg.make(h_outs, 1); + Output h_res; + if (h_outs.size() == 1) { + h_res = h_outs[0]; + } else { + h_res = rg.make(h_outs, 1); + } + h_res = rg.make(h_res, order_102); if (variant == RnnVariant::RNN || variant == RnnVariant::GRU) { return {prev_output, h_res}; } else if (variant == RnnVariant::LSTM) { - Output c_res = rg.make(c_outs, 1); + Output c_res; + if (c_outs.size() == 1) { + c_res = c_outs[0]; + } else { + c_res = rg.make(c_outs, 1); + } c_res = rg.make(c_res, order_102); return {prev_output, h_res, c_res}; } @@ -267,7 +279,33 @@ OutputVector translate_lstm(const NodeContext& context) { ov::pass::NodeRegistry rg; if (context.get_input_type(3).is()) { // lstm packed - FRONT_END_OP_CONVERSION_CHECK(false, "Unsupported lstm variant."); + // aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int + // num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) + const auto data = context.get_input(0); + const auto batch_sizes = context.get_input(1); + const auto hx = context.get_input(2); + const auto params = context.get_input(3); + const auto has_bias = context.const_input(4); + const auto num_layers = context.const_input(5); + // const auto dropout = context.const_input(6); - skip + const auto train = context.const_input(7); + FRONT_END_OP_CONVERSION_CHECK(!train, "LSTM in train mode is not supported."); + const auto bidirectional = context.const_input(8); + + const auto initial_states = get_list_as_outputs(hx); + const auto all_weights = get_list_as_outputs(params); + const auto res = generic_rnn(rg, + RnnVariant::LSTM, + data, + initial_states, + all_weights, + has_bias, + num_layers, + bidirectional, + false, + batch_sizes); + context.mark_nodes(rg.get()); + return res; } else { // aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, // bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) diff --git a/src/frontends/pytorch/src/op/pack_sequence.cpp b/src/frontends/pytorch/src/op/pack_sequence.cpp new file mode 100644 index 00000000000000..2e0fd92b50c231 --- /dev/null +++ b/src/frontends/pytorch/src/op/pack_sequence.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "helper_ops/packed_sequence.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/transpose.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_pack_padded_sequence(const NodeContext& context) { + num_inputs_check(context, 3, 3); + auto seq = context.get_input(0); + auto lengths = context.get_input(1); + const auto batch_first = context.const_input(2); + + const auto order_102 = v0::Constant::create(element::i32, Shape{3}, {1, 0, 2}); + if (batch_first) + seq = context.mark_node(std::make_shared(seq, order_102)); + if (lengths.get_element_type() != element::i32) + lengths = context.mark_node(std::make_shared(lengths, element::i32)); + return context.mark_node(std::make_shared(seq, lengths))->outputs(); +}; + +OutputVector translate_pad_packed_sequence(const NodeContext& context) { + // aten::_pad_packed_sequence with schema aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool + // batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) + num_inputs_check(context, 3, 5); + auto seq = context.get_input(0); + auto lengths = context.get_input(1); + const auto batch_first = context.const_input(2); + auto pad_packed = context.mark_node(std::make_shared(seq, lengths)); + seq = pad_packed->output(0); + lengths = pad_packed->output(1); + const auto order_102 = v0::Constant::create(element::i32, Shape{3}, {1, 0, 2}); + if (batch_first) + seq = context.mark_node(std::make_shared(seq, order_102)); + return {seq, lengths}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index c829099c4acee9..7eb8875787ef01 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -149,7 +149,9 @@ OP_CONVERTER(translate_ones_like); OP_CONVERTER(translate_or); OP_CONVERTER(translate_bitwise_xor); OP_CONVERTER(translate_outer); +OP_CONVERTER(translate_pack_padded_sequence); OP_CONVERTER(translate_pad); +OP_CONVERTER(translate_pad_packed_sequence); OP_CONVERTER(translate_pairwise_distance); OP_CONVERTER(translate_pixel_shuffle); OP_CONVERTER(translate_pixel_unshuffle); @@ -278,6 +280,8 @@ const std::map get_supported_ops_ts() { {"aten::_convolution", op::translate_convolution}, {"aten::_convolution_mode", op::translate_convolution_mode}, {"aten::_native_multi_head_attention", op::translate_native_multi_head_attention}, + {"aten::_pack_padded_sequence", op::translate_pack_padded_sequence}, + {"aten::_pad_packed_sequence", op::translate_pad_packed_sequence}, {"aten::_set_item", op::translate_set_item}, {"aten::_shape_as_tensor", op::translate_shape_as_tensor}, {"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa}, diff --git a/src/frontends/pytorch/src/transforms/remove_packing_ops.cpp b/src/frontends/pytorch/src/transforms/remove_packing_ops.cpp new file mode 100644 index 00000000000000..d5e13c6d31a21d --- /dev/null +++ b/src/frontends/pytorch/src/transforms/remove_packing_ops.cpp @@ -0,0 +1,135 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "remove_packing_ops.hpp" + +#include "helper_ops/packed_sequence.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/gru_sequence.hpp" +#include "openvino/op/lstm_sequence.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/rnn_sequence.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::pass; +using namespace ov::op; + +namespace { +bool is_rnn(std::shared_ptr node) { + if (as_type_ptr(node) || as_type_ptr(node) || + as_type_ptr(node)) { + return true; + } + return false; +} +} // namespace + +MovePackThroughLstm::MovePackThroughLstm() { + auto pack_op = pattern::wrap_type(); + + ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto pack = m.get_match_root(); + + auto targets = pack->output(0).get_target_inputs(); + if (targets.size() != 1) + return false; + auto rnn = targets.begin()->get_node()->shared_from_this(); + // Input to rnn may be transposed, skipping Transpose + if (as_type_ptr(rnn)) + rnn = rnn->output(0).get_target_inputs().begin()->get_node()->shared_from_this(); + if (!is_rnn(rnn)) + return false; + targets = rnn->output(0).get_target_inputs(); + if (targets.size() != 1) + return false; + + // The rnn is followed by a transpose and a reshape (if bidirectional), or by a squeeze (if unidirectional). + auto next = targets.begin()->get_node()->shared_from_this(); + if (as_type_ptr(next)) { + next = next->output(0).get_target_inputs().begin()->get_node()->shared_from_this(); + if (!as_type_ptr(next)) { + return false; + } + } else if (!as_type_ptr(next)) { + return false; + } + + // remove PackPadded from in front of the RNN + pack->output(0).replace(pack->input_value(0)); + + auto batch_sizes = pack->output(1); + for (auto node_input : batch_sizes.get_target_inputs()) { + auto user = node_input.get_node()->shared_from_this(); + // Make calculation of max_batch_size not depend on batch_sizes. + // This looks for a pattern generated by code such as + // https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815. + // + // Replace Gather[axis=0](batch_sizes, 0) + // with Gather[axis=0](ShapeOf(rnn_input), 0) + if (const auto gather = as_type_ptr(user)) { + if (gather->get_axis() != 0) + continue; + auto rnn_shape = std::make_shared(rnn->input_value(0), element::i32); + auto indx_1 = v0::Constant::create(element::i32, Shape{}, {0}); + auto new_gather = std::make_shared(rnn_shape, indx_1, gather->input_value(2)); + copy_runtime_info_and_name(gather, {new_gather, rnn_shape, indx_1}); + replace_node(gather, new_gather); + } else if (user == rnn) { + node_input.replace_source_output(pack->input_value(1)); + } + } + // and insert new PackPadded after the RNN + auto next_target_inputs = next->output(0).get_target_inputs(); + auto newPackPadded = std::make_shared(next->output(0), pack->input_value(1)); + register_new_node(newPackPadded); + + // make things consume from the new PackPadded + for (auto& input : next_target_inputs) + input.replace_source_output(newPackPadded->output(0)); + pack->output(1).replace(newPackPadded->output(1)); + + return true; + }; + + auto m = std::make_shared(pack_op, "ov::frontend::pytorch::pass::MovePackThroughLstm"); + this->register_matcher(m, callback); +}; + +RemovePackingOps::RemovePackingOps() { + auto unpack_op = pattern::wrap_type(); + + ov::matcher_pass_callback callback = [](pattern::Matcher& m) { + const auto& unpack = m.get_match_root(); + auto pack_node = unpack->input_value(0).get_node_shared_ptr(); + if (as_type_ptr(pack_node)) + pack_node = std::dynamic_pointer_cast(pack_node->input_value(0).get_node_shared_ptr()); + if (!pack_node) + return false; + + pack_node->output(0).replace(pack_node->input_value(0)); + pack_node->output(1).replace(pack_node->input_value(1)); + unpack->output(0).replace(unpack->input_value(0)); + unpack->output(1).replace(unpack->input_value(1)); + + return true; + }; + + auto m = std::make_shared(unpack_op, "ov::frontend::pytorch::pass::RemovePackingOps"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/remove_packing_ops.hpp b/src/frontends/pytorch/src/transforms/remove_packing_ops.hpp new file mode 100644 index 00000000000000..773100dfc35af9 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/remove_packing_ops.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +/** + * Move PackPadded through RNN ops, because RNN(PackPadded(x)) == PackPadded(RNN(x)). + */ +class MovePackThroughLstm : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::MovePackThroughLstm"); + MovePackThroughLstm(); +}; + +/** + * Remove PackPadded -> PadPacked ops. + */ +class RemovePackingOps : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::RemovePackingOps"); + RemovePackingOps(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/pytorch_tests/test_lstm.py b/tests/layer_tests/pytorch_tests/test_lstm.py index 3fef1b1e761d25..8ffc3c73afe111 100644 --- a/tests/layer_tests/pytorch_tests/test_lstm.py +++ b/tests/layer_tests/pytorch_tests/test_lstm.py @@ -22,6 +22,29 @@ def forward(self, input_tensor, h0, c0): return self.lstm(input_tensor, (h0, c0)) +class aten_lstm_packed(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, batch_first): + torch.nn.Module.__init__(self) + self.rnn = torch.nn.LSTM(input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + bidirectional=bidirectional, + bias=has_bias, + ) + self.batch_first = batch_first + + def forward(self, seq, lengths): + seq1 = torch.nn.utils.rnn.pack_padded_sequence(seq, + lengths, + batch_first=self.batch_first) + seq2, hid2 = self.rnn(seq1) + seq = torch.nn.utils.rnn.pad_packed_sequence(seq2, + batch_first=self.batch_first)[0] + + return seq, hid2 + + class aten_gru(torch.nn.Module): def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, batch_first): torch.nn.Module.__init__(self) @@ -81,6 +104,43 @@ def test_lstm(self, input_size, hidden_size, num_layers, has_bias, bidirectional ie_device, precision, ir_version, trace_model=True) +class TestLSTMPacked(PytorchLayerTest): + def _prepare_input(self): + batch = 15 + if self.batch_first: + input = np.random.randn( + batch, 50, self.input_size).astype(np.float32) + else: + input = np.random.randn( + 50, batch, self.input_size).astype(np.float32) + lengths = np.array(list(sorted(np.random.randint( + 1, 50, [batch - 1]).tolist() + [50], reverse=True)), dtype=np.int32) + return (input, lengths) + + @pytest.mark.parametrize("input_size,hidden_size", [(10, 20),]) + @pytest.mark.parametrize("num_layers", [1, 2, 7]) + @pytest.mark.parametrize("has_bias", [True, False]) + @pytest.mark.parametrize("bidirectional", [True, False]) + @pytest.mark.parametrize("batch_first", [True, False]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_lstm_packed(self, input_size, hidden_size, num_layers, has_bias, bidirectional, batch_first, ie_device, precision, ir_version): + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bidirectional = bidirectional + self.batch_first = batch_first + self._test(aten_lstm_packed(input_size, hidden_size, num_layers, has_bias, bidirectional, batch_first), + None, + "aten::lstm", + ie_device, + precision, + ir_version, + trace_model=True, + dynamic_shapes=False # ticket 131432 + ) + + class TestGRU(PytorchLayerTest): def _prepare_input(self): n = self.num_layers