diff --git a/src/common/transformations/include/transformations/op_conversions/convert_ti_to_sequences.hpp b/src/common/transformations/include/transformations/op_conversions/convert_ti_to_sequences.hpp index 4d8199c28709d2..e729d735c652d1 100644 --- a/src/common/transformations/include/transformations/op_conversions/convert_ti_to_sequences.hpp +++ b/src/common/transformations/include/transformations/op_conversions/convert_ti_to_sequences.hpp @@ -18,6 +18,8 @@ class TRANSFORMATIONS_API ConvertTensorIteratorToRNNSequence; class TRANSFORMATIONS_API ConvertTensorIteratorToGRUSequence; class TRANSFORMATIONS_API ConvertTensorIteratorToSequence; +class TRANSFORMATIONS_API ConvertLoopWithSlicedInputConcatOutputToLSTMSequence; +class TRANSFORMATIONS_API ConvertLoopWithScatterUpdateToLSTMSequence; class TRANSFORMATIONS_API ConvertLoopToLSTMSequence; class TRANSFORMATIONS_API FuseReverseLSTMSequence; @@ -68,14 +70,29 @@ class ov::pass::ConvertTensorIteratorToSequence : public GraphRewrite { ConvertTensorIteratorToSequence(); }; +class ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ConvertLoopWithSlicedInputConcatOutputToLSTMSequence", "0"); + ConvertLoopWithSlicedInputConcatOutputToLSTMSequence(); +}; + +class ov::pass::ConvertLoopWithScatterUpdateToLSTMSequence : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ConvertLoopWithScatterUpdateToLSTMSequence", "0"); + ConvertLoopWithScatterUpdateToLSTMSequence(); +}; + /** * @ingroup ov_transformation_common_api * @brief Replaces Loop with LSTMCell inside to LSTMSequence */ -class ov::pass::ConvertLoopToLSTMSequence : public ov::pass::MatcherPass { +class ov::pass::ConvertLoopToLSTMSequence : public ov::pass::GraphRewrite { public: OPENVINO_RTTI("ConvertLoopToLSTMSequence", "0"); - ConvertLoopToLSTMSequence(); + ConvertLoopToLSTMSequence() { + add_matcher(); + add_matcher(); + } }; /** diff --git a/src/common/transformations/include/transformations/utils/utils.hpp b/src/common/transformations/include/transformations/utils/utils.hpp index a8818c7440e63c..a6114e8e07172a 100644 --- a/src/common/transformations/include/transformations/utils/utils.hpp +++ b/src/common/transformations/include/transformations/utils/utils.hpp @@ -114,6 +114,17 @@ inline std::string get_ie_output_name(const Output& output) { */ float cast_eps_to_float(double eps_d); +template +bool get_constant_value(const std::shared_ptr& node, T& value) { + auto constant = ov::as_type_ptr(node); + if (!constant) + return false; + if (shape_size(constant->get_shape()) != 1) + return false; + value = constant->cast_vector()[0]; + return true; +} + template bool has_constant_value(const std::shared_ptr& node, const T value, diff --git a/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp b/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp index d2cd621ab24c89..f901399287738b 100644 --- a/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp @@ -35,6 +35,7 @@ #include "openvino/op/scatter_update.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/squeeze.hpp" +#include "openvino/op/strided_slice.hpp" #include "openvino/op/tensor_iterator.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" @@ -43,6 +44,10 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" +using namespace ov; +using namespace ov::pass; +using namespace ov::op::util; + namespace { bool convertTensorIteratorToSequence(const std::shared_ptr& ti, const std::shared_ptr& found_cell, @@ -242,6 +247,233 @@ bool convertTensorIteratorToSequence(const std::shared_ptr& cond_result, + const std::shared_ptr& loop, + const ov::ParameterVector& body_params, + const ov::ResultVector& body_results, + const ov::op::util::MultiSubGraphOp::MultiSubgraphInputDescriptionVector& input_descriptions, + Output& num_iters_output) { + // return a number of iterations if pattern matches + // pattern for condition sub-graph in Loop operarion + auto num_iter_const_label = pattern::wrap_type(); + auto num_iter_param_label = pattern::wrap_type(); + auto num_iterations_label = + std::make_shared(OutputVector{num_iter_const_label, num_iter_param_label}); + auto counter_label = pattern::wrap_type(); + auto counter_step_label = pattern::wrap_type(); + auto updated_counter_label = pattern::wrap_type({counter_label, counter_step_label}); + auto less_label = pattern::wrap_type({updated_counter_label, num_iterations_label}); + auto condition_label = pattern::wrap_type({less_label}); + + // check a pattern of condition graph and get a number of iterations + ov::pass::pattern::Matcher condition_matcher(condition_label); + if (!condition_matcher.match(cond_result->output(0))) { + return false; + } + const auto& condition_map = condition_matcher.get_pattern_value_map(); + int64_t counter_step = -1; + if (!get_constant_value(condition_map.at(counter_step_label).get_node_shared_ptr(), counter_step) || + counter_step != 1) { + return false; + } + // get initial value of counter + int64_t initial_counter = -1; + auto counter_param = ov::as_type_ptr(condition_map.at(counter_label).get_node_shared_ptr()); + if (!counter_param) { + return false; + } + for (const auto& input_desc : input_descriptions) { + auto body_param_idx = input_desc->m_body_parameter_index; + auto input_idx = input_desc->m_input_index; + if (body_params[body_param_idx] != counter_param) { + continue; + } + // it must be merged input and incremented each iteration + auto merged_desc = ov::as_type_ptr(input_desc); + if (!merged_desc) { + return false; + } + auto result_idx = merged_desc->m_body_value_index; + auto update_counter = body_results[result_idx]->input_value(0).get_node_shared_ptr(); + if (update_counter != condition_map.at(updated_counter_label).get_node_shared_ptr()) { + return false; + } + // get initial value of counter + if (!get_constant_value(loop->input_value(input_idx).get_node_shared_ptr(), initial_counter)) { + return false; + } + // suitable counter-parameter is found and checked + break; + } + if (initial_counter != 0) { + return false; + } + // retrieve a number of iterations + int64_t num_iters = -1; + if (condition_map.count(num_iter_param_label)) { + auto num_iter_param = condition_map.at(num_iter_param_label).get_node_shared_ptr(); + for (const auto& input_desc : input_descriptions) { + auto body_param_idx = input_desc->m_body_parameter_index; + auto input_idx = input_desc->m_input_index; + if (body_params[body_param_idx] != num_iter_param) { + continue; + } + num_iters_output = loop->input_value(input_idx); + break; + } + } else if (condition_map.count(num_iter_const_label)) { + if (!get_constant_value(condition_map.at(num_iter_const_label).get_node_shared_ptr(), num_iters) || + num_iters < 1) { + return false; + } + num_iters_output = std::make_shared(element::i64, Shape{}, num_iters); + } else { + return false; + } + + return true; +} + +bool check_condition_true_pattern(const std::shared_ptr& cond_result, + const std::shared_ptr& loop, + Output& num_iters_output) { + // return a number of iterations if pattern matches + auto cond_const_label = pattern::wrap_type(); + auto condition_label = pattern::wrap_type({cond_const_label}); + + // check a pattern of condition graph and get a number of iterations + ov::pass::pattern::Matcher condition_matcher(condition_label); + if (!condition_matcher.match(cond_result->output(0))) { + return false; + } + const auto& condition_map = condition_matcher.get_pattern_value_map(); + const auto& cond_const = + ov::as_type_ptr(condition_map.at(cond_const_label).get_node_shared_ptr()); + if (!cond_const) { + return false; + } + if (ov::shape_size(cond_const->get_shape()) != 1) + return false; + const auto& type = cond_const->get_output_element_type(0); + if (type != ov::element::boolean) { + return false; + } + bool cond_value = cond_const->cast_vector()[0]; + if (!cond_value) { + return false; + } + + // number of iteration is retrieve from the first input port + num_iters_output = loop->input_value(0); + + return true; +} + +bool check_lstm_cell_pattern( + const std::shared_ptr& loop, + const ov::ParameterVector& body_params, + const ov::ResultVector& body_results, + const ov::op::util::MultiSubGraphOp::MultiSubgraphInputDescriptionVector& input_descriptions, + ov::Output& init_hidden_state, + ov::Output& init_cell_state, + ov::Output& x, + ov::Output& W, + ov::Output& R, + ov::Output& B, + std::shared_ptr& result_hidden_state, + std::shared_ptr& lstm_cell_node) { + // check pattern with LSMCell and return key points + // required for fusion to LSTMSequence + // pattern for LSTMCell in the body + auto xi_label = pattern::wrap_type(); + auto squeeze_axis_label = pattern::wrap_type(); + auto xi_reshape_label = pattern::wrap_type({xi_label, squeeze_axis_label}); + auto init_hidden_state_i_label = pattern::wrap_type(); + auto init_cell_state_i_label = pattern::wrap_type(); + auto W_label = pattern::wrap_type(); + auto R_label = pattern::wrap_type(); + auto B_label = pattern::wrap_type(); + + auto lstm_cell_label = pattern::wrap_type( + {xi_reshape_label, init_hidden_state_i_label, init_cell_state_i_label, W_label, R_label, B_label}); + auto unsqueeze_axis_label = pattern::wrap_type(); + auto unsqueeze_hidden_state_label = pattern::wrap_type({lstm_cell_label, unsqueeze_axis_label}); + auto result_hidden_state_label = pattern::wrap_type({unsqueeze_hidden_state_label}); + + // check that body-graph contains a pattern corresponding LSTMCell + ov::pass::pattern::Matcher lstm_cell_matcher(result_hidden_state_label); + int64_t unsqueeze_hidden_state_result_idx = -1; + for (int64_t result_idx = 0; result_idx < static_cast(body_results.size()); ++result_idx) { + if (lstm_cell_matcher.match(body_results[result_idx]->output(0))) { + unsqueeze_hidden_state_result_idx = result_idx; + auto reshape_node = + lstm_cell_matcher.get_pattern_value_map()[unsqueeze_hidden_state_label].get_node_shared_ptr(); + break; + } + } + if (unsqueeze_hidden_state_result_idx < 0) { + // not found LSTMCell inside Loop operation + return false; + } + const auto& lstm_cell_map = lstm_cell_matcher.get_pattern_value_map(); + + // check that Results with hidden and cell states connected with corresponding + // Parameter nodes with back edges and found initial hidden and cell states + lstm_cell_node = ov::as_type_ptr(lstm_cell_map.at(lstm_cell_label).get_node_shared_ptr()); + if (!lstm_cell_node) { + return false; + } + result_hidden_state = + ov::as_type_ptr(lstm_cell_map.at(result_hidden_state_label).get_node_shared_ptr()); + if (!result_hidden_state) { + return false; + } + for (const auto& input_desc : input_descriptions) { + auto param_idx = input_desc->m_body_parameter_index; + auto input_idx = input_desc->m_input_index; + if (body_params[param_idx] == lstm_cell_map.at(init_hidden_state_i_label).get_node_shared_ptr()) { + // hidden state Parameter node + auto merged_desc = ov::as_type_ptr(input_desc); + if (!merged_desc) { + return false; + } + const auto& hidden_state_result = body_results[merged_desc->m_body_value_index]; + if ((hidden_state_result->get_input_node_shared_ptr(0) != lstm_cell_node) || + (hidden_state_result->input_value(0).get_index() != 0)) { + return false; + } + init_hidden_state = loop->input_value(input_idx); + } else if (body_params[param_idx] == lstm_cell_map.at(init_cell_state_i_label).get_node_shared_ptr()) { + // cell state Parameter node + auto merged_desc = ov::as_type_ptr(input_desc); + if (!merged_desc) { + return false; + } + const auto& cell_state_result = body_results[merged_desc->m_body_value_index]; + if ((cell_state_result->get_input_node_shared_ptr(0) != lstm_cell_node) || + (cell_state_result->input_value(0).get_index() != 1)) { + return false; + } + init_cell_state = loop->input_value(input_idx); + } else if (body_params[param_idx] == lstm_cell_map.at(xi_label).get_node_shared_ptr()) { + // input data Parameter node + auto sliced_desc = ov::as_type_ptr(input_desc); + if (!sliced_desc || (sliced_desc->m_axis != 0) || (sliced_desc->m_start != 0) || + (sliced_desc->m_stride != 1) || (sliced_desc->m_end != -1)) { + return false; + } + x = loop->input_value(input_idx); + } + } + + W = lstm_cell_map.at(W_label); + R = lstm_cell_map.at(R_label); + B = lstm_cell_map.at(B_label); + + return true; +} } // namespace ov::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() { @@ -423,19 +655,6 @@ ov::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence register_matcher(m, callback); } -static bool get_scalar_constant_value(const ov::Output& node, int64_t& output_value) { - auto constant = ov::as_type(node.get_node()); - if (!constant) - return false; - if (ov::shape_size(constant->get_shape()) != 1) - return false; - const auto& type = constant->get_output_element_type(0); - if (type != ov::element::i32 && type != ov::element::i64) - return false; - output_value = constant->cast_vector()[0]; - return true; -} - // clang-format off /* @@ -511,8 +730,8 @@ static bool get_scalar_constant_value(const ov::Output& node, int64_t& */ // clang-format on -ov::pass::ConvertLoopToLSTMSequence::ConvertLoopToLSTMSequence() { - MATCHER_SCOPE(ConvertLoopToLSTMSequence); +ov::pass::ConvertLoopWithScatterUpdateToLSTMSequence::ConvertLoopWithScatterUpdateToLSTMSequence() { + MATCHER_SCOPE(ConvertLoopWithScatterUpdateToLSTMSequence); auto input_label = pattern::any_input(pattern::rank_equals(3)); auto input_transpose_const_label = pattern::wrap_type(); auto input_transpose_label = @@ -612,28 +831,34 @@ ov::pass::ConvertLoopToLSTMSequence::ConvertLoopToLSTMSequence() { const auto& loop_output_map = loop_output_matcher.get_pattern_value_map(); int64_t iteration_counter_step = -1; - if (!get_scalar_constant_value(loop_condition_map.at(iteration_counter_step_label), iteration_counter_step) || + if (!get_constant_value(loop_condition_map.at(iteration_counter_step_label).get_node_shared_ptr(), + iteration_counter_step) || iteration_counter_step != 1) return false; int64_t sequence_index_step = -1; - if (!get_scalar_constant_value(loop_condition_map.at(sequence_index_step_label), sequence_index_step) || + if (!get_constant_value(loop_condition_map.at(sequence_index_step_label).get_node_shared_ptr(), + sequence_index_step) || sequence_index_step != 1) return false; int64_t iteration_counter_limit = -1; - if (!get_scalar_constant_value(loop_condition_map.at(iteration_counter_limit_label), iteration_counter_limit)) + if (!get_constant_value(loop_condition_map.at(iteration_counter_limit_label).get_node_shared_ptr(), + iteration_counter_limit)) return false; int64_t sequence_index_limit = -1; - if (!get_scalar_constant_value(loop_condition_map.at(sequence_index_limit_label), sequence_index_limit)) + if (!get_constant_value(loop_condition_map.at(sequence_index_limit_label).get_node_shared_ptr(), + sequence_index_limit)) return false; if (iteration_counter_limit != sequence_index_limit) return false; int64_t gather_axis = -1; - if (!get_scalar_constant_value(loop_output_map.at(gather_axis_label), gather_axis) || gather_axis != 0) + if (!get_constant_value(loop_output_map.at(gather_axis_label).get_node_shared_ptr(), gather_axis) || + gather_axis != 0) return false; int64_t scatter_axis = -1; - if (!get_scalar_constant_value(loop_output_map.at(scatter_axis_label), scatter_axis) || scatter_axis != 0) + if (!get_constant_value(loop_output_map.at(scatter_axis_label).get_node_shared_ptr(), scatter_axis) || + scatter_axis != 0) return false; const auto& sequence_index = loop_condition_map.at(sequence_index_label).get_node_shared_ptr(); @@ -830,6 +1055,219 @@ ov::pass::ConvertLoopToLSTMSequence::ConvertLoopToLSTMSequence() { register_matcher(m, callback); } +// clang-format off +/* + + Following subgraph in Loop is fused into LSTMSequence + + +------------------------------+ + | X | +------+ + | (sliced input) | | axis | + | [1, batch, input_size] | | {0} | + +--------------+---------------+ +--+---+ + | | + | | + +---+ | + | +---------------------------+ + | | + | | + v v +----------------------+ +----------------------+ + +---+------+------+---+ | H | | C | + | Squeeze | | (merged with H_out) | | (merged with C_out) | +-----+ +-----+ +-----+ + | [batch, input_size] | | [batch, hidden_size] | | [batch, hidden_size] | | W | | R | | B | + +----------+----------+ +----------+-----------+ +----------+-----------+ +--+--+ +--+--+ +--+--+ + | | | | | | + | | +---------------+ | | | + | | | | | | + | | | +------------------------------+ | | + | | | | | | + | +------+ | | +------------------------------------+ | + +----------------------------+ | | | | +------------------------------------------+ + | | | | | | + v v v v v v + +---+----+----+----+----+----+---+ + | LSTMCell | + +--------+-------------------+---+ + | | + | | + +----------+---------------+ | + | | +---------------------+ + | | | + v v v + +------------+------------+ +---------+------------+ +--------+--------+ + | Unsqueeze, axis=0 | | H_out | | C_out | + | [batch, 1, hidden_size] | | (merged with H) | | (merged with C) | + | (concat output) | | [batch, hidden_size] | | | + +-------------------------+ +----------------------+ +-----------------+ + +*/ +// clang-format on + +ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence() { + MATCHER_SCOPE(ConvertLoopWithSlicedInputConcatOutputToLSTMSequence); + + auto loop_label = pattern::wrap_type(); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto loop = ov::as_type_ptr(m.get_match_root()); + if (!loop) { + return false; + } + + const auto& body_graph = loop->get_function(); + const auto& body_params = body_graph->get_parameters(); + const auto& body_results = body_graph->get_results(); + const auto& special_body_ports = loop->get_special_body_ports(); + const auto& input_descriptions = loop->get_input_descriptions(); + const auto& output_descriptions = loop->get_output_descriptions(); + if (special_body_ports.body_condition_output_idx < 0) { + return false; + } + + // check condition pattern and retrieve a number of iterations + ov::Output num_iters_output; + if (!check_condition_increment_pattern(body_results[special_body_ports.body_condition_output_idx], + loop, + body_params, + body_results, + input_descriptions, + num_iters_output) && + !check_condition_true_pattern(body_results[special_body_ports.body_condition_output_idx], + loop, + num_iters_output)) { + return false; + } + + // check body pattern with LSTMCell + // and extract key points for LSTMSequence creation + ov::Output init_hidden_state, init_cell_state, x, W, R, B; + std::shared_ptr result_hidden_state; + std::shared_ptr lstm_cell_node; + if (!check_lstm_cell_pattern(loop, + body_params, + body_results, + input_descriptions, + init_hidden_state, + init_cell_state, + x, + W, + R, + B, + result_hidden_state, + lstm_cell_node)) { + return false; + } + + // check that only output with concatenated hidden states from Loop is needed + // or get hidden state from last iteration + std::vector last_iter_hidden_state_output_ids; + std::vector all_hidden_states_output_ids; + for (const auto& output_desc : output_descriptions) { + auto result_idx = output_desc->m_body_value_index; + auto output_idx = output_desc->m_output_index; + if (body_results[result_idx] != result_hidden_state) { + if (loop->get_output_target_inputs(output_idx).size() > 0) { + // some other Result node value from body graph is required by real consumers + return false; + } + continue; + } + auto concat_desc = ov::as_type_ptr(output_desc); + if (concat_desc) { + // all hidden states are concatenated + if ((concat_desc->m_axis != 0) || (concat_desc->m_start != 0) || (concat_desc->m_end != -1) || + (concat_desc->m_stride != 1)) { + return false; + } + all_hidden_states_output_ids.push_back(output_idx); + continue; + } + last_iter_hidden_state_output_ids.push_back(output_idx); + } + + // all input and output ports are almost ready + // preparation of the right format of inputs and outputs is remained before reconnection + NodeRegistry rg; + // x is in a format [seq_len, batch_size, input_size] + auto tr_order = rg.make(element::i32, Shape{3}, std::vector{1, 0, 2}); + auto x_prep = rg.make(x, tr_order); + // initial hidden and cell states are of shape [batch_size, hidden_size] + // needs to add num_directions size + auto unsqueeze_axis = rg.make(element::i32, Shape{1}, std::vector{1}); + auto init_hidden_state_prep = rg.make(init_hidden_state, unsqueeze_axis); + auto init_cell_state_prep = rg.make(init_cell_state, unsqueeze_axis); + // sequence length is currently scalar and it needs broadcasting to shape [batch_size] + auto x_shape = rg.make(x, element::i64); + auto ss_start = rg.make(element::i64, Shape{1}, 1); + auto ss_stop = rg.make(element::i64, Shape{1}, 2); + auto ss_step = rg.make(element::i64, Shape{1}, 1); + auto batch_size = rg.make(x_shape, + ss_start, + ss_stop, + ss_step, + std::vector{0}, + std::vector{0}); + auto seq_lens_prep = rg.make(num_iters_output, batch_size); + // prepare W, R, B to add num_directions dimension + unsqueeze_axis = rg.make(element::i32, Shape{1}, std::vector{0}); + auto W_prep = rg.make(W, unsqueeze_axis); + auto R_prep = rg.make(R, unsqueeze_axis); + auto B_prep = rg.make(B, unsqueeze_axis); + + // create LSTMSequence operation since all inputs are prepared + // extract hidden size that should be static as a requirement of OpenVINO LSTMCell operation + int64_t hidden_size = static_cast(lstm_cell_node->get_hidden_size()); + auto lstm_sequence = rg.make(x_prep, + init_hidden_state_prep, + init_cell_state_prep, + seq_lens_prep, + W_prep, + R_prep, + B_prep, + hidden_size, + ov::op::RecurrentSequenceDirection::FORWARD, + lstm_cell_node->get_activations_alpha(), + lstm_cell_node->get_activations_beta(), + lstm_cell_node->get_activations(), + lstm_cell_node->get_clip()); + + if (transformation_callback(lstm_sequence)) + return false; + + // prepare outputs of LSTMSequence + // output with concatenated hidden states must be in a format [seq_len, batch_size, hidden_size] + // LSTMSequence generates it in a format [batch_size, num_directions, seq_len, hidden_size] + auto squeeze_axis = rg.make(element::i32, Shape{1}, std::vector{1}); + ov::Output all_hidden_states_prep = rg.make(lstm_sequence->output(0), squeeze_axis); + all_hidden_states_prep = rg.make(all_hidden_states_prep, tr_order); + // prepare output with last hidden state + // LSTMSequence operation outputs it in a format [batch_size, num_directions, hidden_size] + auto tr_order2 = rg.make(element::i32, Shape{3}, std::vector{1, 0, 2}); + ov::Output last_hidden_state_prep = + rg.make(lstm_sequence->output(1), tr_order2); + + // reconnect all consumers for hidden states + for (auto output_idx : all_hidden_states_output_ids) { + loop->output(output_idx).replace(all_hidden_states_prep); + } + for (auto output_idx : last_iter_hidden_state_output_ids) { + loop->output(output_idx).replace(last_hidden_state_prep); + } + + copy_runtime_info(m.get_matched_nodes(), rg.get()); + if ((all_hidden_states_output_ids.size() > 0) && (last_iter_hidden_state_output_ids.size() == 0)) { + all_hidden_states_prep.get_node_shared_ptr()->set_friendly_name(loop->get_friendly_name()); + } else if ((all_hidden_states_output_ids.size() == 0) && (last_iter_hidden_state_output_ids.size() > 0)) { + last_hidden_state_prep.get_node_shared_ptr()->set_friendly_name(loop->get_friendly_name()); + } + + return true; + }; + + auto m = std::make_shared(loop_label, matcher_name); + register_matcher(m, callback); +} + class EliminateGatherWithRange : public ov::pass::MatcherPass { public: EliminateGatherWithRange() { @@ -868,9 +1306,9 @@ class EliminateGatherWithRange : public ov::pass::MatcherPass { const auto shapeof_gather2 = pattern_map.at(shapeof_gather2_label).get_node_shared_ptr(); int64_t shapeof_gather2_index = -1; int64_t shapeof_gather2_axis = -1; - if (!get_scalar_constant_value(shapeof_gather2->get_input_node_shared_ptr(1), shapeof_gather2_index)) + if (!get_constant_value(shapeof_gather2->get_input_node_shared_ptr(1), shapeof_gather2_index)) return false; - if (!get_scalar_constant_value(shapeof_gather2->get_input_node_shared_ptr(2), shapeof_gather2_axis) || + if (!get_constant_value(shapeof_gather2->get_input_node_shared_ptr(2), shapeof_gather2_axis) || shapeof_gather2_axis != 0) return false; const auto reshape = pattern_map.at(reshape_label).get_node_shared_ptr(); @@ -880,13 +1318,13 @@ class EliminateGatherWithRange : public ov::pass::MatcherPass { const auto range = pattern_map.at(range_label).get_node_shared_ptr(); int64_t range_start = -1; int64_t range_step = -1; - if (!get_scalar_constant_value(range->get_input_node_shared_ptr(0), range_start) || range_start != 0) + if (!get_constant_value(range->get_input_node_shared_ptr(0), range_start) || range_start != 0) return false; - if (!get_scalar_constant_value(range->get_input_node_shared_ptr(2), range_step) || range_step != 1) + if (!get_constant_value(range->get_input_node_shared_ptr(2), range_step) || range_step != 1) return false; int64_t gather_axis = -1; - if (!get_scalar_constant_value(gather->get_input_node_shared_ptr(2), gather_axis) || + if (!get_constant_value(gather->get_input_node_shared_ptr(2), gather_axis) || gather_axis != shapeof_gather_indexes[shapeof_gather2_index]) return false; @@ -1005,7 +1443,7 @@ ov::pass::FuseReverseLSTMSequence::FuseReverseLSTMSequence() { if (squeeze->input_value(0) != lstm->output(0)) return false; int64_t squeeze_axis = -1; - if (!get_scalar_constant_value(squeeze->get_input_node_shared_ptr(1), squeeze_axis) || squeeze_axis != 1) + if (!get_constant_value(squeeze->get_input_node_shared_ptr(1), squeeze_axis) || squeeze_axis != 1) return false; auto new_squeeze = node_registry.make(new_lstm->output(0), squeeze->input_value(1)); const auto match_root = m.get_match_root(); @@ -1122,7 +1560,7 @@ ov::pass::FuseLSTMSequencesToBidirectionalLSTMSequence::FuseLSTMSequencesToBidir if (squeeze_forward->input_value(0) != lstm_forward->output(0)) return false; int64_t squeeze_forward_axis = -1; - if (!get_scalar_constant_value(squeeze_forward->get_input_node_shared_ptr(1), squeeze_forward_axis) || + if (!get_constant_value(squeeze_forward->get_input_node_shared_ptr(1), squeeze_forward_axis) || squeeze_forward_axis != 1) return false; @@ -1130,7 +1568,7 @@ ov::pass::FuseLSTMSequencesToBidirectionalLSTMSequence::FuseLSTMSequencesToBidir if (squeeze_reverse->input_value(0) != lstm_reverse->output(0)) return false; int64_t squeeze_reverse_axis = -1; - if (!get_scalar_constant_value(squeeze_reverse->get_input_node_shared_ptr(1), squeeze_reverse_axis) || + if (!get_constant_value(squeeze_reverse->get_input_node_shared_ptr(1), squeeze_reverse_axis) || squeeze_reverse_axis != 1) return false; @@ -1180,21 +1618,17 @@ ov::pass::FuseLSTMSequencesToBidirectionalLSTMSequence::FuseLSTMSequencesToBidir auto gather_forward = pattern_map.at(gather_forward_label); int64_t gather_index = -1; int64_t gather_axis = -1; - if (!get_scalar_constant_value(gather_forward->get_input_node_shared_ptr(1), gather_index) || - gather_index != 0) + if (!get_constant_value(gather_forward->get_input_node_shared_ptr(1), gather_index) || gather_index != 0) return false; - if (!get_scalar_constant_value(gather_forward->get_input_node_shared_ptr(2), gather_axis) || - gather_axis != 0) + if (!get_constant_value(gather_forward->get_input_node_shared_ptr(2), gather_axis) || gather_axis != 0) return false; auto gather_reverse = pattern_map.at(gather_reverse_label); gather_index = -1; gather_axis = -1; - if (!get_scalar_constant_value(gather_reverse->get_input_node_shared_ptr(1), gather_index) || - gather_index != 0) + if (!get_constant_value(gather_reverse->get_input_node_shared_ptr(1), gather_index) || gather_index != 0) return false; - if (!get_scalar_constant_value(gather_reverse->get_input_node_shared_ptr(2), gather_axis) || - gather_axis != 0) + if (!get_constant_value(gather_reverse->get_input_node_shared_ptr(2), gather_axis) || gather_axis != 0) return false; from.push_back(max_sequence_len_forward); diff --git a/src/common/transformations/tests/op_conversions/convert_ti_to_sequences_test.cpp b/src/common/transformations/tests/op_conversions/convert_ti_to_sequences_test.cpp index e03f821c55584a..0115abc6779091 100644 --- a/src/common/transformations/tests/op_conversions/convert_ti_to_sequences_test.cpp +++ b/src/common/transformations/tests/op_conversions/convert_ti_to_sequences_test.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include "common_test_utils/ov_test_utils.hpp" @@ -1265,3 +1266,197 @@ TEST_P(FuseLSTMSequencesToBidirectionalLSTMSequenceTest, FusionTest) { INSTANTIATE_TEST_SUITE_P(FuseLSTMSequencesToBidirectionalLSTMSequence, FuseLSTMSequencesToBidirectionalLSTMSequenceTest, testing::Combine(testing::Values(false, true), testing::Values(false, true))); + +using LoopWithLSTMCellToLSTMSequenceFusionParam = std::tuple; // hidden size + +class LoopWithLSTMCellToLSTMSequenceFusionTest + : public testing::WithParamInterface, + public TransformationTestsF {}; + +namespace { +void generate_weights_value(std::vector& weights_value, const Shape& weights_shape) { + weights_value.resize(shape_size(weights_shape)); + std::mt19937 rng(9812); + std::uniform_real_distribution distribution(-300, 300); + for (size_t i = 0; i < weights_value.size(); ++i) { + weights_value[i] = distribution(rng); + } +} +} // namespace + +TEST_P(LoopWithLSTMCellToLSTMSequenceFusionTest, FusionTest) { + const auto& param = GetParam(); + const std::string& f_activation = std::get<0>(param); + const std::string& g_activation = std::get<1>(param); + const std::string& h_activation = std::get<2>(param); + size_t input_size = std::get<3>(param); + size_t hidden_size = std::get<4>(param); + size_t batch_size = 2; + size_t time_len = 10; + + // generate weights values + // w must be of a shape [input_size, hidden_size] + // r must be of a shape [hidden_size, hidden_size] + // b must be of a shape [hidden_size] + Shape w_shape({4 * hidden_size, input_size}); + Shape r_shape({4 * hidden_size, hidden_size}); + Shape b_shape({4 * hidden_size}); + std::vector w, r, b; + generate_weights_value(w, w_shape); + generate_weights_value(r, r_shape); + generate_weights_value(b, b_shape); + + { + // create body graph with LSTMCell + auto xi = std::make_shared(element::f32, Shape{1, batch_size, input_size}); + auto squeeze_axis = std::make_shared(element::i64, Shape{}, 0); + auto xi_squeeze = std::make_shared(xi, squeeze_axis); + auto init_hidden_state = std::make_shared(element::f32, Shape{batch_size, hidden_size}); + auto init_cell_state = std::make_shared(element::f32, Shape{batch_size, hidden_size}); + auto w_const = op::v0::Constant::create(element::f32, w_shape, w); + auto r_const = op::v0::Constant::create(element::f32, r_shape, r); + auto b_const = op::v0::Constant::create(element::f32, b_shape, b); + auto lstm_cell = + std::make_shared(xi_squeeze, + init_hidden_state, + init_cell_state, + w_const, + r_const, + b_const, + hidden_size, + std::vector{f_activation, g_activation, h_activation}); + + auto hidden_state_res = std::make_shared(lstm_cell->output(0)); + auto cell_state_res = std::make_shared(lstm_cell->output(1)); + auto unsqueeze_axis = std::make_shared(element::i64, Shape{}, 0); + auto unsqueeze_hidden_state = std::make_shared(lstm_cell->output(0), unsqueeze_axis); + auto unsqueeze_hidden_state_res = std::make_shared(unsqueeze_hidden_state); + + // conditional graph + auto num_iters = std::make_shared(element::i32, Shape{1}); + auto counter = std::make_shared(element::i32, Shape{1}); + auto increment = std::make_shared(element::i32, Shape{}, 1); + auto add = std::make_shared(counter, increment); + auto updated_counter = std::make_shared(add); + auto less = std::make_shared(add, num_iters); + auto less_res = std::make_shared(less); + + auto body_graph = std::make_shared( + ResultVector{hidden_state_res, cell_state_res, unsqueeze_hidden_state_res, less_res, updated_counter}, + ParameterVector{xi, init_hidden_state, init_cell_state, num_iters, counter}); + + // create main graph with Loop + auto x = std::make_shared(element::f32, Shape{time_len, batch_size, input_size}); + auto h_init = std::make_shared(element::f32, Shape{batch_size, hidden_size}); + auto c_init = std::make_shared(element::f32, Shape{batch_size, hidden_size}); + auto execution_cond = std::make_shared(ov::element::boolean, ov::Shape{}, true); + auto max_iter = std::make_shared(ov::element::i32, ov::Shape{1}, -1); + auto num_iter_const = + std::make_shared(ov::element::i32, ov::Shape{1}, static_cast(time_len)); + auto counter_const = std::make_shared(ov::element::i32, ov::Shape{1}, 0); + + auto loop_node = std::make_shared(max_iter, execution_cond); + + loop_node->set_function(body_graph); + loop_node->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 3}); + + // set inputs for Loop + // x input will be sliced for each time step + loop_node->set_sliced_input(xi, x, 0, 1, 1, -1, 0); + // set back edges for cell and hidden states + // since they are changing through timeline + loop_node->set_merged_input(init_hidden_state, h_init, hidden_state_res); + loop_node->set_merged_input(init_cell_state, c_init, cell_state_res); + loop_node->set_invariant_input(num_iters, num_iter_const); + loop_node->set_merged_input(counter, counter_const, updated_counter); + + // set external outputs for Loop node + // concatenated cell and hidden states from all time steps + auto hs = loop_node->get_concatenated_slices(unsqueeze_hidden_state_res, 0, 1, 1, -1, 0); + auto hs_res = std::make_shared(hs); + + model = std::make_shared(ResultVector{hs_res}, ParameterVector{x, h_init, c_init}); + manager.register_pass(); + } + + { + auto x = std::make_shared(element::f32, Shape{time_len, batch_size, input_size}); + auto h_init = std::make_shared(element::f32, Shape{batch_size, hidden_size}); + auto c_init = std::make_shared(element::f32, Shape{batch_size, hidden_size}); + + // transpose x since LSTMSequence expects x in a format [batch_size, time_len, input_size] + auto tr_order = + std::make_shared(ov::element::i32, ov::Shape{3}, std::vector{1, 0, 2}); + auto tr_x = std::make_shared(x, tr_order); + // prepare init hidden and cell states to have a format [batch_size, num_directions, hidden_size] + // where num_directions equals one + auto unsqueeze_axis = + std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{1}); + auto h_init_unsqueeze = std::make_shared(h_init, unsqueeze_axis); + auto c_init_unsqueeze = std::make_shared(c_init, unsqueeze_axis); + // prepare seq_lens + auto batch_size = std::make_shared(x, element::i64)->output(0); + auto begin = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto end = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{2}); + auto stride = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + batch_size = std::make_shared(batch_size, + begin, + end, + stride, + std::vector{0}, + std::vector{0}); + auto num_iter_const = + std::make_shared(ov::element::i32, ov::Shape{1}, static_cast(time_len)); + auto seq_lens = std::make_shared(num_iter_const, batch_size); + // prepare W, R, B weights to a format with num_directions dimension + auto w_const = op::v0::Constant::create(element::f32, w_shape, w); + auto r_const = op::v0::Constant::create(element::f32, r_shape, r); + auto b_const = op::v0::Constant::create(element::f32, b_shape, b); + auto unsqueeze_axis2 = + std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{0}); + auto w = std::make_shared(w_const, unsqueeze_axis2); + auto r = std::make_shared(r_const, unsqueeze_axis2); + auto b = std::make_shared(b_const, unsqueeze_axis2); + + // create LSTMSequence + auto lstm_sequence = std::make_shared( + tr_x, + h_init_unsqueeze, + c_init_unsqueeze, + seq_lens, + w, + r, + b, + hidden_size, + ov::op::RecurrentSequenceDirection::FORWARD, + std::vector{}, + std::vector{}, + std::vector{f_activation, g_activation, h_activation}, + 0.0f); + + // prepare output + auto squeeze_axis = std::make_shared(ov::element::i32, ov::Shape{1}, 1); + auto squeeze_output_hs = std::make_shared(lstm_sequence->output(0), squeeze_axis); + auto tr_order2 = + std::make_shared(ov::element::i32, ov::Shape{3}, std::vector{1, 0, 2}); + auto tr_squeeze_output_hs = std::make_shared(squeeze_output_hs, tr_order2); + auto output_hs_res = std::make_shared(tr_squeeze_output_hs); + model_ref = std::make_shared(ResultVector{output_hs_res}, ParameterVector{x, h_init, c_init}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); +} + +INSTANTIATE_TEST_SUITE_P(LoopWithLSTMCellToLSTMSequenceFusion, + LoopWithLSTMCellToLSTMSequenceFusionTest, + testing::Combine(testing::Values("sigmoid", "tanh"), + testing::Values("sigmoid", "relu"), + testing::Values("tanh", "relu"), + testing::Values(2, 3), + testing::Values(3, 4))); diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index 2c49cad537c1e5..412245a7edc53a 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -27,6 +27,7 @@ #include "openvino/util/file_util.hpp" #include "openvino/util/log.hpp" #include "tf_framework_node.hpp" +#include "transformations/common_optimizations/eliminate_loop_inputs_outputs.hpp" #include "transformations/common_optimizations/remove_concat_zero_dim_input.hpp" #include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp" #include "transformations/control_flow/unroll_if.hpp" @@ -568,7 +569,16 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + + // apply EliminateLoopInputsOutputs to avoid extra Results + // that output the same value as receiving on input + // it is needed for applying TensorListInLoopOptimization + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/tensorflow/src/op/block_lstm.cpp b/src/frontends/tensorflow/src/op/block_lstm.cpp index 663f50794476c5..7498f6e2702316 100644 --- a/src/frontends/tensorflow/src/op/block_lstm.cpp +++ b/src/frontends/tensorflow/src/op/block_lstm.cpp @@ -71,7 +71,7 @@ void create_decomposed_block_lstm(const Output& x, auto squeeze_axis = std::make_shared(element::i32, Shape{1}, 0); auto xi = std::make_shared(xi_param, squeeze_axis); - auto lstm_cell = std::make_shared(xi, + auto lstm_cell = std::make_shared(xi, h_prev_param, c_prev_param, w_param, diff --git a/src/frontends/tensorflow_common/include/helper_transforms/tensor_list_ops_resolver.hpp b/src/frontends/tensorflow_common/include/helper_transforms/tensor_list_ops_resolver.hpp index b7a3bc5b5a9891..2ef83a6c3682f5 100644 --- a/src/frontends/tensorflow_common/include/helper_transforms/tensor_list_ops_resolver.hpp +++ b/src/frontends/tensorflow_common/include/helper_transforms/tensor_list_ops_resolver.hpp @@ -4,9 +4,6 @@ #pragma once -#include -#include - #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pass.hpp" @@ -36,16 +33,12 @@ class TensorListGetItemReplacer : public ov::pass::MatcherPass { TensorListGetItemReplacer(); }; -// Replace and optimize sub-graphs with TensorList operations such as TensorListReserve, -// TensorListSetItem, TensorListGetItem -class TensorListOperationsResolver : public ov::pass::GraphRewrite { +// Optimize sub-graphs with TensorList operations in Loop body graph +// Replace TensorListSetItem and TensorListGetItem with ConcatOutput and SlicedInput +class TensorListInLoopOptimization : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("TensorListOperationsResolver", "0"); - TensorListOperationsResolver() { - add_matcher(); - add_matcher(); - add_matcher(); - } + OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorListInLoopOptimization"); + TensorListInLoopOptimization(); }; } // namespace pass diff --git a/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp b/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp index 520b93fdfc817a..09b7257eb68be7 100644 --- a/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp +++ b/src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp @@ -5,16 +5,20 @@ #include "helper_transforms/tensor_list_ops_resolver.hpp" #include "helper_ops/tensor_list_ops.hpp" +#include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/gather.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/loop.hpp" #include "openvino/op/parameter.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/scatter_update.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/slice.hpp" +#include "openvino/op/squeeze.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/op/util/multi_subgraph_base.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -25,6 +29,13 @@ using namespace ov; using namespace ov::pass; using namespace ov::op; using namespace ov::frontend; +using namespace ov::op::util; + +using InvariantD = ov::op::util::MultiSubGraphOp::InvariantInputDescription; +using SlicedD = ov::op::util::MultiSubGraphOp::SliceInputDescription; +using MergedD = ov::op::util::MultiSubGraphOp::MergedInputDescription; +using OutputD = ov::op::util::MultiSubGraphOp::BodyOutputDescription; +using ConcatD = ov::op::util::MultiSubGraphOp::ConcatOutputDescription; namespace { ov::Rank find_element_rank(const ov::Input& target_input) { @@ -64,6 +75,115 @@ ov::Rank find_element_rank(const ov::Input& target_input) { return ov::Rank::dynamic(); } + +bool find_input_description(const ov::op::util::InputDescriptionVector& input_descriptions, + uint64_t param_idx, + ov::op::util::SubGraphOp::InputDescription::Ptr& found_desc) { + for (const auto& input_desc : input_descriptions) { + if (input_desc->m_body_parameter_index == param_idx) { + found_desc = input_desc; + return true; + } + } + + return false; +} + +void update_parameter_to_slice_input(const std::shared_ptr& node, + const std::shared_ptr& body_graph, + const ov::op::util::InputDescriptionVector& input_descriptions, + std::vector& update_param_ids) { + // select only TensorListGetItem that accepts a tensor list from Parameter node + // value of Parameter node is unchanged from one iteration to another one in Loop + auto tensor_list_get_item = std::dynamic_pointer_cast(node); + if (!tensor_list_get_item) { + return; + } + auto tensor_list = ov::as_type_ptr(tensor_list_get_item->get_input_node_shared_ptr(0)); + if (!tensor_list) { + return; + } + if (tensor_list->get_output_target_inputs(0).size() != 1) { + return; + } + + // tensor list must be invariant through iterations + int64_t param_idx = body_graph->get_parameter_index(tensor_list); + if (param_idx < 0) { + return; + } + ov::op::util::SubGraphOp::InputDescription::Ptr input_desc = nullptr; + if (!find_input_description(input_descriptions, static_cast(param_idx), input_desc) || !input_desc) { + return; + } + auto invariant_input_desc = ov::as_type_ptr(input_desc); + if (!invariant_input_desc) { + return; + } + + update_param_ids.push_back(static_cast(param_idx)); +} + +void update_result_to_concat_output(const std::shared_ptr& node, + const std::shared_ptr& body_graph, + const ov::ResultVector& results, + const ov::op::util::InputDescriptionVector& input_descriptions, + std::vector& update_result_ids, + std::vector& remove_param_ids) { + // select only TensorListSetItem that accepts a tensor list from Parameter node + // output of TensorListSetItem goes to Result that is connected with the tensor list by a back edge + auto tensor_list_set_item = std::dynamic_pointer_cast(node); + if (!tensor_list_set_item) { + return; + } + auto tensor_list = ov::as_type_ptr(tensor_list_set_item->get_input_node_shared_ptr(0)); + if (!tensor_list) { + return; + } + if (tensor_list->get_output_target_inputs(0).size() != 1) { + return; + } + + int64_t param_idx = body_graph->get_parameter_index(tensor_list); + if (param_idx < 0) { + return; + } + ov::op::util::SubGraphOp::InputDescription::Ptr input_desc = nullptr; + if (!find_input_description(input_descriptions, static_cast(param_idx), input_desc) || !input_desc) { + return; + } + auto merged_input_desc = ov::as_type_ptr(input_desc); + if (!merged_input_desc) { + return; + } + + uint64_t result_idx = merged_input_desc->m_body_value_index; + if (results[result_idx]->get_input_node_shared_ptr(0) != tensor_list_set_item) { + return; + } + + update_result_ids.push_back(result_idx); + remove_param_ids.push_back(static_cast(param_idx)); +} + +uint64_t get_new_param_idx(const std::vector& remove_parameter_idxs, uint64_t old_idx) { + // compute a number of Parameters nodes standing before old_idx that will be removed + uint64_t num_removed = 0; + for (auto remove_idx : remove_parameter_idxs) { + FRONT_END_GENERAL_CHECK(old_idx != remove_idx, + "[TensorFlow Frontend] internal error: incorrect old_idx for " + "TensorListSliceInputAndConcatOutputReplacer transformation"); + if (remove_idx < old_idx) { + ++num_removed; + } + } + + // compute shifted index + FRONT_END_GENERAL_CHECK(num_removed <= old_idx, + "[TensorFlow Frontend] internal error: incorrect new parameter index computation " + "TensorListSliceInputAndConcatOutputReplacer transformation"); + return old_idx - num_removed; +} } // namespace ov::frontend::tensorflow::pass::TensorListReplacer::TensorListReplacer() { @@ -209,3 +329,304 @@ ov::frontend::tensorflow::pass::TensorListGetItemReplacer::TensorListGetItemRepl "ov::frontend::tensorflow::pass::TensorListGetItemReplacer"); register_matcher(m, callback); } + +ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOptimization() { + auto loop_label = pattern::wrap_type(); + + // pattern for condition sub-graph in Loop operarion + auto num_iterations_label = pattern::wrap_type(); + auto counter_label = pattern::wrap_type(); + auto counter_step_label = pattern::wrap_type(); + auto updated_counter_label = pattern::wrap_type({counter_label, counter_step_label}); + auto less_label = pattern::wrap_type({updated_counter_label, num_iterations_label}); + auto condition_label = pattern::wrap_type({less_label}); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + NodeRegistry rg; + + auto loop = ov::as_type_ptr(m.get_match_root()); + if (!loop) { + return false; + } + + // check that condition sub-graph of the required form: + // counter with zero initial value and increments each iteration + // loop continues until counter is less than given number + ov::pass::pattern::Matcher condition_matcher(condition_label); + + const auto& body = loop->get_function(); + const auto& body_params = body->get_parameters(); + const auto& body_results = body->get_results(); + const auto& special_body_ports = loop->get_special_body_ports(); + const auto& input_descriptions = loop->get_input_descriptions(); + const auto& output_descriptions = loop->get_output_descriptions(); + + if (!condition_matcher.match(body_results[special_body_ports.body_condition_output_idx]->output(0))) { + return false; + } + + const auto& condition_map = condition_matcher.get_pattern_value_map(); + int64_t counter_step = -1; + if (!get_constant_value(condition_map.at(counter_step_label).get_node_shared_ptr(), counter_step) || + counter_step != 1) { + return false; + } + + // get initial value of counter + int64_t initial_counter = 0; + auto counter_param = ov::as_type_ptr(condition_map.at(counter_label).get_node_shared_ptr()); + if (!counter_param) { + return false; + } + + for (const auto& input_desc : loop->get_input_descriptions()) { + auto body_param_idx = input_desc->m_body_parameter_index; + auto input_idx = input_desc->m_input_index; + if (body_params[body_param_idx] != counter_param) { + continue; + } + + // it must be merged input and incremented each iteration + auto merged_desc = ov::as_type_ptr(input_desc); + if (!merged_desc) { + return false; + } + auto result_idx = merged_desc->m_body_value_index; + auto update_counter = body_results[result_idx]->input_value(0).get_node_shared_ptr(); + if (update_counter != condition_map.at(updated_counter_label).get_node_shared_ptr()) { + return false; + } + + // get initial value of counter + if (!get_constant_value(loop->get_input_node_shared_ptr(input_idx), initial_counter)) { + return false; + } + + // suitable counter-parameter is found and checked + break; + } + + if (initial_counter != 0) { + return false; + } + + // collect vector of updated Parameter indices (they will be converted to SlicedInput parameters) + // and Result nodes (they will be converted to ConcatOutput results) + // also, some parameters and results can be removed and they are directly connected + // to updated Parameter/Result nodes + std::vector remove_parameter_ids; + std::vector update_parameter_ids, update_result_ids; + for (const auto& target_input : counter_param->get_output_target_inputs(0)) { + update_parameter_to_slice_input(target_input.get_node()->shared_from_this(), + body, + input_descriptions, + update_parameter_ids); + update_result_to_concat_output(target_input.get_node()->shared_from_this(), + body, + body_results, + input_descriptions, + update_result_ids, + remove_parameter_ids); + } + + // avoid TensorListSetItem that overrides tensor list with data computed for each iteration + // by index zero that is equivalent to Loop returning data only from last iteration + // TensorListSetItem accepts constant index equal to zero + // TensorListSetItem is not connected with counter node + std::vector update_result_last_iter_ids; + for (uint64_t result_idx = 0; result_idx < body_results.size(); ++result_idx) { + const auto& result = body_results[result_idx]; + auto tensor_list_set_item = + std::dynamic_pointer_cast(result->get_input_node_shared_ptr(0)); + if (!tensor_list_set_item) { + continue; + } + int64_t index_value = -1; + if (!get_constant_value(tensor_list_set_item->get_input_node_shared_ptr(1), index_value) || + (index_value != 0)) { + continue; + } + + auto tensor_list = ov::as_type_ptr(tensor_list_set_item->get_input_node_shared_ptr(0)); + if (!tensor_list) { + continue; + } + int64_t param_idx = body->get_parameter_index(tensor_list); + if (param_idx < 0) { + continue; + } + + update_result_last_iter_ids.push_back(result_idx); + remove_parameter_ids.push_back(static_cast(param_idx)); + } + + // nothing to update + if (update_parameter_ids.size() == 0 && update_result_ids.size() == 0 && + update_result_last_iter_ids.size() == 0) { + return false; + } + + // build a new body_graph + auto new_body_results = body_results; + std::vector all_update_result_ids = update_result_ids; + all_update_result_ids.insert(all_update_result_ids.end(), + update_result_last_iter_ids.begin(), + update_result_last_iter_ids.end()); + for (auto update_result_idx : all_update_result_ids) { + const auto& body_result = body_results[update_result_idx]; + auto tensor_list_set_item = + std::dynamic_pointer_cast(body_result->get_input_node_shared_ptr(0)); + FRONT_END_GENERAL_CHECK(tensor_list_set_item, + "[TensorFlow Frontend] internal error: tensor_list_set_item is nullptr in " + "TensorListSliceInputAndConcatOutputReplacer"); + // unsqueeze newly generated data at this iteration + // that will be concatenated + auto new_data = tensor_list_set_item->input_value(2); + auto axis = std::make_shared(ov::element::i32, ov::Shape{1}, 0); + auto unsqueeze_new_data = std::make_shared(new_data, axis); + auto new_body_result = std::make_shared(unsqueeze_new_data); + new_body_results[update_result_idx] = new_body_result; + } + auto new_body_params = ParameterVector{}; + for (uint64_t param_idx = 0; param_idx < static_cast(body_params.size()); ++param_idx) { + // skip Parameter nodes from remove_parameter_ids list + if (std::find(remove_parameter_ids.begin(), remove_parameter_ids.end(), param_idx) != + remove_parameter_ids.end()) { + continue; + } + + // use updated Parameter node if needed + if (std::find(update_parameter_ids.begin(), update_parameter_ids.end(), param_idx) != + update_parameter_ids.end()) { + const auto& body_param = body_params[param_idx]; + FRONT_END_GENERAL_CHECK(body_param->get_output_target_inputs(0).size() == 1, + "[TensorFlow Frontend] internal error: tensor list must have only consumer " + "TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer"); + auto target_input = *(body_param->get_output_target_inputs(0).begin()); + auto tensor_list_get_item = + std::dynamic_pointer_cast(target_input.get_node()->shared_from_this()); + FRONT_END_GENERAL_CHECK(tensor_list_get_item, + "[TensorFlow Frontend] internal error: tensor list must have only consumer " + "TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer"); + + auto new_shape = body_param->get_output_partial_shape(0); + if (new_shape.rank().is_static() && new_shape.rank().get_length() > 0) { + // set a static dimension equal to 1 since it is sliced by axis 0 + new_shape[0] = 1; + } + auto new_param = std::make_shared(body_param->get_output_element_type(0), new_shape); + new_param->set_friendly_name(body_param->get_friendly_name()); + + // adjust new_param since it comes after slicing and sliced input needs to be squeezed + auto squeeze_axis = std::make_shared(element::i32, Shape{1}, 0); + auto squeeze_param = std::make_shared(new_param, squeeze_axis); + + // replace data producer for all consumers of TensorListGetItem + tensor_list_get_item->output(0).replace(squeeze_param->output(0)); + new_body_params.push_back(new_param); + continue; + } + + new_body_params.push_back(body_params[param_idx]); + } + auto new_body_graph = std::make_shared(new_body_results, new_body_params); + + // eventually, only some Parameter nodes can be removed + // so indices of Parameters can be changed + // a number of Result nodes and their indices leave unchanged + // create new Loop operation and set input and output descriptions + const auto& trip_count = loop->input_value(0); + const auto& exec_cond = loop->input_value(1); + auto new_loop = rg.make(trip_count, exec_cond); + new_loop->set_special_body_ports(special_body_ports); + new_loop->set_function(new_body_graph); + + // update current_iteration_input_idx since some Parameters can be removed + // Result nodes are not removed so body_condition_output_idx leaves unchanged + auto current_iteration_input_idx = special_body_ports.current_iteration_input_idx; + if (current_iteration_input_idx > 0) { + auto new_idx = get_new_param_idx(remove_parameter_ids, static_cast(current_iteration_input_idx)); + auto new_special_body_ports = special_body_ports; + new_special_body_ports.current_iteration_input_idx = static_cast(new_idx); + new_loop->set_special_body_ports(new_special_body_ports); + } + + // set inputs for new Loop operation + for (const auto& input_desc : input_descriptions) { + // skip already removed body Parameters + auto param_idx = input_desc->m_body_parameter_index; + auto input_index = input_desc->m_input_index; + if (std::find(remove_parameter_ids.begin(), remove_parameter_ids.end(), param_idx) != + remove_parameter_ids.end()) { + continue; + } + + auto new_param_idx = get_new_param_idx(remove_parameter_ids, param_idx); + const auto& new_body_param = new_body_params[new_param_idx]; + const auto& init_value = loop->input_value(input_index); + if (std::find(update_parameter_ids.begin(), update_parameter_ids.end(), param_idx) != + update_parameter_ids.end()) { + // set this input as sliced input + new_loop->set_sliced_input(new_body_param, init_value, initial_counter, counter_step, 1, -1, 0); + } else if (const auto& invariant_input = ov::as_type_ptr(input_desc)) { + new_loop->set_invariant_input(new_body_param, init_value); + } else if (const auto& merged_input = ov::as_type_ptr(input_desc)) { + const auto& body_res = new_body_results[merged_input->m_body_value_index]; + new_loop->set_merged_input(new_body_param, init_value, body_res); + } else if (const auto& sliced_input = ov::as_type_ptr(input_desc)) { + new_loop->set_sliced_input(new_body_param, + init_value, + sliced_input->m_start, + sliced_input->m_stride, + sliced_input->m_part_size, + sliced_input->m_end, + sliced_input->m_axis); + } else { + // unknown type of input + // transformation is not applicable + return false; + } + } + + // set outputs for new Loop operation + std::unordered_map> idx_to_new_output; + for (const auto& output_desc : output_descriptions) { + auto result_idx = output_desc->m_body_value_index; + auto output_index = output_desc->m_output_index; + auto new_body_result = new_body_results[result_idx]; + + if (std::find(update_result_ids.begin(), update_result_ids.end(), result_idx) != update_result_ids.end()) { + idx_to_new_output[output_index] = + new_loop->get_concatenated_slices(new_body_result, initial_counter, counter_step, 1, -1, 0); + } else if (std::find(update_result_last_iter_ids.begin(), update_result_last_iter_ids.end(), result_idx) != + update_result_last_iter_ids.end()) { + idx_to_new_output[output_index] = new_loop->get_iter_value(new_body_result, -1); + } else if (const auto& concat_output = ov::as_type_ptr(output_desc)) { + idx_to_new_output[output_index] = new_loop->get_concatenated_slices(new_body_result, + concat_output->m_start, + concat_output->m_stride, + concat_output->m_part_size, + concat_output->m_end, + concat_output->m_axis); + } else if (const auto& iter_output = ov::as_type_ptr(output_desc)) { + idx_to_new_output[output_index] = new_loop->get_iter_value(new_body_result, iter_output->m_iteration); + } else { + // unknown type of output + return false; + } + } + + auto loop_outputs = loop->outputs(); + for (size_t i = 0; i < loop_outputs.size(); ++i) { + loop_outputs[i].replace(idx_to_new_output[i]); + } + copy_runtime_info(loop, rg.get()); + new_loop->set_friendly_name(loop->get_friendly_name()); + + return true; + }; + + auto m = + std::make_shared(loop_label, "ov::frontend::tensorflow::pass::TensorListInLoopOptimization"); + register_matcher(m, callback); +}