Skip to content

Commit

Permalink
Unroll Tensor Iterator using ngraph pass (openvinotoolkit#7205)
Browse files Browse the repository at this point in the history
* use ngraph-based unroll-ti as default

* use isNgraphPassesUsed for legacy code

* code review fixes

* code review fixes
  • Loading branch information
evkotov authored and akuporos committed Sep 29, 2021
1 parent 9b18622 commit 84234bb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
13 changes: 5 additions & 8 deletions inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,11 +750,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<ngraph::pass::UnrollTensorIterator>();

const auto& pass_config = manager.get_pass_config();
pass_config->set_callback<ngraph::pass::UnrollTensorIterator>(
[](const std::shared_ptr<const ngraph::Node> &node) -> bool {
// UnrollTI transformation is disabled by default, is turned on by LowLatency transformation
return node->get_rt_info().count("UNROLL_TI") == 0;
});
pass_config->disable<ngraph::pass::FakeQuantizeMulFusion>();
pass_config->disable<ngraph::pass::FakeQuantizeReshapeFusion>();
pass_config->disable<ngraph::pass::PullTransposeThroughFQUp>();
Expand Down Expand Up @@ -798,10 +793,12 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
auto run_passes = [&] (const CNNNetwork& network, bool runBeforeCopy, bool lowPrecision) {
auto passes = make_shared<PassManager>(PassManagerSettings{runBeforeCopy, lowPrecision}, network);
passes->registerPass<RemoveConstPass>();
passes->registerPass<UnrollTIPass>();
passes->registerPass<RemoveConstPass>();
if (!isNgraphPassesUsed)
if (!isNgraphPassesUsed) {
passes->registerPass<UnrollTIPass>();
passes->registerPass<RemoveConstPass>();
passes->registerPass<UnrollLSTMCellPass>();
}

passes->registerPass<RemoveSingleInputConcatPass>();

// fake quantisation aware passes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ namespace SubgraphTestsDefinitions {
// TI construction
auto tensor_iterator = std::make_shared<TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_invariant_input(X, permute_in);
tensor_iterator->set_sliced_input(X, permute_in, 0, 1, 1, -1, 0);
tensor_iterator->set_merged_input(H_t, hidden_memory_read, H_o);
tensor_iterator->set_merged_input(C_t, cell_memory_read, C_o);

Expand All @@ -130,6 +130,7 @@ namespace SubgraphTestsDefinitions {
SinkVector{cell_memory_write, hidden_memory_write},
input_parameter,
"TI_with_memory");
tensor_iterator->validate_and_infer_types();
}

void MemoryLSTMCellTest::switchToNgraphFriendlyModel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ void MultipleLSTMCellTest::SetUp() {
// TI construction
auto tensor_iterator = std::make_shared<TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_invariant_input(X, permute_in);
tensor_iterator->set_sliced_input(X, permute_in, 0, 1, 1, -1, 0);
tensor_iterator->set_merged_input(H_t, hidden_memory_read, H_o);
tensor_iterator->set_merged_input(C_t, cell_memory_read, C_o);
tensor_iterator->validate_and_infer_types();

auto out_unsqueeze = tensor_iterator->get_iter_value(unsqueeze_o, -1);
auto out_hidden = tensor_iterator->get_iter_value(H_o, -1);
Expand Down Expand Up @@ -165,9 +166,10 @@ void MultipleLSTMCellTest::SetUp() {
// TI construction
auto tensor_iterator_2 = std::make_shared<TensorIterator>();
tensor_iterator_2->set_body(body_2);
tensor_iterator_2->set_invariant_input(X_2, inbetween_squeeze);
tensor_iterator_2->set_sliced_input(X_2, inbetween_squeeze, 0, 1, 1, -1, 0);
tensor_iterator_2->set_merged_input(H_t_2, hidden_memory_2_read, H_o_2);
tensor_iterator_2->set_merged_input(C_t_2, cell_memory_2_read, C_o_2);
tensor_iterator_2->validate_and_infer_types();

auto out_unsqueeze_2 = tensor_iterator_2->get_iter_value(unsqueeze_o_2, -1);
auto out_hidden_2 = tensor_iterator_2->get_iter_value(H_o_2, -1);
Expand Down

0 comments on commit 84234bb

Please sign in to comment.