Skip to content

Commit

Permalink
fix transformation; add unit test (#18314)
Browse files Browse the repository at this point in the history
Co-authored-by: Ivan Tikhonov <[email protected]>
  • Loading branch information
evkotov and itikhono authored Jul 5, 2023
1 parent f313dde commit 6a0c6a1
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "openvino/op/space_to_batch.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/op/util/pad_base.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
Expand All @@ -25,7 +26,7 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {

std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_node) {
if (as_type_ptr<ov::op::v1::Pad>(main_node)) {
if (as_type_ptr<ov::op::util::PadBase>(main_node)) {
return {1, 2};
} else if (as_type_ptr<ov::op::v1::BatchToSpace>(main_node) || as_type_ptr<ov::op::v1::SpaceToBatch>(main_node)) {
return {1, 2, 3};
Expand All @@ -38,7 +39,7 @@ std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_nod

TSDataMovementForward::TSDataMovementForward() {
MATCHER_SCOPE(TSDataMovementForward);
create_pattern<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
create_pattern<op::util::PadBase, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
true,
{0});

Expand Down Expand Up @@ -74,7 +75,7 @@ TSDataMovementBackward::TSDataMovementBackward() {
MATCHER_SCOPE(TSDataMovementBackward);

auto main_node_label =
wrap_type<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
wrap_type<op::util::PadBase, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
[](const Output<Node>& output) -> bool {
return has_static_rank()(output) && CheckTransposeConsumers(output);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,30 @@ class PadFactory : public IFactory {
public:
explicit PadFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Pad>(parent_nodes[0], parent_nodes[1], parent_nodes[2], ov::op::PadMode::CONSTANT);
return std::make_shared<ov::op::v1::Pad>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
ov::op::PadMode::CONSTANT);
}
};
FactoryPtr CreatePadFactory(const std::string& type_name) {
return std::make_shared<PadFactory>(type_name);
}

class Pad12Factory : public IFactory {
public:
explicit Pad12Factory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<ov::op::v12::Pad>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
ov::op::PadMode::CONSTANT);
}
};
FactoryPtr CreatePad12Factory(const std::string& type_name) {
return std::make_shared<Pad12Factory>(type_name);
}

class BatchToSpaceFactory : public IFactory {
public:
explicit BatchToSpaceFactory(const std::string& type_name) : IFactory(type_name) {}
Expand Down Expand Up @@ -253,6 +270,9 @@ FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
#undef CREATE_PAD_FACTORY
#define CREATE_PAD_FACTORY(type_name) CreatePadFactory(#type_name)

#undef CREATE_PAD12_FACTORY
#define CREATE_PAD12_FACTORY(type_name) CreatePad12Factory(#type_name)

#undef CREATE_BATCH_TO_SPACE_FACTORY
#define CREATE_BATCH_TO_SPACE_FACTORY(type_name) CreateBatchToSpaceFactory(#type_name)

Expand Down Expand Up @@ -538,6 +558,34 @@ auto test_forward_pad = []() {

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TSTestFixture, test_forward_pad());

auto test_negative_forward_pad = []() {
TestCase test_case;

// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 3, 55, 55}),
constant<int64_t>(element::i32, {4}, {1, -2, -3, -4}),
constant<int64_t>(element::i32, {4}, {1, -2, -3, -4}),
};

// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_PAD12_FACTORY(Pad12)};
test_case.model.model_template = create_model;

// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}};
test_case.model_ref.main_op = {CREATE_PAD12_FACTORY(Pad12)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;

return wrapper(test_case);
};

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonNegativePad12Forward, TSTestFixture, test_negative_forward_pad());

auto test_forward_batch_to_space = []() {
TestCase test_case;

Expand Down

0 comments on commit 6a0c6a1

Please sign in to comment.