Skip to content

Commit

Permalink
[TF FE][MOC] Fuse Keras LSTM to LSTMSequence and Optimize TF While wi…
Browse files Browse the repository at this point in the history
…th TensorList ops (openvinotoolkit#25170)

**Details:** Fuse Keras LSTM to LSTMSequence and Optimize TF While with
TensorList ops.
Loop operations with TensorListSetItem transformed to ConcatOutput
outputs.
Loop operations with TensorListGetItem transformed to SlicedInput
inputs.

It helps to fuse six loops with LSTMCell to six LSTM sequence model. It
reduces customer model size by twice and increase throughput by 1.97x.

**Tickets:** TBD

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored and AsyaPronina committed Jul 1, 2024
1 parent 38f476c commit 5253d2e
Show file tree
Hide file tree
Showing 8 changed files with 1,134 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class TRANSFORMATIONS_API ConvertTensorIteratorToRNNSequence;
class TRANSFORMATIONS_API ConvertTensorIteratorToGRUSequence;
class TRANSFORMATIONS_API ConvertTensorIteratorToSequence;

class TRANSFORMATIONS_API ConvertLoopWithSlicedInputConcatOutputToLSTMSequence;
class TRANSFORMATIONS_API ConvertLoopWithScatterUpdateToLSTMSequence;
class TRANSFORMATIONS_API ConvertLoopToLSTMSequence;
class TRANSFORMATIONS_API FuseReverseLSTMSequence;

Expand Down Expand Up @@ -68,14 +70,29 @@ class ov::pass::ConvertTensorIteratorToSequence : public GraphRewrite {
ConvertTensorIteratorToSequence();
};

class ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertLoopWithSlicedInputConcatOutputToLSTMSequence", "0");
ConvertLoopWithSlicedInputConcatOutputToLSTMSequence();
};

class ov::pass::ConvertLoopWithScatterUpdateToLSTMSequence : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertLoopWithScatterUpdateToLSTMSequence", "0");
ConvertLoopWithScatterUpdateToLSTMSequence();
};

/**
* @ingroup ov_transformation_common_api
* @brief Replaces Loop with LSTMCell inside to LSTMSequence
*/
class ov::pass::ConvertLoopToLSTMSequence : public ov::pass::MatcherPass {
class ov::pass::ConvertLoopToLSTMSequence : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("ConvertLoopToLSTMSequence", "0");
ConvertLoopToLSTMSequence();
ConvertLoopToLSTMSequence() {
add_matcher<ov::pass::ConvertLoopWithScatterUpdateToLSTMSequence>();
add_matcher<ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence>();
}
};

/**
Expand Down
11 changes: 11 additions & 0 deletions src/common/transformations/include/transformations/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ inline std::string get_ie_output_name(const Output<Node>& output) {
*/
float cast_eps_to_float(double eps_d);

template <typename T>
bool get_constant_value(const std::shared_ptr<ov::Node>& node, T& value) {
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
if (!constant)
return false;
if (shape_size(constant->get_shape()) != 1)
return false;
value = constant->cast_vector<T>()[0];
return true;
}

