diff --git a/src/common/transformations/include/transformations/utils/utils.hpp b/src/common/transformations/include/transformations/utils/utils.hpp index 9a2036fff1b20d..1961c35ef16594 100644 --- a/src/common/transformations/include/transformations/utils/utils.hpp +++ b/src/common/transformations/include/transformations/utils/utils.hpp @@ -182,7 +182,8 @@ TRANSFORMATIONS_API bool check_for_broadcast(const PartialShape& ref_shape, cons TRANSFORMATIONS_API std::shared_ptr activation(const std::string& activation_name, const Output& apply_to); -TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr& seq_len_input, int64_t max_seq_len); +TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr& X, + const std::shared_ptr& seq_len_input); TRANSFORMATIONS_API std::shared_ptr try_fold_unary_output(const std::shared_ptr& node); diff --git a/src/common/transformations/src/transformations/op_conversions/convert_sequences_to_tensor_iterator.cpp b/src/common/transformations/src/transformations/op_conversions/convert_sequences_to_tensor_iterator.cpp index 7d7cc8049883d0..a7e7b3c1ae1880 100644 --- a/src/common/transformations/src/transformations/op_conversions/convert_sequences_to_tensor_iterator.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_sequences_to_tensor_iterator.cpp @@ -88,12 +88,11 @@ bool convert_sequence_to_ti(const std::shared_ptr& sequence, const ov::Output& B, const ov::op::RecurrentSequenceDirection& direction) { auto X_pshape = X.get_partial_shape(); - if (X_pshape.size() < 2 || X_pshape[1].is_dynamic()) { + if (X_pshape.size() < 2) { return false; } - auto max_seq_len = X_pshape[1].get_length(); - bool enable_mask = ov::op::util::is_seq_len_provided(seq_lengths.get_node_shared_ptr(), max_seq_len); + bool enable_mask = ov::op::util::is_seq_len_provided(X.get_node_shared_ptr(), seq_lengths.get_node_shared_ptr()); const bool is_reverse = direction == ov::op::RecurrentSequenceDirection::REVERSE; std::shared_ptr reverse_seq_before; diff --git a/src/common/transformations/src/transformations/utils/utils.cpp b/src/common/transformations/src/transformations/utils/utils.cpp index 62b1765e7ba275..6d77ecc0dbea94 100644 --- a/src/common/transformations/src/transformations/utils/utils.cpp +++ b/src/common/transformations/src/transformations/utils/utils.cpp @@ -128,11 +128,56 @@ std::shared_ptr activation(const std::string& activation_name, const o } } -bool is_seq_len_provided(const std::shared_ptr& seq_len_input, int64_t max_seq_len) { +bool is_seq_len_provided(const std::shared_ptr& X, const std::shared_ptr& seq_len_input) { + auto max_seq_dim = X->get_output_partial_shape(0)[1]; + if (max_seq_dim.is_dynamic()) { + // if values in seq_len input are equal to max_seq_len dim in X input + // then we don't need to insert Select operations + // supported seq_len_input: + // X -> ShapeOf -> Gather (max_seq_dim) -> Optional (Broadcast) + std::shared_ptr input = seq_len_input; + auto broadcast = ov::as_type_ptr(input); + if (broadcast) { + input = seq_len_input->input_value(0).get_node_shared_ptr(); + } + + auto gather = ov::as_type_ptr(input); + bool valid_gather = false; + if (gather) { + auto indices = gather->input_value(1).get_node_shared_ptr(); + auto axis = gather->input_value(2).get_node_shared_ptr(); + auto indices_const = ov::as_type_ptr(indices); + auto axis_const = ov::as_type_ptr(axis); + if (indices_const && axis_const) { + auto ind_values = indices_const->cast_vector(); + auto axis_values = axis_const->cast_vector(); + if (ind_values.size() == 1 && ind_values[0] == 1 && axis_values.size() == 1 && axis_values[0] == 0) { + valid_gather = true; + } + } + } + + if (!valid_gather) { + return true; + } + + auto shape_of = ov::as_type_ptr(gather->input_value(0).get_node_shared_ptr()); + if (!shape_of) { + return true; + } + + if (shape_of->input_value(0).get_node_shared_ptr() != X) { + return true; + } + + return false; + } + + auto max_seq_len_val = max_seq_dim.get_length(); if (const auto& seq_len_const = std::dynamic_pointer_cast(seq_len_input)) { const auto& seq_len_values = seq_len_const->cast_vector(); - return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len](const int64_t val) { - return val != max_seq_len; + return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len_val](const int64_t val) { + return val != max_seq_len_val; }); } return true; diff --git a/src/common/transformations/tests/op_conversions/convert_sequences_to_ti_test.cpp b/src/common/transformations/tests/op_conversions/convert_sequences_to_ti_test.cpp index e140087c2dd2e8..7220157efbf781 100644 --- a/src/common/transformations/tests/op_conversions/convert_sequences_to_ti_test.cpp +++ b/src/common/transformations/tests/op_conversions/convert_sequences_to_ti_test.cpp @@ -798,3 +798,115 @@ TEST(TransformationTests, ConvertQuantizedGRUSequenceToTensorIterator) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; } + +TEST(TransformationTests, ConvertLSTMSequenceWithDynSeqLenToTensorIterator) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto X = std::make_shared(element::f32, PartialShape{1, -1, 16}); + auto Y = std::make_shared(element::f32, Shape{1, 1, 128}); + auto Z = std::make_shared(element::f32, Shape{1, 1, 128}); + auto shape_of = std::make_shared(X); + auto indices = opset5::Constant::create(element::i32, {1}, {1}); + auto axis = opset5::Constant::create(element::i32, {}, {0}); + auto seq_lengths = std::make_shared(shape_of, indices, axis); + + auto w_val = std::vector(512 * 16, 0); + auto r_val = std::vector(512 * 128, 0); + auto b_val = std::vector(512, 0); + auto W = opset5::Constant::create(element::f32, Shape{1, 512, 16}, w_val); + auto R = opset5::Constant::create(element::f32, Shape{1, 512, 128}, r_val); + auto B = opset5::Constant::create(element::f32, Shape{1, 512}, b_val); + + auto rnn_sequence = std::make_shared(X, + Y, + Z, + seq_lengths, + W, + R, + B, + 128, + op::RecurrentSequenceDirection::FORWARD); + auto Y_out = std::make_shared(rnn_sequence->output(0)); + auto Ho = std::make_shared(rnn_sequence->output(1)); + auto Co = std::make_shared(rnn_sequence->output(2)); + Y_out->set_friendly_name("Y_out"); + Ho->set_friendly_name("Ho"); + Co->set_friendly_name("Co"); + + f = std::make_shared(NodeVector{Y_out, Ho, Co}, ParameterVector{X, Y, Z}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto X = std::make_shared(element::f32, PartialShape{1, -1, 16}); + auto Y = std::make_shared(element::f32, Shape{1, 1, 128}); + auto Z = std::make_shared(element::f32, Shape{1, 1, 128}); + auto squeeze_pattern = opset5::Constant::create(element::i64, Shape{1}, {1}); + auto squeeze_y = std::make_shared(Y, squeeze_pattern); + auto squeeze_z = std::make_shared(Z, squeeze_pattern); + + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto Yi = std::make_shared(element::f32, Shape{1, 128}); + auto Zi = std::make_shared(element::f32, Shape{1, 128}); + auto seq_body_param = std::make_shared(element::i32, PartialShape{1}); + + // Body + auto squeeze_x = std::make_shared(Xi, squeeze_pattern); + + auto w_val = std::vector(512 * 16, 0); + auto r_val = std::vector(512 * 128, 0); + auto b_val = std::vector(512, 0); + auto W = opset5::Constant::create(element::f32, Shape{512, 16}, w_val); + auto R = opset5::Constant::create(element::f32, Shape{512, 128}, r_val); + auto B = opset5::Constant::create(element::f32, Shape{512}, b_val); + + auto rnn_cell = std::make_shared(squeeze_x, Yi, Zi, W, R, B, 128); + + auto unsqueeze_pattern = opset5::Constant::create(element::i64, Shape{1}, {1}); + auto Ho = std::make_shared(rnn_cell->output(0)); + + auto Co = std::make_shared(rnn_cell->output(1)); + + auto unsqueeze_y = std::make_shared(rnn_cell->output(0), unsqueeze_pattern); + auto Y_out = std::make_shared(unsqueeze_y); + + auto body = std::make_shared(OutputVector{Y_out, Ho, Co}, ParameterVector{Xi, Yi, Zi, seq_body_param}); + + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_body(body); + + tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1); + tensor_iterator->get_concatenated_slices(Y_out, 0, 1, 1, -1, 1); + + tensor_iterator->set_merged_input(Yi, squeeze_y, Ho); + tensor_iterator->set_merged_input(Zi, squeeze_z, Co); + + auto shape_of = std::make_shared(X); + auto indices = opset5::Constant::create(element::i32, {1}, {1}); + auto axis = opset5::Constant::create(element::i32, {}, {0}); + auto seq_lengths = std::make_shared(shape_of, indices, axis); + tensor_iterator->set_invariant_input(seq_body_param, seq_lengths); + + tensor_iterator->get_iter_value(Ho); + tensor_iterator->get_iter_value(Co); + + auto res_ti_Y = std::make_shared( + std::make_shared(tensor_iterator->output(0), unsqueeze_pattern)); + auto res_ti_H = std::make_shared( + std::make_shared(tensor_iterator->output(1), unsqueeze_pattern)); + auto res_ti_C = std::make_shared( + std::make_shared(tensor_iterator->output(2), unsqueeze_pattern)); + res_ti_Y->set_friendly_name("Y_out"); + res_ti_H->set_friendly_name("Ho"); + res_ti_C->set_friendly_name("Co"); + f_ref = std::make_shared(NodeVector{res_ti_Y, res_ti_H, res_ti_C}, ParameterVector{X, Y, Z}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/rnn.cpp b/src/plugins/intel_cpu/src/nodes/rnn.cpp index f453b7a5a51e0b..158b1f65967215 100644 --- a/src/plugins/intel_cpu/src/nodes/rnn.cpp +++ b/src/plugins/intel_cpu/src/nodes/rnn.cpp @@ -318,8 +318,9 @@ bool RNN::isSupportedOperation(const std::shared_ptr& op, std::s errorMessage = "Max sequence length dimension is dynamic"; return false; } - auto maxSeqLen = data_pshape[maxSeqLenDimIdx].get_length(); - if (ov::op::util::is_seq_len_provided(op->get_input_node_shared_ptr(seqLenIdx), maxSeqLen)) { + + if (ov::op::util::is_seq_len_provided(op->get_input_node_shared_ptr(0), + op->get_input_node_shared_ptr(seqLenIdx))) { errorMessage = "Unsupported sequence length."; return false; } diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 10275dae95d729..60976fecafccce 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -376,8 +376,8 @@ void TransformationsPipeline::apply(std::shared_ptr func) { return lstm_seq->get_clip() == 0.0f && lstm_seq->get_activations() == std::vector{"sigmoid", "tanh", "tanh"} && max_seq_len < 16 && - !ov::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(3), - max_seq_len); + !ov::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(0), + lstm_seq->get_input_node_shared_ptr(3)); } return false; };