Skip to content

Commit

Permalink
Applied comments #2
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 31, 2022
1 parent f5ee3bb commit 7810194
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class TRANSFORMATIONS_API EliminateSplit;
class TRANSFORMATIONS_API EliminateSqueeze;
class TRANSFORMATIONS_API EliminateTranspose;
class TRANSFORMATIONS_API EliminateEltwise;
class TRANSFORMATIONS_API EliminateReshape;
class TRANSFORMATIONS_API NopElimination;

} // namespace pass
Expand Down Expand Up @@ -109,16 +108,6 @@ class ngraph::pass::EliminateEltwise: public ngraph::pass::MatcherPass {
EliminateEltwise();
};

/**
* @ingroup ie_transformation_common_api
* @brief EliminateReshape eliminates reshape that does nothing
*/
class ngraph::pass::EliminateReshape: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EliminateReshape();
};

class ngraph::pass::NopElimination: public GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TRANSFORMATIONS_API ReshapeSequenceFusion;

/**
* @ingroup ie_transformation_common_api
* @brief ReshpaeSequenceFusion fuses sequence of Reshape operation into single Reshape
* @brief ReshapeSequenceFusion fuses sequence of Reshape operation into single Reshape or eliminates full redundant sequence
*/

class ngraph::pass::ReshapeSequenceFusion: public ngraph::pass::MatcherPass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,46 @@ static bool eliminate_nop(const std::shared_ptr<Node>& node) {
return false;
}

static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
auto input = node->input_value(0);
// check if reshape is not identity op
if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
NGRAPH_DEBUG << node << " has dynamic shapes.";
return false;
}
// remove identity op
if (input.get_shape() == node->get_output_shape(0)) {
return replace_output_update_name(node->output(0), input);
}

// eliminate redundant reshape, squeeze, or unsqueeze
auto input_node = input.get_node_shared_ptr();
if (input_node->get_output_size() != 1)
return false;

if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
ov::as_type_ptr<opset3::Unsqueeze>(input_node)) {
auto shape = node->get_output_shape(0);

// remove interchangeable nodes
if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) {
return replace_output_update_name(node->output(0), input_node->input_value(0));
} else {
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
new_reshape->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, new_reshape);
replace_node(node, new_reshape);
return true;
}
}

return false;
}

static size_t count_unknown_dims(const PartialShape& ps) {
size_t rc = 0;
if (ps.is_static()) {
Expand Down Expand Up @@ -254,6 +294,7 @@ NAME() { \
}; \
NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);

SIMPLE_MATCHER_PASS_DEFINITION(EliminateReshape, opset3::Reshape, eliminate_reshape_v1);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateUnsqueeze, opset3::Unsqueeze, eliminate_unsqueeze);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateBroadcast, op::v1::Broadcast, eliminate_nop);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, opset3::Gather, simplify_gather);
Expand Down Expand Up @@ -510,56 +551,6 @@ pass::EliminateEltwise::EliminateEltwise() {
this->register_matcher(m, callback);
}

NGRAPH_RTTI_DEFINITION(pass::EliminateReshape, "EliminateReshape", 0);

pass::EliminateReshape::EliminateReshape() {
MATCHER_SCOPE(EliminateReshape);
auto reshape_node_pattern = pattern::wrap_type<opset8::Reshape>();

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto node = m.get_match_root();
auto input = node->input_value(0);
// check if reshape is not identity op
if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
NGRAPH_DEBUG << node << " has dynamic shapes.";
return false;
}
// remove identity op
if (input.get_shape() == node->get_output_shape(0)) {
return replace_output_update_name(node->output(0), input);
}
// eliminate redundant reshape, squeeze, or unsqueeze
auto input_node = input.get_node_shared_ptr();
if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
ov::as_type_ptr<opset3::Reshape>(input_node)) {
auto shape = node->get_output_shape(0);

// remove interchangeable reshape, squeeze, or unsqueeze
if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) {
node->output(0).get_node_shared_ptr()->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, node->output(0).get_node_shared_ptr());
return replace_output_update_name(node->output(0), input_node->input_value(0));
} else {
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
new_reshape->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, new_reshape);
replace_node(node, new_reshape);
return true;
}
}

return false;
};

auto m = std::make_shared<pattern::Matcher>(reshape_node_pattern, matcher_name);
this->register_matcher(m, callback);
}

NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);

ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
input = node->input_value(0);
}

reshape->input(0).replace_source_output(input);
copy_runtime_info(nodes, reshape);
// remove redundant reshapes
if (input.get_node_shared_ptr()->get_output_partial_shape(0).is_static() && reshape->get_output_partial_shape(0).is_static() &&
input.get_node_shared_ptr()->get_output_shape(0) == reshape->get_output_shape(0)) {
return replace_output_update_name(reshape->output(0), input);
} else {
reshape->input(0).replace_source_output(input);
copy_runtime_info(nodes, reshape);
}

return false;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "transformations/convert_precision.hpp"
#include "transformations/utils/utils.hpp"
#include "rnn_sequences_optimization.hpp"
#include "transformations/common_optimizations/nop_elimination.hpp"
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"

namespace MKLDNNPlugin {

Expand All @@ -35,7 +35,7 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(nGraphFunc)) {
manager.register_pass<ReshapeFullyConnectedFusion>();
}
manager.register_pass<ngraph::pass::EliminateReshape>();
manager.register_pass<ngraph::pass::ReshapeSequenceFusion>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,17 @@ TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

bool movement_are_missing = true;
bool reshape_is_missing = true;
for (auto node : f->get_ops()) {
if (node->get_friendly_name() == "reshape" || node->get_friendly_name() == "squeeze") {
movement_are_missing = false;
if (node->get_friendly_name() == "reshape") {
reshape_is_missing = false;
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
auto original_names = ngraph::getFusedNamesVector(node);
sort(original_names.begin(), original_names.end());
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
}
}
ASSERT_TRUE(movement_are_missing);
ASSERT_FALSE(reshape_is_missing);
}

TEST(nop_elimination, squeeze_unsqueeze_elimination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,21 @@ TEST_F(TransformationTestsF, ReshapeSequenceFusionNeg4) {
manager.register_pass<pass::ReshapeSequenceFusion>();
}
}

TEST_F(TransformationTestsF, ReshapeSequenceFusionEliminate) {
{
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
auto relu = std::make_shared<opset6::Relu>(data);
auto a = reshape(relu, {2, 3});
auto b = reshape(a, {1, 2, 3});
function = std::make_shared<Function>(OutputVector{b}, ParameterVector{data});

manager.register_pass<pass::ReshapeSequenceFusion>();
}

{
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
auto relu = std::make_shared<opset6::Relu>(data);
function_ref = std::make_shared<Function>(OutputVector{relu}, ParameterVector{data});
}
}

0 comments on commit 7810194

Please sign in to comment.