diff --git a/inference-engine/tests/functional/plugin/gna/CMakeLists.txt b/inference-engine/tests/functional/plugin/gna/CMakeLists.txt index 2d86efcc770b09..168780ca8346e4 100644 --- a/inference-engine/tests/functional/plugin/gna/CMakeLists.txt +++ b/inference-engine/tests/functional/plugin/gna/CMakeLists.txt @@ -7,13 +7,10 @@ set(TARGET_NAME gnaFuncTests) addIeTargetTest( NAME ${TARGET_NAME} ROOT ${CMAKE_CURRENT_SOURCE_DIR} - INCLUDES - ${IE_MAIN_SOURCE_DIR}/src/gna_plugin/transformations DEPENDENCIES GNAPlugin LINK_LIBRARIES funcSharedTests - GNAPlugin_test_static ADD_CPPLINT LABELS GNA diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp new file mode 100644 index 00000000000000..532bf70383c4b3 --- /dev/null +++ b/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/handle_transposes_around_matmul.hpp" + +#include "common_test_utils/ngraph_test_utils.hpp" +#include +#include +#include +#include + +std::shared_ptr CreateTransposeMatmulFunction(const ngraph::Shape& input_shape, + const ngraph::Shape& new_shape, const ngraph::Shape& const_shape) { + auto input_params = std::make_shared(ngraph::element::i64, input_shape); + + auto new_shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{new_shape.size()}, new_shape); + auto reshape = std::make_shared(input_params, new_shape_const, false); + + auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0}); + auto transpose = std::make_shared(reshape, transpose_order); + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape); + auto matmul = std::make_shared(transpose, constant); + + auto result = std::make_shared(matmul); + return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params}); +} + +std::shared_ptr CreateMatmulFunction(const ngraph::Shape& input_shape, + const ngraph::Shape& new_shape, const ngraph::Shape& const_shape) { + auto input_params = std::make_shared(ngraph::element::i64, input_shape); + + auto new_shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{new_shape.size()}, new_shape); + auto reshape = std::make_shared(input_params, new_shape_const, false); + + auto new_shape_after_transpose = ngraph::opset7::Constant::create(ngraph::element::i64, + ngraph::Shape{input_shape.size()}, {new_shape[1], new_shape[0]}); + auto reshape_after_transpose = std::make_shared(reshape, + new_shape_after_transpose, + false); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape); + auto matmul = std::make_shared(reshape_after_transpose, constant); + + auto result = std::make_shared(matmul); + return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params}); +} + +TEST(TransformationTests, RemoveTransposeBeforeMatmulTest) { + std::shared_ptr func(nullptr), reference_func(nullptr); + const ngraph::Shape data_shape{1, 8}; + + { + func = CreateTransposeMatmulFunction(data_shape, {2, 4}, {2, 1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = CreateMatmulFunction(data_shape, {2, 4}, {2, 1}); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) { + std::shared_ptr func(nullptr), reference_func(nullptr); + const ngraph::Shape data_shape{2, 8}; + + { + func = CreateTransposeMatmulFunction(data_shape, {2, 8}, {8, 1}); + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + reference_func = CreateTransposeMatmulFunction(data_shape, {2, 8}, {8, 1}); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} \ No newline at end of file