Skip to content

Commit

Permalink
Added tests for transformation HandleTransposesAroundMatMul
Browse files Browse the repository at this point in the history
  • Loading branch information
elilobanova committed Jun 29, 2021
1 parent b02adb0 commit b7e2062
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
3 changes: 0 additions & 3 deletions inference-engine/tests/functional/plugin/gna/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ set(TARGET_NAME gnaFuncTests)
addIeTargetTest(
NAME ${TARGET_NAME}
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
INCLUDES
${IE_MAIN_SOURCE_DIR}/src/gna_plugin/transformations
DEPENDENCIES
GNAPlugin
LINK_LIBRARIES
funcSharedTests
GNAPlugin_test_static
ADD_CPPLINT
LABELS
GNA
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include "transformations/handle_transposes_around_matmul.hpp"

#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>

std::shared_ptr<ngraph::Function> CreateTransposeMatmulFunction(const ngraph::Shape& input_shape,
const ngraph::Shape& new_shape, const ngraph::Shape& const_shape) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);

auto new_shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{new_shape.size()}, new_shape);
auto reshape = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape_const, false);

auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
auto transpose = std::make_shared<ngraph::opset7::Transpose>(reshape, transpose_order);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(transpose, constant);

auto result = std::make_shared<ngraph::opset7::Result>(matmul);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}

std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& input_shape,
const ngraph::Shape& new_shape, const ngraph::Shape& const_shape) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);

auto new_shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{new_shape.size()}, new_shape);
auto reshape = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape_const, false);

auto new_shape_after_transpose = ngraph::opset7::Constant::create(ngraph::element::i64,
ngraph::Shape{input_shape.size()}, {new_shape[1], new_shape[0]});
auto reshape_after_transpose = std::make_shared<ngraph::opset7::Reshape>(reshape,
new_shape_after_transpose,
false);

auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(reshape_after_transpose, constant);

auto result = std::make_shared<ngraph::opset7::Result>(matmul);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}

TEST(TransformationTests, RemoveTransposeBeforeMatmulTest) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{1, 8};

{
func = CreateTransposeMatmulFunction(data_shape, {2, 4}, {2, 1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = CreateMatmulFunction(data_shape, {2, 4}, {2, 1});

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{2, 8};

{
func = CreateTransposeMatmulFunction(data_shape, {2, 8}, {8, 1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = CreateTransposeMatmulFunction(data_shape, {2, 8}, {8, 1});

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

0 comments on commit b7e2062

Please sign in to comment.