Skip to content

Commit

Permalink
Optimize ShapeOf sub-graphs (#6627)
Browse files Browse the repository at this point in the history
1. Include shape sub-graphs optimization in the model optimizer nGraph pipeline
2. Extend shape sub-graph optimizations with useless Concat and useless Gather optimization
  • Loading branch information
Evgenya Stepyreva authored Jul 13, 2021
1 parent 3d13424 commit 3710c0e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <transformations/common_optimizations/hswish_fusion.hpp>
#include <transformations/common_optimizations/convert_quantize_dequantize.hpp>
#include <transformations/common_optimizations/pad_fusion.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);

Expand All @@ -35,6 +36,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();

auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace pass {
class TRANSFORMATIONS_API SimplifyShapeOfSubGraph;
class TRANSFORMATIONS_API SharedShapeOf;
class TRANSFORMATIONS_API GroupedGatherElimination;
class TRANSFORMATIONS_API GatherNopElimination;

} // namespace pass
} // namespace ngraph
Expand Down Expand Up @@ -58,3 +59,13 @@ class ngraph::pass::SimplifyShapeOfSubGraph: public ngraph::pass::FunctionPass {
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};

/**
* @ingroup ie_transformation_common_api
* @brief GatherNopElimination transformation optimizes out useless Gather operations
*/
class ngraph::pass::GatherNopElimination: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
GatherNopElimination();
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
#include <transformations/common_optimizations/eliminate_unsqueeze_gather.hpp>
#include <transformations/utils/utils.hpp>
#include <numeric>

NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedShapeOf, "SharedShapeOf", 0);

Expand Down Expand Up @@ -49,7 +50,7 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
MATCHER_SCOPE(GroupedGatherElimination);
auto concat_label = ngraph::pattern::wrap_type<ngraph::opset1::Concat>(pattern::rank_equals(1));

ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto concat = m.get_match_root();
OutputVector inputs = concat->input_values();
NodeVector new_ops;
Expand All @@ -62,19 +63,27 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
++i;
continue;
} // curr and next are the same type of gather which takes data from the same source
bool is_opset1 = is_type<opset1::Gather>(curr);
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(OutputVector{curr->input_value(1), next->input_value(1)}, 0);
auto new_gather = curr->clone_with_new_inputs(
{curr->input_value(0), joint_indices, ngraph::opset1::Constant::create(element::i64, {}, {0})});
std::shared_ptr<Node> new_gather;
if (is_opset1)
new_gather = register_new_node<ngraph::opset1::Gather>(
curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
else
new_gather = register_new_node<ngraph::opset7::Gather>(
curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
new_ops.push_back(joint_indices);
new_ops.push_back(new_gather);
inputs.erase(inputs.begin() + i);
inputs[i] = new_gather->output(0);
}
ngraph::copy_runtime_info(concat, new_ops);
if (inputs.size() == 1) // we can optimize out concat
return replace_output_update_name(concat->output(0), inputs[0]);
if (original_inputs_size > inputs.size()) {
auto new_concat = std::make_shared<opset1::Concat>(inputs, 0);
new_ops.push_back(new_concat);
new_concat->set_friendly_name(concat->get_friendly_name());
ngraph::copy_runtime_info(concat, new_ops);
ngraph::copy_runtime_info(concat, new_concat);
ngraph::replace_node(concat, new_concat);
return true;
}
Expand All @@ -85,17 +94,43 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
this->register_matcher(m, callback);
}

NGRAPH_RTTI_DEFINITION(ngraph::pass::GatherNopElimination, "GatherNopElimination", 0);

ngraph::pass::GatherNopElimination::GatherNopElimination() {
MATCHER_SCOPE(GatherNopElimination);
const auto gather_label = ngraph::pattern::wrap_type<ngraph::op::util::GatherBase>(
{ngraph::pattern::any_input(pattern::has_static_shape()),
ngraph::pattern::wrap_type<ngraph::op::Constant>(),
ngraph::pattern::wrap_type<ngraph::op::Constant>()});

ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto gather = m.get_match_root();
const auto& number_of_indices = shape_size(gather->get_input_shape(1));
if (gather->get_input_shape(0) != gather->get_output_shape(0) || shape_size(gather->get_input_shape(2)) != 1 || number_of_indices > 10)
return false;
std::vector<int64_t> expected_vector(number_of_indices);
std::iota(expected_vector.begin(), expected_vector.end(), 0);
if (const auto& indices = get_constant_from_source(gather->input_value(1))) {
const auto& indices_values = indices->cast_vector<int64_t>();
if (indices_values != expected_vector)
return false;
}
return replace_output_update_name(gather->output(0), gather->input_value(0));
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gather_label, matcher_name);
this->register_matcher(m, callback);
}


NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSubGraph", 0);

bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(GroupedGatherElimination);
ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::EliminateGatherUnsqueeze>();
manager.register_pass<ngraph::pass::SharedShapeOf>();
manager.register_pass<ngraph::pass::GroupedGatherElimination>();
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::GatherNopElimination>();
manager.run_passes(f);
return false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,41 @@ TEST(TransformationTests, ShapeSubGraphTest) {
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}

TEST(TransformationTests, ShapeNopSubGraphTest) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);

PartialShape data_shape{-1, -1};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);

auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
auto gather_1 = gather(shape_op_1, {0}, true);
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
gather_1, opset7::Constant::create(element::i64, {1}, {0}));

auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
auto gather_2 = gather(shape_op_2, {1}, true);
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
gather_2, opset7::Constant::create(element::i64, {1}, {0}));

auto concat = std::make_shared<opset7::Concat>(OutputVector{unsqueeze_1, unsqueeze_2}, 0);

auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifyShapeOfSubGraph>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
auto reshape = std::make_shared<opset7::Reshape>(data, shape_op_1, false);
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
}

auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}

0 comments on commit 3710c0e

Please sign in to comment.