Skip to content

Commit

Permalink
EliminatePad and PadFusion support Pad12 positive indexes (#18278)
Browse files Browse the repository at this point in the history
* update opset5::Pad -> PadBase

* rewrite unit tests

* refactor unit tests

* add unit test NegativePadElimination

* fix add destructor

* clang fixes

* fix unit tests

* add unit tests

* bug fix

---------

Co-authored-by: Ivan Tikhonov <[email protected]>
  • Loading branch information
evkotov and itikhono authored Jul 5, 2023
1 parent 8a76f4e commit f313dde
Show file tree
Hide file tree
Showing 3 changed files with 689 additions and 449 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ngraph/util.hpp>
#include <numeric>
#include <openvino/core/validation_util.hpp>
#include <openvino/op/util/pad_base.hpp>
#include <openvino/opsets/opset3.hpp>
#include <openvino/opsets/opset7.hpp>
#include <openvino/opsets/opset8.hpp>
Expand Down Expand Up @@ -315,7 +316,7 @@ SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, simplify_gather, opset3::Gather,

pass::EliminatePad::EliminatePad() {
MATCHER_SCOPE(EliminatePad);
auto pad_node_pattern = pattern::wrap_type<opset8::Pad>();
auto pad_node_pattern = pattern::wrap_type<op::util::PadBase>();

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pad = m.get_match_root();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/validation_util.hpp>
#include <openvino/op/util/pad_base.hpp>
#include <openvino/opsets/opset5.hpp>
#include <vector>

Expand All @@ -17,7 +18,7 @@
using namespace ov;

template <typename T>
static bool can_be_fused(const std::shared_ptr<opset5::Pad>& pad,
static bool can_be_fused(const std::shared_ptr<op::util::PadBase>& pad,
const std::shared_ptr<T>& node,
const std::shared_ptr<Node>& pad_value_node,
const std::shared_ptr<opset5::Constant>& pads_begin,
Expand Down Expand Up @@ -96,14 +97,14 @@ pass::PadFusionAvgPool::PadFusionAvgPool() {
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
auto pad_value_pattern = pattern::any_input();
auto pad_node_pattern =
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
auto avg_pool_pattern = pattern::wrap_type<opset5::AvgPool>({pad_node_pattern});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map[data_pattern];
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
auto pads_begin =
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
Expand Down Expand Up @@ -196,15 +197,16 @@ pass::PadFusionConvolution::PadFusionConvolution() {
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
auto pad_value_pattern = pattern::any_input();
auto pad_node_pattern =
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
auto conv_pattern = pattern::wrap_type<opset5::Convolution>({pad_node_pattern, filter_pattern});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
std::cout << "[EMUTEX DEBUG] CHECKPOINT PadFusionConvolution" << std::endl;
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map[data_pattern];
auto filter = pattern_map[filter_pattern];
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
auto pads_begin =
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
Expand Down Expand Up @@ -243,15 +245,15 @@ pass::PadFusionConvolutionBackpropData::PadFusionConvolutionBackpropData() {
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
auto pad_value_pattern = pattern::any_input();
auto pad_node_pattern =
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
auto conv_pattern = pattern::wrap_type<opset5::ConvolutionBackpropData>({pad_node_pattern, filter_pattern});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map[data_pattern];
auto filter = pattern_map[filter_pattern];
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
auto pads_begin =
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
Expand Down Expand Up @@ -301,15 +303,15 @@ pass::PadFusionGroupConvolution::PadFusionGroupConvolution() {
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
auto pad_value_pattern = pattern::any_input();
auto pad_node_pattern =
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
auto conv_pattern = pattern::wrap_type<opset5::GroupConvolution>({pad_node_pattern, filter_pattern});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map[data_pattern];
auto filter = pattern_map[filter_pattern];
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
auto pads_begin =
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
Expand Down Expand Up @@ -349,15 +351,15 @@ pass::PadFusionGroupConvolutionBackpropData::PadFusionGroupConvolutionBackpropDa
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
auto pad_value_pattern = pattern::any_input();
auto pad_node_pattern =
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
pattern::consumers_count(1));
auto conv_pattern = pattern::wrap_type<opset5::GroupConvolutionBackpropData>({pad_node_pattern, filter_pattern});

matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map[data_pattern];
auto filter = pattern_map[filter_pattern];
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
auto pads_begin =
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
Expand Down
Loading

0 comments on commit f313dde

Please sign in to comment.