template <typename T>
bool has_constant_value(const std::shared_ptr<Node>& node,
const T value,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <gtest/gtest.h>

#include <memory>
#include <random>
#include <string>

#include "common_test_utils/ov_test_utils.hpp"
Expand Down Expand Up @@ -1265,3 +1266,197 @@ TEST_P(FuseLSTMSequencesToBidirectionalLSTMSequenceTest, FusionTest) {
INSTANTIATE_TEST_SUITE_P(FuseLSTMSequencesToBidirectionalLSTMSequence,
FuseLSTMSequencesToBidirectionalLSTMSequenceTest,
testing::Combine(testing::Values(false, true), testing::Values(false, true)));

using LoopWithLSTMCellToLSTMSequenceFusionParam = std::tuple<std::string, // f activation function
std::string, // g activation function
std::string, // h activation function
size_t, // input size
size_t>; // hidden size

class LoopWithLSTMCellToLSTMSequenceFusionTest
: public testing::WithParamInterface<LoopWithLSTMCellToLSTMSequenceFusionParam>,
public TransformationTestsF {};

namespace {
void generate_weights_value(std::vector<float>& weights_value, const Shape& weights_shape) {
weights_value.resize(shape_size(weights_shape));
std::mt19937 rng(9812);
std::uniform_real_distribution<float> distribution(-300, 300);
for (size_t i = 0; i < weights_value.size(); ++i) {
weights_value[i] = distribution(rng);
}
}
} // namespace

TEST_P(LoopWithLSTMCellToLSTMSequenceFusionTest, FusionTest) {
const auto& param = GetParam();
const std::string& f_activation = std::get<0>(param);
const std::string& g_activation = std::get<1>(param);
const std::string& h_activation = std::get<2>(param);
size_t input_size = std::get<3>(param);
size_t hidden_size = std::get<4>(param);
size_t batch_size = 2;
size_t time_len = 10;

// generate weights values
// w must be of a shape [input_size, hidden_size]
// r must be of a shape [hidden_size, hidden_size]
// b must be of a shape [hidden_size]
Shape w_shape({4 * hidden_size, input_size});
Shape r_shape({4 * hidden_size, hidden_size});
Shape b_shape({4 * hidden_size});
std::vector<float> w, r, b;
generate_weights_value(w, w_shape);
generate_weights_value(r, r_shape);
generate_weights_value(b, b_shape);

{
// create body graph with LSTMCell
auto xi = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, batch_size, input_size});
auto squeeze_axis = std::make_shared<op::v0::Constant>(element::i64, Shape{}, 0);
auto xi_squeeze = std::make_shared<op::v0::Squeeze>(xi, squeeze_axis);
auto init_hidden_state = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto init_cell_state = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto w_const = op::v0::Constant::create(element::f32, w_shape, w);
auto r_const = op::v0::Constant::create(element::f32, r_shape, r);
auto b_const = op::v0::Constant::create(element::f32, b_shape, b);
auto lstm_cell =
std::make_shared<op::v4::LSTMCell>(xi_squeeze,
init_hidden_state,
init_cell_state,
w_const,
r_const,
b_const,
hidden_size,
std::vector<std::string>{f_activation, g_activation, h_activation});

auto hidden_state_res = std::make_shared<op::v0::Result>(lstm_cell->output(0));
auto cell_state_res = std::make_shared<op::v0::Result>(lstm_cell->output(1));
auto unsqueeze_axis = std::make_shared<op::v0::Constant>(element::i64, Shape{}, 0);
auto unsqueeze_hidden_state = std::make_shared<op::v0::Unsqueeze>(lstm_cell->output(0), unsqueeze_axis);
auto unsqueeze_hidden_state_res = std::make_shared<op::v0::Result>(unsqueeze_hidden_state);

// conditional graph
auto num_iters = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
auto counter = std::make_shared<op::v0::Parameter>(element::i32, Shape{1});
auto increment = std::make_shared<op::v0::Constant>(element::i32, Shape{}, 1);
auto add = std::make_shared<op::v1::Add>(counter, increment);
auto updated_counter = std::make_shared<op::v0::Result>(add);
auto less = std::make_shared<op::v1::Less>(add, num_iters);
auto less_res = std::make_shared<op::v0::Result>(less);

auto body_graph = std::make_shared<Model>(
ResultVector{hidden_state_res, cell_state_res, unsqueeze_hidden_state_res, less_res, updated_counter},
ParameterVector{xi, init_hidden_state, init_cell_state, num_iters, counter});

// create main graph with Loop
auto x = std::make_shared<op::v0::Parameter>(element::f32, Shape{time_len, batch_size, input_size});
auto h_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto c_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto execution_cond = std::make_shared<op::v0::Constant>(ov::element::boolean, ov::Shape{}, true);
auto max_iter = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, -1);
auto num_iter_const =
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, static_cast<int32_t>(time_len));
auto counter_const = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, 0);

