Skip to content

Commit

Permalink
Remove broadcasting (#3574)
Browse files Browse the repository at this point in the history
* Remove broadcusting

* Refactoring some code

* Add unit tests

* Update description

* Refactoring transformation

* Add is_broadcastable_shapes checks

* Update is_eliminate_broadcast func

* Add unit tests

* Update unit tests

* Add unit tests

* Add unit tests

* Remove unused include

* Add dynemic tests

* Update unit tests

* Fix code style

* Fix unit tests code style

* Fix code style

* Add one more case for elumenate broadcast

* Fix according to review

* Refactoring transformation code
  • Loading branch information
iimironov authored Dec 22, 2020
1 parent b6bba5d commit b17e0d4
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <transformations_visibility.hpp>

#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API BroadcastElementwiseFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief Removing Broadcast OP before ElementWise if output shape of Broadcast
* are equal neighboring input shape of ElementWise.
*/

class ngraph::pass::BroadcastElementwiseFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
BroadcastElementwiseFusion();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"

#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::BroadcastElementwiseFusion, "BroadcastElementwiseFusion", 0);

bool is_eliminate_broadcast(const ngraph::PartialShape & input_shape, const ngraph::PartialShape & broadcast_shape) {
if (input_shape.rank().is_dynamic() || broadcast_shape.rank().is_dynamic()) {
return false;
}

const int64_t & input_shape_rank = input_shape.rank().get_length();
const int64_t & broadcast_shape_rank = broadcast_shape.rank().get_length();
if (broadcast_shape_rank > input_shape_rank) {
//We can not eliminate broadcast op because
//in the case input_shape will be broadcasted
return false;
}
for (int64_t i_dim = input_shape_rank - 1, b_dim = broadcast_shape_rank - 1; i_dim >= 0 && b_dim >=0; --i_dim, --b_dim) {
if (input_shape[i_dim].is_static() && broadcast_shape[b_dim].is_static()) {
const auto &input_shape_dim = input_shape[i_dim].get_length();
const auto &broadcast_shape_dim = broadcast_shape[b_dim].get_length();
if (input_shape_dim != broadcast_shape_dim && broadcast_shape_dim != 1) {
//We can not eliminate broadcast op because
//input_shape will be broadcast
return false;
}
} else if (input_shape[i_dim].is_dynamic() && broadcast_shape[i_dim].is_static() &&
broadcast_shape[i_dim].get_length() != 1) {
return false;
} else if (broadcast_shape[i_dim].is_dynamic() && input_shape[i_dim].is_static() &&
input_shape[i_dim].get_length() == 1) {
return false;
} else if (broadcast_shape[i_dim].is_dynamic() && input_shape[i_dim].is_dynamic()) {
return false;
}
}
return true;
}

ngraph::pass::BroadcastElementwiseFusion::BroadcastElementwiseFusion() {
auto broadcast_input = pattern::any_input();
auto broadcast = pattern::wrap_type<ngraph::opset5::Broadcast>({broadcast_input, pattern::any_input()});
auto eltwise_input = pattern::any_input();
auto eltwise = pattern::wrap_type<op::util::BinaryElementwiseArithmetic>({eltwise_input, broadcast});

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto & pattern_value = m.get_pattern_value_map();

const auto & m_eltwise_input = pattern_value.at(eltwise_input);
const auto & m_eltwise = pattern_value.at(eltwise_input);

const auto & m_broadcast_input = pattern_value.at(broadcast_input);
auto & m_broadcast = pattern_value.at(broadcast);

if (!is_eliminate_broadcast(m_eltwise_input.get_partial_shape(),
m_broadcast.get_partial_shape())) {
return false;
}

copy_runtime_info(m_broadcast.get_node_shared_ptr(), m_eltwise.get_node_shared_ptr());
m_broadcast.replace(m_broadcast_input);

return false;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, "BroadcastElementwiseFusion");
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "transformations/init_node_info.hpp"
#include "transformations/itt.hpp"
#include "transformations/common_optimizations/algebraic_simplification.hpp"
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
#include "transformations/common_optimizations/nop_elimination.hpp"
#include "transformations/common_optimizations/common_optimizations.hpp"
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
Expand Down Expand Up @@ -62,6 +63,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
manager.register_pass<ngraph::pass::AlgebraicSimplification>(); // may introduce fake dynamism
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
Expand Down
Loading

0 comments on commit b17e0d4

Please sign in to comment.