Skip to content

Commit

Permalink
[Transformations] Added interchangeable reshape elimination
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 31, 2022
1 parent b56fd07 commit f5ee3bb
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TRANSFORMATIONS_API EliminateSplit;
class TRANSFORMATIONS_API EliminateSqueeze;
class TRANSFORMATIONS_API EliminateTranspose;
class TRANSFORMATIONS_API EliminateEltwise;
class TRANSFORMATIONS_API EliminateReshape;
class TRANSFORMATIONS_API NopElimination;

} // namespace pass
Expand Down Expand Up @@ -108,6 +109,16 @@ class ngraph::pass::EliminateEltwise: public ngraph::pass::MatcherPass {
EliminateEltwise();
};

/**
* @ingroup ie_transformation_common_api
* @brief EliminateReshape eliminates reshape that does nothing
*/
class ngraph::pass::EliminateReshape: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EliminateReshape();
};

class ngraph::pass::NopElimination: public GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,6 @@ static bool eliminate_nop(const std::shared_ptr<Node>& node) {
return false;
}

static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
auto input = node->input_value(0);
// check if reshape is not identity op
if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
NGRAPH_DEBUG << node << " has dynamic shapes.";
return false;
}
// remove identity op
if (input.get_shape() == node->get_output_shape(0)) {
return replace_output_update_name(node->output(0), input);
}
// eliminate redundant reshape, squeeze, or unsqueeze
auto input_node = input.get_node_shared_ptr();
if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
ov::as_type_ptr<opset3::Reshape>(input_node)) {
auto shape = node->get_output_shape(0);
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
new_reshape->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, new_reshape);
replace_node(node, new_reshape);
return true;
}

return false;
}

static size_t count_unknown_dims(const PartialShape& ps) {
size_t rc = 0;
if (ps.is_static()) {
Expand Down Expand Up @@ -285,7 +254,6 @@ NAME() { \
}; \
NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);

SIMPLE_MATCHER_PASS_DEFINITION(EliminateReshape, opset3::Reshape, eliminate_reshape_v1);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateUnsqueeze, opset3::Unsqueeze, eliminate_unsqueeze);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateBroadcast, op::v1::Broadcast, eliminate_nop);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, opset3::Gather, simplify_gather);
Expand Down Expand Up @@ -542,6 +510,56 @@ pass::EliminateEltwise::EliminateEltwise() {
this->register_matcher(m, callback);
}

NGRAPH_RTTI_DEFINITION(pass::EliminateReshape, "EliminateReshape", 0);

pass::EliminateReshape::EliminateReshape() {
MATCHER_SCOPE(EliminateReshape);
auto reshape_node_pattern = pattern::wrap_type<opset8::Reshape>();

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto node = m.get_match_root();
auto input = node->input_value(0);
// check if reshape is not identity op
if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
NGRAPH_DEBUG << node << " has dynamic shapes.";
return false;
}
// remove identity op
if (input.get_shape() == node->get_output_shape(0)) {
return replace_output_update_name(node->output(0), input);
}
// eliminate redundant reshape, squeeze, or unsqueeze
auto input_node = input.get_node_shared_ptr();
if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
ov::as_type_ptr<opset3::Reshape>(input_node)) {
auto shape = node->get_output_shape(0);

// remove interchangeable reshape, squeeze, or unsqueeze
if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) {
node->output(0).get_node_shared_ptr()->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, node->output(0).get_node_shared_ptr());
return replace_output_update_name(node->output(0), input_node->input_value(0));
} else {
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
new_reshape->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, new_reshape);
replace_node(node, new_reshape);
return true;
}
}

return false;
};

auto m = std::make_shared<pattern::Matcher>(reshape_node_pattern, matcher_name);
this->register_matcher(m, callback);
}

NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);

ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "transformations/convert_precision.hpp"
#include "transformations/utils/utils.hpp"
#include "rnn_sequences_optimization.hpp"
#include "transformations/common_optimizations/nop_elimination.hpp"

namespace MKLDNNPlugin {

Expand All @@ -34,6 +35,7 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(nGraphFunc)) {
manager.register_pass<ReshapeFullyConnectedFusion>();
}
manager.register_pass<ngraph::pass::EliminateReshape>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,48 @@ TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

bool reshape_is_missing = true;
bool movement_are_missing = true;
for (auto node : f->get_ops()) {
if (node->get_friendly_name() == "reshape") {
reshape_is_missing = false;
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
auto original_names = ngraph::getFusedNamesVector(node);
sort(original_names.begin(), original_names.end());
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
if (node->get_friendly_name() == "reshape" || node->get_friendly_name() == "squeeze") {
movement_are_missing = false;
}
}
ASSERT_FALSE(reshape_is_missing);
ASSERT_TRUE(movement_are_missing);
}

TEST(nop_elimination, squeeze_unsqueeze_elimination) {
std::shared_ptr<Function> f;
{
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});

auto relu = std::make_shared<opset4::Relu>(arg);
relu->set_friendly_name("relu");

auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
squeeze->set_friendly_name("squeeze");

auto unsqueeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(squeeze, unsqueeze_axes);
unsqueeze->set_friendly_name("unsqueeze");

auto abs = std::make_shared<opset4::Abs>(unsqueeze);

f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
}

pass::Manager pass_manager;
pass_manager.register_pass<pass::InitNodeInfo>();
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

bool movement_are_missing = true;
for (auto node : f->get_ops()) {
if (node->get_friendly_name() == "squeeze" || node->get_friendly_name() == "unsqueeze") {
movement_are_missing = false;
}
}
ASSERT_TRUE(movement_are_missing);
}

TEST(nop_elimination, reshape_elimination_v1_dynamic) {
Expand Down

0 comments on commit f5ee3bb

Please sign in to comment.