auto loop_node = std::make_shared<op::v5::Loop>(max_iter, execution_cond);

loop_node->set_function(body_graph);
loop_node->set_special_body_ports(ov::op::v5::Loop::SpecialBodyPorts{-1, 3});

// set inputs for Loop
// x input will be sliced for each time step
loop_node->set_sliced_input(xi, x, 0, 1, 1, -1, 0);
// set back edges for cell and hidden states
// since they are changing through timeline
loop_node->set_merged_input(init_hidden_state, h_init, hidden_state_res);
loop_node->set_merged_input(init_cell_state, c_init, cell_state_res);
loop_node->set_invariant_input(num_iters, num_iter_const);
loop_node->set_merged_input(counter, counter_const, updated_counter);

// set external outputs for Loop node
// concatenated cell and hidden states from all time steps
auto hs = loop_node->get_concatenated_slices(unsqueeze_hidden_state_res, 0, 1, 1, -1, 0);
auto hs_res = std::make_shared<op::v0::Result>(hs);

model = std::make_shared<Model>(ResultVector{hs_res}, ParameterVector{x, h_init, c_init});
manager.register_pass<ov::pass::ConvertLoopWithSlicedInputConcatOutputToLSTMSequence>();
}

{
auto x = std::make_shared<op::v0::Parameter>(element::f32, Shape{time_len, batch_size, input_size});
auto h_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto c_init = std::make_shared<op::v0::Parameter>(element::f32, Shape{batch_size, hidden_size});

// transpose x since LSTMSequence expects x in a format [batch_size, time_len, input_size]
auto tr_order =
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 0, 2});
auto tr_x = std::make_shared<op::v1::Transpose>(x, tr_order);
// prepare init hidden and cell states to have a format [batch_size, num_directions, hidden_size]
// where num_directions equals one
auto unsqueeze_axis =
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{1});
auto h_init_unsqueeze = std::make_shared<op::v0::Unsqueeze>(h_init, unsqueeze_axis);
auto c_init_unsqueeze = std::make_shared<op::v0::Unsqueeze>(c_init, unsqueeze_axis);
// prepare seq_lens
auto batch_size = std::make_shared<op::v3::ShapeOf>(x, element::i64)->output(0);
auto begin = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{1});
auto end = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{2});
auto stride = std::make_shared<op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int32_t>{1});
batch_size = std::make_shared<op::v1::StridedSlice>(batch_size,
begin,
end,
stride,
std::vector<int64_t>{0},
std::vector<int64_t>{0});
auto num_iter_const =
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, static_cast<int32_t>(time_len));
auto seq_lens = std::make_shared<op::v1::Broadcast>(num_iter_const, batch_size);
// prepare W, R, B weights to a format with num_directions dimension
auto w_const = op::v0::Constant::create(element::f32, w_shape, w);
auto r_const = op::v0::Constant::create(element::f32, r_shape, r);
auto b_const = op::v0::Constant::create(element::f32, b_shape, b);
auto unsqueeze_axis2 =
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, std::vector<int32_t>{0});
auto w = std::make_shared<op::v0::Unsqueeze>(w_const, unsqueeze_axis2);
auto r = std::make_shared<op::v0::Unsqueeze>(r_const, unsqueeze_axis2);
auto b = std::make_shared<op::v0::Unsqueeze>(b_const, unsqueeze_axis2);

// create LSTMSequence
auto lstm_sequence = std::make_shared<ov::op::v5::LSTMSequence>(
tr_x,
h_init_unsqueeze,
c_init_unsqueeze,
seq_lens,
w,
r,
b,
hidden_size,
ov::op::RecurrentSequenceDirection::FORWARD,
std::vector<float>{},
std::vector<float>{},
std::vector<std::string>{f_activation, g_activation, h_activation},
0.0f);

