Skip to content

Commit

Permalink
Support dynamic seq lenghts in ConvertSequenceToTensorIterator transf…
Browse files Browse the repository at this point in the history
…ormation
  • Loading branch information
itikhono committed Oct 24, 2023
1 parent 6bec4fc commit 48d9f75
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ TRANSFORMATIONS_API bool check_for_broadcast(const PartialShape& ref_shape, cons

TRANSFORMATIONS_API std::shared_ptr<Node> activation(const std::string& activation_name, const Output<Node>& apply_to);

TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr<Node>& seq_len_input, int64_t max_seq_len);
TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr<Node>& X,
const std::shared_ptr<Node>& seq_len_input);

TRANSFORMATIONS_API std::shared_ptr<Node> try_fold_unary_output(const std::shared_ptr<Node>& node);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,11 @@ bool convert_sequence_to_ti(const std::shared_ptr<ov::Node>& sequence,
const ov::Output<ov::Node>& B,
const ov::op::RecurrentSequenceDirection& direction) {
auto X_pshape = X.get_partial_shape();
if (X_pshape.size() < 2 || X_pshape[1].is_dynamic()) {
if (X_pshape.size() < 2) {
return false;
}

auto max_seq_len = X_pshape[1].get_length();
bool enable_mask = ov::op::util::is_seq_len_provided(seq_lengths.get_node_shared_ptr(), max_seq_len);
bool enable_mask = ov::op::util::is_seq_len_provided(X.get_node_shared_ptr(), seq_lengths.get_node_shared_ptr());

const bool is_reverse = direction == ov::op::RecurrentSequenceDirection::REVERSE;
std::shared_ptr<ov::Node> reverse_seq_before;
Expand Down
51 changes: 48 additions & 3 deletions src/common/transformations/src/transformations/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,56 @@ std::shared_ptr<ov::Node> activation(const std::string& activation_name, const o
}
}

bool is_seq_len_provided(const std::shared_ptr<Node>& seq_len_input, int64_t max_seq_len) {
bool is_seq_len_provided(const std::shared_ptr<Node>& X, const std::shared_ptr<Node>& seq_len_input) {
auto max_seq_dim = X->get_output_partial_shape(0)[1];
if (max_seq_dim.is_dynamic()) {
// if values in seq_len input are equal to max_seq_len dim in X input
// then we don't need to insert Select operations
// supported seq_len_input:
// X -> ShapeOf -> Gather (max_seq_dim) -> Optional (Broadcast)
std::shared_ptr<Node> input = seq_len_input;
auto broadcast = ov::as_type_ptr<ov::op::v3::Broadcast>(input);
if (broadcast) {
input = seq_len_input->input_value(0).get_node_shared_ptr();
}

auto gather = ov::as_type_ptr<ov::op::util::GatherBase>(input);
bool valid_gather = false;
if (gather) {
auto indices = gather->input_value(1).get_node_shared_ptr();
auto axis = gather->input_value(2).get_node_shared_ptr();
auto indices_const = ov::as_type_ptr<ov::op::v0::Constant>(indices);
auto axis_const = ov::as_type_ptr<ov::op::v0::Constant>(axis);
if (indices_const && axis_const) {
auto ind_values = indices_const->cast_vector<int64_t>();
auto axis_values = axis_const->cast_vector<int64_t>();
if (ind_values.size() == 1 && ind_values[0] == 1 && axis_values.size() == 1 && axis_values[0] == 0) {
valid_gather = true;
}
}
}

if (!valid_gather) {
return true;
}

auto shape_of = ov::as_type_ptr<ov::op::util::ShapeOfBase>(gather->input_value(0).get_node_shared_ptr());
if (!shape_of) {
return true;
}

if (shape_of->input_value(0).get_node_shared_ptr() != X) {
return true;
}

return false;
}

auto max_seq_len_val = max_seq_dim.get_length();
if (const auto& seq_len_const = std::dynamic_pointer_cast<op::v0::Constant>(seq_len_input)) {
const auto& seq_len_values = seq_len_const->cast_vector<int64_t>();
return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len](const int64_t val) {
return val != max_seq_len;
return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len_val](const int64_t val) {
return val != max_seq_len_val;
});
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,115 @@ TEST(TransformationTests, ConvertQuantizedGRUSequenceToTensorIterator) {
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

TEST(TransformationTests, ConvertLSTMSequenceWithDynSeqLenToTensorIterator) {
std::shared_ptr<ov::Model> f(nullptr), f_ref(nullptr);
{
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{1, -1, 16});
auto Y = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
auto Z = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
auto indices = opset5::Constant::create(element::i32, {1}, {1});
auto axis = opset5::Constant::create(element::i32, {}, {0});
auto seq_lengths = std::make_shared<opset5::Gather>(shape_of, indices, axis);

auto w_val = std::vector<float>(512 * 16, 0);
auto r_val = std::vector<float>(512 * 128, 0);
auto b_val = std::vector<float>(512, 0);
auto W = opset5::Constant::create(element::f32, Shape{1, 512, 16}, w_val);
auto R = opset5::Constant::create(element::f32, Shape{1, 512, 128}, r_val);
auto B = opset5::Constant::create(element::f32, Shape{1, 512}, b_val);

auto rnn_sequence = std::make_shared<opset5::LSTMSequence>(X,
Y,
Z,
seq_lengths,
W,
R,
B,
128,
op::RecurrentSequenceDirection::FORWARD);
auto Y_out = std::make_shared<opset5::Result>(rnn_sequence->output(0));
auto Ho = std::make_shared<opset5::Result>(rnn_sequence->output(1));
auto Co = std::make_shared<opset5::Result>(rnn_sequence->output(2));
Y_out->set_friendly_name("Y_out");
Ho->set_friendly_name("Ho");
Co->set_friendly_name("Co");

f = std::make_shared<ov::Model>(NodeVector{Y_out, Ho, Co}, ParameterVector{X, Y, Z});

pass::Manager m;
m.register_pass<ov::pass::InitNodeInfo>();
m.register_pass<ov::pass::ConvertLSTMSequenceToTensorIterator>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{1, -1, 16});
auto Y = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
auto Z = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
auto squeeze_pattern = opset5::Constant::create(element::i64, Shape{1}, {1});
auto squeeze_y = std::make_shared<opset5::Squeeze>(Y, squeeze_pattern);
auto squeeze_z = std::make_shared<opset5::Squeeze>(Z, squeeze_pattern);

auto Xi = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 16});
auto Yi = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 128});
auto Zi = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 128});
auto seq_body_param = std::make_shared<opset5::Parameter>(element::i32, PartialShape{1});

