From 15843ab855ccd3896d379df061a16ae335462977 Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Fri, 30 Jun 2023 13:28:48 +0200 Subject: [PATCH] fix transformation; add unit test --- .../transpose_sinking/ts_data_movement.cpp | 7 +-- .../transpose_sinking/ts_common_test.cpp | 50 ++++++++++++++++++- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp index 1f4aed8e724bdc..1f6986dcaff322 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_data_movement.cpp @@ -12,6 +12,7 @@ #include "openvino/op/space_to_batch.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/util/op_types.hpp" +#include "openvino/op/util/pad_base.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/util/common_util.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp" @@ -25,7 +26,7 @@ using namespace ov::pass::transpose_sinking::utils; namespace { std::vector get_indices_by_op_type(const std::shared_ptr& main_node) { - if (as_type_ptr(main_node)) { + if (as_type_ptr(main_node)) { return {1, 2}; } else if (as_type_ptr(main_node) || as_type_ptr(main_node)) { return {1, 2, 3}; @@ -38,7 +39,7 @@ std::vector get_indices_by_op_type(const std::shared_ptr& main_nod TSDataMovementForward::TSDataMovementForward() { MATCHER_SCOPE(TSDataMovementForward); - create_pattern( + create_pattern( true, {0}); @@ -74,7 +75,7 @@ TSDataMovementBackward::TSDataMovementBackward() { MATCHER_SCOPE(TSDataMovementBackward); auto main_node_label = - wrap_type( + wrap_type( [](const Output& output) -> bool { return has_static_rank()(output) && CheckTransposeConsumers(output); }); diff --git a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp index 7a7a483c1d58e5..336a89462c57d0 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp @@ -114,13 +114,30 @@ class PadFactory : public IFactory { public: explicit PadFactory(const std::string& type_name) : IFactory(type_name) {} NodePtr create(const OutputVector& parent_nodes) const override { - return std::make_shared(parent_nodes[0], parent_nodes[1], parent_nodes[2], ov::op::PadMode::CONSTANT); + return std::make_shared(parent_nodes[0], + parent_nodes[1], + parent_nodes[2], + ov::op::PadMode::CONSTANT); } }; FactoryPtr CreatePadFactory(const std::string& type_name) { return std::make_shared(type_name); } +class Pad12Factory : public IFactory { +public: + explicit Pad12Factory(const std::string& type_name) : IFactory(type_name) {} + NodePtr create(const OutputVector& parent_nodes) const override { + return std::make_shared(parent_nodes[0], + parent_nodes[1], + parent_nodes[2], + ov::op::PadMode::CONSTANT); + } +}; +FactoryPtr CreatePad12Factory(const std::string& type_name) { + return std::make_shared(type_name); +} + class BatchToSpaceFactory : public IFactory { public: explicit BatchToSpaceFactory(const std::string& type_name) : IFactory(type_name) {} @@ -253,6 +270,9 @@ FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) { #undef CREATE_PAD_FACTORY #define CREATE_PAD_FACTORY(type_name) CreatePadFactory(#type_name) +#undef CREATE_PAD12_FACTORY +#define CREATE_PAD12_FACTORY(type_name) CreatePad12Factory(#type_name) + #undef CREATE_BATCH_TO_SPACE_FACTORY #define CREATE_BATCH_TO_SPACE_FACTORY(type_name) CreateBatchToSpaceFactory(#type_name) @@ -538,6 +558,34 @@ auto test_forward_pad = []() { INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TSTestFixture, test_forward_pad()); +auto test_negative_forward_pad = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward); + test_case.num_main_ops = {1, 2}; + test_case.inputs_to_main = { + parameter(element::f32, {1, 3, 55, 55}), + constant(element::i32, {4}, {1, -2, -3, -4}), + constant(element::i32, {4}, {1, -2, -3, -4}), + }; + + // Test model description: + test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; + test_case.model.main_op = {CREATE_PAD12_FACTORY(Pad12)}; + test_case.model.model_template = create_model; + + // Reference model description: + test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}}; + test_case.model_ref.main_op = {CREATE_PAD12_FACTORY(Pad12)}; + test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonNegativePad12Forward, TSTestFixture, test_negative_forward_pad()); + auto test_forward_batch_to_space = []() { TestCase test_case;