// prepare output
auto squeeze_axis = std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{1}, 1);
auto squeeze_output_hs = std::make_shared<op::v0::Squeeze>(lstm_sequence->output(0), squeeze_axis);
auto tr_order2 =
std::make_shared<op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 0, 2});
auto tr_squeeze_output_hs = std::make_shared<op::v1::Transpose>(squeeze_output_hs, tr_order2);
auto output_hs_res = std::make_shared<op::v0::Result>(tr_squeeze_output_hs);
model_ref = std::make_shared<Model>(ResultVector{output_hs_res}, ParameterVector{x, h_init, c_init});
}

comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}

INSTANTIATE_TEST_SUITE_P(LoopWithLSTMCellToLSTMSequenceFusion,
LoopWithLSTMCellToLSTMSequenceFusionTest,
testing::Combine(testing::Values("sigmoid", "tanh"),
testing::Values("sigmoid", "relu"),
testing::Values("tanh", "relu"),
testing::Values(2, 3),
testing::Values(3, 4)));
12 changes: 11 additions & 1 deletion src/frontends/tensorflow/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "openvino/util/file_util.hpp"
#include "openvino/util/log.hpp"
#include "tf_framework_node.hpp"
#include "transformations/common_optimizations/eliminate_loop_inputs_outputs.hpp"
#include "transformations/common_optimizations/remove_concat_zero_dim_input.hpp"
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "transformations/control_flow/unroll_if.hpp"
Expand Down Expand Up @@ -568,7 +569,16 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<pass::TensorArrayV3Replacer>();
manager.register_pass<pass::ConstToResultRemover>();
manager.register_pass<pass::SwitchMergeResolver>();
manager.register_pass<pass::TensorListOperationsResolver>();

// apply EliminateLoopInputsOutputs to avoid extra Results
// that output the same value as receiving on input
// it is needed for applying TensorListInLoopOptimization
manager.register_pass<ov::pass::EliminateLoopInputsOutputs>();
manager.register_pass<pass::TensorListReplacer>();
manager.register_pass<pass::TensorListInLoopOptimization>();
manager.register_pass<pass::TensorListSetItemReplacer>();
manager.register_pass<pass::TensorListGetItemReplacer>();

manager.register_pass<ov::pass::UnrollIf>();
manager.register_pass<ov::pass::RemoveConcatZeroDimInput>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/tensorflow/src/op/block_lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void create_decomposed_block_lstm(const Output<Node>& x,
auto squeeze_axis = std::make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto xi = std::make_shared<v0::Squeeze>(xi_param, squeeze_axis);

auto lstm_cell = std::make_shared<v0::LSTMCell>(xi,
auto lstm_cell = std::make_shared<v4::LSTMCell>(xi,
h_prev_param,
c_prev_param,
w_param,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

#pragma once

#include <memory>
#include <utility>

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"

Expand Down Expand Up @@ -36,16 +33,12 @@ class TensorListGetItemReplacer : public ov::pass::MatcherPass {
TensorListGetItemReplacer();
};

// Replace and optimize sub-graphs with TensorList operations such as TensorListReserve,
// TensorListSetItem, TensorListGetItem
class TensorListOperationsResolver : public ov::pass::GraphRewrite {
// Optimize sub-graphs with TensorList operations in Loop body graph
// Replace TensorListSetItem and TensorListGetItem with ConcatOutput and SlicedInput
class TensorListInLoopOptimization : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TensorListOperationsResolver", "0");
TensorListOperationsResolver() {
add_matcher<TensorListReplacer>();
add_matcher<TensorListSetItemReplacer>();
add_matcher<TensorListGetItemReplacer>();
}
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TensorListInLoopOptimization");
TensorListInLoopOptimization();
};

} // namespace pass
Expand Down
Loading

0 comments on commit 5253d2e

Please sign in to comment.