Skip to content

Commit

Permalink
applied comments #3
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 9, 2022
1 parent 597375d commit 153fd5c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ class TRANSFORMATIONS_API ReshapeSequenceFusion;
class ngraph::pass::ReshapeSequenceFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ReshapeSequenceFusion();
ReshapeSequenceFusion(bool use_shape_for_elimination = true);
};
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {

// eliminate redundant reshape, squeeze, or unsqueeze
auto input_node = input.get_node_shared_ptr();
if (input_node->get_output_size() != 1)
return false;

if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
ov::as_type_ptr<opset3::Reshape>(input_node)) {
if (input_node->get_output_target_inputs(0).size() != 1)
return false;

auto shape = node->get_output_shape(0);

// remove interchangeable nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ bool has_valid_pattern(const ov::Output<ov::Node>& node_out) {
}
} // namespace

ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion(bool use_shape_for_elimination) {
MATCHER_SCOPE(ReshapeSequenceFusion);
auto reshape_input = pattern::any_input();
auto reshape_a_pattern = pattern::wrap_type<opset8::Constant>();
Expand Down Expand Up @@ -88,10 +88,14 @@ ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
}

// remove redundant reshapes
if (input.get_node_shared_ptr()->get_output_partial_shape(0).is_static() && reshape->get_output_partial_shape(0).is_static() &&
input.get_node_shared_ptr()->get_output_shape(0) == reshape->get_output_shape(0)) {
return replace_output_update_name(reshape->output(0), input);
} else {
bool replaced = false;
if (use_shape_for_elimination && input.get_partial_shape().is_static() && reshape->get_output_partial_shape(0).is_static() &&
input.get_shape() == reshape->get_output_shape(0)) {
// in case if elimination is not allowed we still can eliminate all transposes except last one
replaced = replace_output_update_name(reshape->output(0), input);
}

if (!replaced) {
reshape->input(0).replace_source_output(input);
copy_runtime_info(nodes, reshape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,33 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) {
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
}

TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) {
std::shared_ptr<Function> f;
{
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});

auto reshape_1_shape = opset4::Constant::create(element::i64, Shape{2}, {128, 3});
auto reshape_1 = std::make_shared<opset4::Reshape>(arg, reshape_1_shape, false);
reshape_1->set_friendly_name("reshape_1");

auto reshape_2_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
auto reshape_2 = std::make_shared<opset4::Reshape>(reshape_1, reshape_2_shape, false);
reshape_2->set_friendly_name("reshape_2");

auto relu = std::make_shared<opset4::Relu>(reshape_1);
relu->set_friendly_name("relu");

f = std::make_shared<Function>(NodeVector{reshape_2, relu}, ParameterVector{arg});
}

pass::Manager pass_manager;
pass_manager.register_pass<pass::InitNodeInfo>();
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 2);
}

TEST(nop_elimination, concat_elimination_single_node) {
int64_t a = 0;
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
Expand Down

0 comments on commit 153fd5c

Please sign in to comment.