Skip to content

Commit

Permalink
[PT FE] Add support for PackedSequence for aten::lstm (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#22586)

* [PT FE] Add support for PackedSequence for aten::lstm

* Apply suggestions from code review
  • Loading branch information
mvafin authored Feb 1, 2024
1 parent ff6795e commit 02d7cb9
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "transforms/prim_list_unpack_replacer.hpp"
#include "transforms/prim_unpack_parameter_replacer.hpp"
#include "transforms/quantized_node_remover.hpp"
#include "transforms/remove_packing_ops.hpp"
#include "transforms/reverseprop_resolver.hpp"
#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/softmax_reshape_elimination.hpp"
Expand Down Expand Up @@ -213,6 +214,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
manager.register_pass<ov::frontend::pytorch::pass::U4BlockRepack>();
manager.register_pass<ov::frontend::pytorch::pass::ReversepropResolver>();
manager.register_pass<ov::frontend::pytorch::pass::MovePackThroughLstm>();
manager.register_pass<ov::frontend::pytorch::pass::RemovePackingOps>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
// Second pass of AlignTypesRemoval after all converting transformations
Expand Down
43 changes: 43 additions & 0 deletions src/frontends/pytorch/src/helper_ops/packed_sequence.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "internal_op.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {

class PackPadded : public InternalOperation {
public:
OPENVINO_OP("PackPadded", "util", ov::op::util::FrameworkNode);
PackPadded(const Output<Node>& input, const Output<Node>& lengths)
: InternalOperation("prim::PackPadded", {input, lengths}, 2, "This is PackedSequence pack operation.") {
validate_and_infer_types();
}

void validate_and_infer_types() override {
set_output_type(0, get_input_element_type(0), PartialShape({-1, -1, -1}));
set_output_type(1, get_input_element_type(1), PartialShape::dynamic());
}
};

class PadPacked : public InternalOperation {
public:
OPENVINO_OP("PadPacked", "util", ov::op::util::FrameworkNode);
PadPacked(const Output<Node>& input, const Output<Node>& lengths)
: InternalOperation("prim::PadPacked", {input, lengths}, 2, "This is PackedSequence unpack operation.") {
validate_and_infer_types();
}

void validate_and_infer_types() override {
set_output_type(0, get_input_element_type(0), PartialShape({-1, -1, -1}));
set_output_type(1, get_input_element_type(1), get_input_partial_shape(1));
}
};
} // namespace pytorch
} // namespace frontend
} // namespace ov
44 changes: 41 additions & 3 deletions src/frontends/pytorch/src/op/lstm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_ops/packed_sequence.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
Expand Down Expand Up @@ -248,12 +249,23 @@ OutputVector generic_rnn(ov::pass::NodeRegistry& rg,
}
if (!batch_first)
prev_output = rg.make<v1::Transpose>(prev_output, order_102);
Output<Node> h_res = rg.make<v0::Concat>(h_outs, 1);
Output<Node> h_res;
if (h_outs.size() == 1) {
h_res = h_outs[0];
} else {
h_res = rg.make<v0::Concat>(h_outs, 1);
}

