From 5dfba630ac973a96b27aa294dc0e090d8fce8b4c Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Tue, 14 Feb 2023 14:16:15 +0100 Subject: [PATCH] wrote unit tests draft --- .../gather_sinking_transpose_reshape.cpp | 13 + .../gather_sinking_transpose_reshape.cpp | 224 +++++++++++++++++- 2 files changed, 229 insertions(+), 8 deletions(-) diff --git a/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp b/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp index 98bc72fb8c59d8..4aee4c244faf4e 100644 --- a/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp +++ b/src/plugins/intel_gna/src/transformations/gather_sinking_transpose_reshape.cpp @@ -3,6 +3,7 @@ // #include +#include #include "transformations/gather_sinking_transpose_reshape.hpp" @@ -45,6 +46,9 @@ NodePair SinkForward(NodePtr transpose, NodePtr reshape) { ov::replace_node(reshape, gather); + ov::copy_runtime_info({reshape}, {gather, gather_indices, gather_axis, reshape_new}); + gather->set_friendly_name(reshape->get_friendly_name()); + return std::make_pair(reshape_new, gather); } @@ -82,11 +86,17 @@ NodePair SinkBackward(NodePtr transpose, std::shared_ptr transpose_con ov::replace_node(transpose, reshape_new); + ov::copy_runtime_info({transpose}, {gather, gather_indices, gather_axis, reshape_new, reshape_const_new}); + reshape_new->set_friendly_name(transpose->get_friendly_name()); + return std::make_pair(transpose, reshape_new); } bool IsFlatten2D(const Output& output) { std::shared_ptr reshape_node = output.get_node_shared_ptr(); + if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() || + reshape_node->get_input_partial_shape(0).rank().is_dynamic()) + return false; const Shape& input_shape = reshape_node->get_input_shape(0); const Shape& output_shape = reshape_node->get_output_shape(0); return (input_shape.size() == 3 && @@ -97,6 +107,9 @@ bool IsFlatten2D(const Output& output) { bool IsUnflatten2D(const Output& output) { std::shared_ptr reshape_node = output.get_node_shared_ptr(); + if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() || + reshape_node->get_input_partial_shape(0).rank().is_dynamic()) + return false; const Shape& input_shape = reshape_node->get_input_shape(0); const Shape& output_shape = reshape_node->get_output_shape(0); return (input_shape.size() == 2 && diff --git a/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_transpose_reshape.cpp b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_transpose_reshape.cpp index 3c847ec6a1827f..29a8236c61fe14 100644 --- a/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_transpose_reshape.cpp +++ b/src/plugins/intel_gna/tests/unit/transformations/gather_sinking_transpose_reshape.cpp @@ -79,6 +79,31 @@ void CompareOutput(std::shared_ptr function, std::shared_ptr +std::shared_ptr FindNode(std::shared_ptr model) { + for (auto op : model->get_ops()) { + auto node = as_type_ptr(op); + if (node) + return node; + } + return {}; +} + +void PrintConstant(std::shared_ptr node) { + auto constant = as_type_ptr(node); + if (!constant) + return; + auto value = constant->cast_vector(); + std::cout << "{ "; + for (int i = 0; i < value.size(); ++i) { + if (i) + std::cout << ", "; + std::cout << value[i]; + } + std::cout << " }" << std::endl; +} + } // namespace TEST(GatherSinkingTransposeReshape, ForwardSinking) { @@ -98,15 +123,47 @@ TEST(GatherSinkingTransposeReshape, ForwardSinking) { function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); } - std::shared_ptr reference_function = function->clone(); + std::shared_ptr orig_function = function->clone(); ov::pass::Manager manager; manager.register_pass(); - manager.register_pass("./0before.png"); + //manager.register_pass("./0before.png"); manager.register_pass(); - manager.register_pass("./1after.png"); + //manager.register_pass("./1after.png"); manager.run_passes(function); + ASSERT_NO_THROW(check_rt_info(function)); + + CompareOutput(function, orig_function); - CompareOutput(function, reference_function); + std::shared_ptr reference_function; + { + auto input_params = std::make_shared(element::Type_t::f32, Shape{1, 3, 80}); + auto tanh0 = std::make_shared(input_params); + + auto reshape_const = std::make_shared(element::i64, Shape{2}, std::vector{1, -1}); + auto reshape = std::make_shared(tanh0, reshape_const, false); + + auto generate_indices = []() -> std::vector { + std::vector indices; + for (int i = 0; i < 80; ++i) { + indices.push_back(i); + indices.push_back(i + 80); + indices.push_back(i + 160); + } + return indices; + }; + auto gather_indices = generate_indices(); + auto gather_indices_const = std::make_shared(element::i64, Shape{gather_indices.size()}, gather_indices); + auto gather_axis_const = std::make_shared(element::i64, Shape{}, 1); + auto gather = std::make_shared(reshape, gather_indices_const, gather_axis_const); + + auto tanh1 = std::make_shared(gather); + const auto result = std::make_shared(tanh1); + reference_function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); + } + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, reference_function); + ASSERT_TRUE(result.valid); } TEST(GatherSinkingTransposeReshape, BackwardSinking) { @@ -126,16 +183,167 @@ TEST(GatherSinkingTransposeReshape, BackwardSinking) { function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); } - std::shared_ptr reference_function = function->clone(); + std::shared_ptr orig_function = function->clone(); ov::pass::Manager manager; manager.register_pass(); - manager.register_pass("./0before.png"); + //manager.register_pass("./0before.png"); manager.register_pass(); - manager.register_pass("./1after.png"); + //manager.register_pass("./1after.png"); + manager.run_passes(function); + ASSERT_NO_THROW(check_rt_info(function)); + + CompareOutput(function, orig_function); + + std::shared_ptr reference_function; + { + auto input_params = std::make_shared(element::Type_t::f32, Shape{1, 240}); + auto tanh0 = std::make_shared(input_params); + + auto generate_indices = []() -> std::vector { + std::vector indices; + for (int i = 0; i < 80; ++i) { + indices.push_back(i); + indices.push_back(i + 80); + indices.push_back(i + 160); + } + return indices; + }; + auto gather_indices = generate_indices(); + auto gather_indices_const = std::make_shared(element::i64, Shape{gather_indices.size()}, gather_indices); + auto gather_axis_const = std::make_shared(element::i64, Shape{}, 1); + auto gather = std::make_shared(tanh0, gather_indices_const, gather_axis_const); + + auto reshape_const = std::make_shared(element::i64, Shape{3}, std::vector{1, 80, 3}); + auto reshape = std::make_shared(gather, reshape_const, false); + + auto tanh1 = std::make_shared(reshape); + const auto result = std::make_shared(tanh1); + reference_function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); + } + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, reference_function); + ASSERT_TRUE(result.valid) << result.message; +} + +TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink1) { + std::shared_ptr function; + { + auto input_params = std::make_shared(element::Type_t::f32, Shape{1, 3, 80}); + auto tanh0 = std::make_shared(input_params); + + auto transpose_order = std::make_shared(element::u64, Shape{3}, Shape{0, 2, 1}); + auto transpose = std::make_shared(tanh0, transpose_order); + + auto reshape_const = std::make_shared(element::i64, Shape{4}, std::vector{1, 3, 80, 1}); + auto reshape = std::make_shared(transpose, reshape_const, false); + + auto tanh1 = std::make_shared(reshape); + const auto result = std::make_shared(tanh1); + function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); + } + + std::shared_ptr orig_function = function->clone(); + ov::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); manager.run_passes(function); + ASSERT_NO_THROW(check_rt_info(function)); - CompareOutput(function, reference_function); + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, orig_function); + ASSERT_TRUE(result.valid); } +TEST(GatherSinkingTransposeReshape, ForwardSinkingNoSink2) { + std::shared_ptr function; + { + auto input_params = std::make_shared(element::Type_t::f32, Shape{1, 4, 80}); + auto tanh0 = std::make_shared(input_params); + + auto transpose_order = std::make_shared(element::u64, Shape{3}, Shape{0, 2, 1}); + auto transpose = std::make_shared(tanh0, transpose_order); + + auto reshape_const = std::make_shared(element::i64, Shape{4}, std::vector{1, 2, 80, 2}); + auto reshape = std::make_shared(transpose, reshape_const, false); + + auto tanh1 = std::make_shared(reshape); + const auto result = std::make_shared(tanh1); + function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); + } + + std::shared_ptr orig_function = function->clone(); + ov::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(function); + ASSERT_NO_THROW(check_rt_info(function)); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, orig_function); + ASSERT_TRUE(result.valid); +} + +TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink1) { + std::shared_ptr function; + { + auto input_params = std::make_shared(element::Type_t::f32, Shape{1, 240}); + auto tanh0 = std::make_shared(input_params); + + auto reshape_const = std::make_shared(element::i64, Shape{4}, std::vector{1, 3, 80, 1}); + auto reshape = std::make_shared(tanh0, reshape_const, false); + + auto transpose_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 1, 3}); + auto transpose = std::make_shared(reshape, transpose_order); + + auto tanh1 = std::make_shared(transpose); + const auto result = std::make_shared(tanh1); + function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); + } + + std::shared_ptr orig_function = function->clone(); + ov::pass::Manager manager; + manager.register_pass(); + //manager.register_pass("./0before.png"); + manager.register_pass(); + //manager.register_pass("./1after.png"); + manager.run_passes(function); + ASSERT_NO_THROW(check_rt_info(function)); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, orig_function); + ASSERT_TRUE(result.valid) << result.message; +} + +TEST(GatherSinkingTransposeReshape, BackwardSinkingNoSink2) { + std::shared_ptr function; + { + auto input_params = std::make_shared(element::Type_t::f32, Shape{1, 320}); + auto tanh0 = std::make_shared(input_params); + + auto reshape_const = std::make_shared(element::i64, Shape{4}, std::vector{1, 2, 80, 2}); + auto reshape = std::make_shared(tanh0, reshape_const, false); + + auto transpose_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 1, 3}); + auto transpose = std::make_shared(reshape, transpose_order); + + auto tanh1 = std::make_shared(transpose); + const auto result = std::make_shared(tanh1); + function = std::make_shared(OutputVector{result}, ParameterVector{input_params}); + } + + std::shared_ptr orig_function = function->clone(); + ov::pass::Manager manager; + manager.register_pass(); + //manager.register_pass("./0before.png"); + manager.register_pass(); + //manager.register_pass("./1after.png"); + manager.run_passes(function); + ASSERT_NO_THROW(check_rt_info(function)); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, orig_function); + ASSERT_TRUE(result.valid) << result.message; +} } // namespace testing