// Body
auto squeeze_x = std::make_shared<opset5::Squeeze>(Xi, squeeze_pattern);

auto w_val = std::vector<float>(512 * 16, 0);
auto r_val = std::vector<float>(512 * 128, 0);
auto b_val = std::vector<float>(512, 0);
auto W = opset5::Constant::create(element::f32, Shape{512, 16}, w_val);
auto R = opset5::Constant::create(element::f32, Shape{512, 128}, r_val);
auto B = opset5::Constant::create(element::f32, Shape{512}, b_val);

auto rnn_cell = std::make_shared<opset5::LSTMCell>(squeeze_x, Yi, Zi, W, R, B, 128);

auto unsqueeze_pattern = opset5::Constant::create(element::i64, Shape{1}, {1});
auto Ho = std::make_shared<opset5::Result>(rnn_cell->output(0));

auto Co = std::make_shared<opset5::Result>(rnn_cell->output(1));

auto unsqueeze_y = std::make_shared<opset5::Unsqueeze>(rnn_cell->output(0), unsqueeze_pattern);
auto Y_out = std::make_shared<opset5::Result>(unsqueeze_y);

auto body = std::make_shared<Model>(OutputVector{Y_out, Ho, Co}, ParameterVector{Xi, Yi, Zi, seq_body_param});

auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
tensor_iterator->set_body(body);

tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
tensor_iterator->get_concatenated_slices(Y_out, 0, 1, 1, -1, 1);

tensor_iterator->set_merged_input(Yi, squeeze_y, Ho);
tensor_iterator->set_merged_input(Zi, squeeze_z, Co);

auto shape_of = std::make_shared<opset5::ShapeOf>(X);
auto indices = opset5::Constant::create(element::i32, {1}, {1});
auto axis = opset5::Constant::create(element::i32, {}, {0});
auto seq_lengths = std::make_shared<opset5::Gather>(shape_of, indices, axis);
tensor_iterator->set_invariant_input(seq_body_param, seq_lengths);

tensor_iterator->get_iter_value(Ho);
tensor_iterator->get_iter_value(Co);

auto res_ti_Y = std::make_shared<opset5::Result>(
std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(0), unsqueeze_pattern));
auto res_ti_H = std::make_shared<opset5::Result>(
std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(1), unsqueeze_pattern));
auto res_ti_C = std::make_shared<opset5::Result>(
std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(2), unsqueeze_pattern));
res_ti_Y->set_friendly_name("Y_out");
res_ti_H->set_friendly_name("Ho");
res_ti_C->set_friendly_name("Co");
f_ref = std::make_shared<ov::Model>(NodeVector{res_ti_Y, res_ti_H, res_ti_C}, ParameterVector{X, Y, Z});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
5 changes: 3 additions & 2 deletions src/plugins/intel_cpu/src/nodes/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,9 @@ bool RNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::s
errorMessage = "Max sequence length dimension is dynamic";
return false;
}
auto maxSeqLen = data_pshape[maxSeqLenDimIdx].get_length();
if (ov::op::util::is_seq_len_provided(op->get_input_node_shared_ptr(seqLenIdx), maxSeqLen)) {

if (ov::op::util::is_seq_len_provided(op->get_input_node_shared_ptr(0),
op->get_input_node_shared_ptr(seqLenIdx))) {
errorMessage = "Unsupported sequence length.";
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return lstm_seq->get_clip() == 0.0f &&
lstm_seq->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"} &&
max_seq_len < 16 &&
!ov::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(3),
max_seq_len);
!ov::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(0),
lstm_seq->get_input_node_shared_ptr(3));
}
return false;
};
Expand Down

0 comments on commit 48d9f75

Please sign in to comment.