h_res = rg.make<v1::Transpose>(h_res, order_102);
if (variant == RnnVariant::RNN || variant == RnnVariant::GRU) {
return {prev_output, h_res};
} else if (variant == RnnVariant::LSTM) {
Output<Node> c_res = rg.make<v0::Concat>(c_outs, 1);
Output<Node> c_res;
if (c_outs.size() == 1) {
c_res = c_outs[0];
} else {
c_res = rg.make<v0::Concat>(c_outs, 1);
}
c_res = rg.make<v1::Transpose>(c_res, order_102);
return {prev_output, h_res, c_res};
}
Expand All @@ -267,7 +279,33 @@ OutputVector translate_lstm(const NodeContext& context) {
ov::pass::NodeRegistry rg;
if (context.get_input_type(3).is<type::List>()) {
// lstm packed
FRONT_END_OP_CONVERSION_CHECK(false, "Unsupported lstm variant.");
// aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int
// num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor)
const auto data = context.get_input(0);
const auto batch_sizes = context.get_input(1);
const auto hx = context.get_input(2);
const auto params = context.get_input(3);
const auto has_bias = context.const_input<bool>(4);
const auto num_layers = context.const_input<int64_t>(5);
// const auto dropout = context.const_input<float>(6); - skip
const auto train = context.const_input<bool>(7);
FRONT_END_OP_CONVERSION_CHECK(!train, "LSTM in train mode is not supported.");
const auto bidirectional = context.const_input<bool>(8);

const auto initial_states = get_list_as_outputs(hx);
const auto all_weights = get_list_as_outputs(params);
const auto res = generic_rnn(rg,
RnnVariant::LSTM,
data,
initial_states,
all_weights,
has_bias,
num_layers,
bidirectional,
false,
batch_sizes);
context.mark_nodes(rg.get());
return res;
} else {
// aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout,
// bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)
Expand Down
52 changes: 52 additions & 0 deletions src/frontends/pytorch/src/op/pack_sequence.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "helper_ops/packed_sequence.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/transpose.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_pack_padded_sequence(const NodeContext& context) {
num_inputs_check(context, 3, 3);
auto seq = context.get_input(0);
auto lengths = context.get_input(1);
const auto batch_first = context.const_input<bool>(2);

const auto order_102 = v0::Constant::create(element::i32, Shape{3}, {1, 0, 2});
if (batch_first)
seq = context.mark_node(std::make_shared<v1::Transpose>(seq, order_102));
if (lengths.get_element_type() != element::i32)
lengths = context.mark_node(std::make_shared<v0::Convert>(lengths, element::i32));
return context.mark_node(std::make_shared<PackPadded>(seq, lengths))->outputs();
};

OutputVector translate_pad_packed_sequence(const NodeContext& context) {
// aten::_pad_packed_sequence with schema aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool
// batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor)
num_inputs_check(context, 3, 5);
auto seq = context.get_input(0);
auto lengths = context.get_input(1);
const auto batch_first = context.const_input<bool>(2);
auto pad_packed = context.mark_node(std::make_shared<PadPacked>(seq, lengths));
seq = pad_packed->output(0);
lengths = pad_packed->output(1);
const auto order_102 = v0::Constant::create(element::i32, Shape{3}, {1, 0, 2});
if (batch_first)
seq = context.mark_node(std::make_shared<v1::Transpose>(seq, order_102));
return {seq, lengths};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_or);
OP_CONVERTER(translate_bitwise_xor);
OP_CONVERTER(translate_outer);
OP_CONVERTER(translate_pack_padded_sequence);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pad_packed_sequence);
OP_CONVERTER(translate_pairwise_distance);
OP_CONVERTER(translate_pixel_shuffle);
OP_CONVERTER(translate_pixel_unshuffle);
Expand Down Expand Up @@ -278,6 +280,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::_convolution", op::translate_convolution},
{"aten::_convolution_mode", op::translate_convolution_mode},
{"aten::_native_multi_head_attention", op::translate_native_multi_head_attention},
{"aten::_pack_padded_sequence", op::translate_pack_padded_sequence},
{"aten::_pad_packed_sequence", op::translate_pad_packed_sequence},
{"aten::_set_item", op::translate_set_item},
{"aten::_shape_as_tensor", op::translate_shape_as_tensor},
{"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa},
Expand Down
135 changes: 135 additions & 0 deletions src/frontends/pytorch/src/transforms/remove_packing_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "remove_packing_ops.hpp"

#include "helper_ops/packed_sequence.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/gru_sequence.hpp"
#include "openvino/op/lstm_sequence.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/rnn_sequence.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

using namespace ov::pass;
using namespace ov::op;

namespace {
bool is_rnn(std::shared_ptr<Node> node) {
if (as_type_ptr<v5::LSTMSequence>(node) || as_type_ptr<v5::RNNSequence>(node) ||
as_type_ptr<v5::GRUSequence>(node)) {
return true;
}
return false;
}
} // namespace

MovePackThroughLstm::MovePackThroughLstm() {
auto pack_op = pattern::wrap_type<PackPadded>();

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pack = m.get_match_root();

auto targets = pack->output(0).get_target_inputs();
if (targets.size() != 1)
return false;
auto rnn = targets.begin()->get_node()->shared_from_this();
// Input to rnn may be transposed, skipping Transpose
if (as_type_ptr<v1::Transpose>(rnn))
rnn = rnn->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
if (!is_rnn(rnn))
return false;
targets = rnn->output(0).get_target_inputs();
if (targets.size() != 1)
return false;

// The rnn is followed by a transpose and a reshape (if bidirectional), or by a squeeze (if unidirectional).
auto next = targets.begin()->get_node()->shared_from_this();
if (as_type_ptr<v1::Transpose>(next)) {
next = next->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
if (!as_type_ptr<v1::Reshape>(next)) {
return false;
}
} else if (!as_type_ptr<v0::Squeeze>(next)) {
return false;
}

// remove PackPadded from in front of the RNN
pack->output(0).replace(pack->input_value(0));

auto batch_sizes = pack->output(1);
for (auto node_input : batch_sizes.get_target_inputs()) {
auto user = node_input.get_node()->shared_from_this();
// Make calculation of max_batch_size not depend on batch_sizes.
// This looks for a pattern generated by code such as
// https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815.
//
// Replace Gather[axis=0](batch_sizes, 0)
// with Gather[axis=0](ShapeOf(rnn_input), 0)
if (const auto gather = as_type_ptr<v8::Gather>(user)) {
if (gather->get_axis() != 0)
continue;
auto rnn_shape = std::make_shared<v3::ShapeOf>(rnn->input_value(0), element::i32);
auto indx_1 = v0::Constant::create(element::i32, Shape{}, {0});
auto new_gather = std::make_shared<v8::Gather>(rnn_shape, indx_1, gather->input_value(2));
copy_runtime_info_and_name(gather, {new_gather, rnn_shape, indx_1});
replace_node(gather, new_gather);
} else if (user == rnn) {
node_input.replace_source_output(pack->input_value(1));
}
}
// and insert new PackPadded after the RNN
auto next_target_inputs = next->output(0).get_target_inputs();
auto newPackPadded = std::make_shared<PackPadded>(next->output(0), pack->input_value(1));
register_new_node(newPackPadded);

// make things consume from the new PackPadded
for (auto& input : next_target_inputs)
input.replace_source_output(newPackPadded->output(0));
pack->output(1).replace(newPackPadded->output(1));

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(pack_op, "ov::frontend::pytorch::pass::MovePackThroughLstm");
this->register_matcher(m, callback);
};

RemovePackingOps::RemovePackingOps() {
auto unpack_op = pattern::wrap_type<PadPacked>();

ov::matcher_pass_callback callback = [](pattern::Matcher& m) {
const auto& unpack = m.get_match_root();
auto pack_node = unpack->input_value(0).get_node_shared_ptr();
if (as_type_ptr<v1::Transpose>(pack_node))
pack_node = std::dynamic_pointer_cast<PackPadded>(pack_node->input_value(0).get_node_shared_ptr());
if (!pack_node)
return false;

pack_node->output(0).replace(pack_node->input_value(0));
pack_node->output(1).replace(pack_node->input_value(1));
unpack->output(0).replace(unpack->input_value(0));
unpack->output(1).replace(unpack->input_value(1));

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(unpack_op, "ov::frontend::pytorch::pass::RemovePackingOps");
this->register_matcher(m, callback);
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
36 changes: 36 additions & 0 deletions src/frontends/pytorch/src/transforms/remove_packing_ops.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {

/**
* Move PackPadded through RNN ops, because RNN(PackPadded(x)) == PackPadded(RNN(x)).
*/
class MovePackThroughLstm : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::MovePackThroughLstm");
MovePackThroughLstm();
};

/**
* Remove PackPadded -> PadPacked ops.
*/
class RemovePackingOps : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::RemovePackingOps");
RemovePackingOps();
};

} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov
Loading

0 comments on commit 02d7cb9

Please sign in to comment.