Skip to content

Commit

Permalink
returned Reshape in condition
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 31, 2022
1 parent 7810194 commit b1ae35d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
return false;

if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
ov::as_type_ptr<opset3::Unsqueeze>(input_node)) {
ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
ov::as_type_ptr<opset3::Reshape>(input_node)) {
auto shape = node->get_output_shape(0);

// remove interchangeable nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,13 @@ TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

bool reshape_is_missing = true;
bool movement_are_missing = true;
for (auto node : f->get_ops()) {
if (node->get_friendly_name() == "reshape") {
reshape_is_missing = false;
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
auto original_names = ngraph::getFusedNamesVector(node);
sort(original_names.begin(), original_names.end());
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
if (node->get_friendly_name() == "reshape" || node->get_friendly_name() == "squeeze") {
movement_are_missing = false;
}
}
ASSERT_FALSE(reshape_is_missing);
ASSERT_TRUE(movement_are_missing);
}

TEST(nop_elimination, squeeze_unsqueeze_elimination) {
Expand Down

0 comments on commit b1ae35d

Please sign in to comment.