Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional tests for HandleTransposesAroundMatmul were added #1

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include <numeric>

namespace handle_transpose_before_matmul {

std::shared_ptr<ngraph::Function> CreateTransposeMatmulFunction(const ngraph::Shape& input_shape,
const ngraph::Shape& new_shape, const ngraph::Shape& const_shape) {
Expand All @@ -21,7 +24,10 @@ std::shared_ptr<ngraph::Function> CreateTransposeMatmulFunction(const ngraph::Sh

auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
auto transpose = std::make_shared<ngraph::opset7::Transpose>(reshape, transpose_order);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape);

std::vector<size_t> data(ngraph::shape_size(const_shape));
std::iota(std::begin(data), std::end(data), 1);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, const_shape, data);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(transpose, constant);

auto result = std::make_shared<ngraph::opset7::Result>(matmul);
Expand All @@ -41,27 +47,94 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& inpu
new_shape_after_transpose,
false);

auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{const_shape.size()}, const_shape);
std::vector<size_t> data(ngraph::shape_size(const_shape));
std::iota(std::begin(data), std::end(data), 1);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, const_shape, data);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(reshape_after_transpose, constant);

auto result = std::make_shared<ngraph::opset7::Result>(matmul);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}

} // namespace handle_transpose_before_matmul

namespace handle_transpose_after_matmul {

std::shared_ptr<ngraph::Function> CreateTransposeMatmulFunction(const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_after_transpose) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);

std::vector<size_t> data(ngraph::shape_size(matmul_shape));
std::iota(std::begin(data), std::end(data), 1);
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
const auto matmul_output_shape = matmul->get_output_shape(0);
std::cout << "matmul=[" << matmul_output_shape[0] << ", " << matmul_output_shape[1] << "]\n";

auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
auto transpose = std::make_shared<ngraph::opset7::Transpose>(matmul, transpose_order);
const auto transpose_output_shape = transpose->get_output_shape(0);
std::cout << "transpose=[" << transpose_output_shape[0] << ", " << transpose_output_shape[1] << "]\n";

std::shared_ptr<ngraph::opset7::Reshape> reshape;
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
if (create_reshape_after_transpose) {
const auto matmul_output_shape = matmul->get_output_shape(0);
auto reshape_after_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
ngraph::Shape{matmul_output_shape.size()}, matmul_output_shape);
auto reshape_after_transpose = std::make_shared<ngraph::opset7::Reshape>(transpose, reshape_after_transpose_const, false);
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_after_transpose, shape_const, false);
} else {
reshape = std::make_shared<ngraph::opset7::Reshape>(transpose, shape_const, false);
const auto reshape_output_shape = reshape->get_output_shape(0);
std::cout << "reshape=[" << reshape_output_shape[0] << ", " << reshape_output_shape[1] << "]\n";
}

auto result = std::make_shared<ngraph::opset7::Result>(reshape);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}

std::shared_ptr<ngraph::Function> CreateMatmulFunction(const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_instead_of_transpose) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);

std::vector<size_t> data(ngraph::shape_size(matmul_shape));
std::iota(std::begin(data), std::end(data), 1);
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
auto matmul = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);

std::shared_ptr<ngraph::opset7::Reshape> reshape;
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
if (create_reshape_instead_of_transpose) {
const auto matmul_output_shape = matmul->get_output_shape(0);
auto reshape_instead_of_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64,
ngraph::Shape{matmul_output_shape.size()}, {matmul_output_shape[1], matmul_output_shape[0]});
auto reshape_instead_of_transpose = std::make_shared<ngraph::opset7::Reshape>(matmul, reshape_instead_of_transpose_const, false);
reshape = std::make_shared<ngraph::opset7::Reshape>(reshape_instead_of_transpose, shape_const, false);
} else {
reshape = std::make_shared<ngraph::opset7::Reshape>(matmul, shape_const, false);
}

auto result = std::make_shared<ngraph::opset7::Result>(reshape);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params});
}

} // namespace handle_transpose_after_matmul

TEST(TransformationTests, RemoveTransposeBeforeMatmulTest) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{1, 8};

{
func = CreateTransposeMatmulFunction(data_shape, {2, 4}, {2, 1});
func = handle_transpose_before_matmul::CreateTransposeMatmulFunction(data_shape, {2, 4}, {2, 1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = CreateMatmulFunction(data_shape, {2, 4}, {2, 1});
reference_func = handle_transpose_before_matmul::CreateMatmulFunction(data_shape, {2, 4}, {2, 1});

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
Expand All @@ -73,17 +146,93 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
const ngraph::Shape data_shape{2, 8};

{
func = CreateTransposeMatmulFunction(data_shape, {2, 8}, {8, 1});
func = handle_transpose_before_matmul::CreateTransposeMatmulFunction(data_shape, {2, 8}, {2, 5});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = handle_transpose_before_matmul::CreateTransposeMatmulFunction(data_shape, {2, 8}, {2, 5});

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, false);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = handle_transpose_after_matmul::CreateTransposeMatmulFunction({4, 1}, {1, 8}, {2, 16}, true);

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = handle_transpose_after_matmul::CreateTransposeMatmulFunction({4, 1}, {1, 8}, {2, 16}, false);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, true);

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = handle_transpose_after_matmul::CreateTransposeMatmulFunction({4, 1}, {1, 8}, {8, 4}, false);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = handle_transpose_after_matmul::CreateTransposeMatmulFunction({4, 1}, {1, 8}, {8, 4}, false);

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}

TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);

{
func = handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::HandleTransposesAroundMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}

reference_func = CreateTransposeMatmulFunction(data_shape, {2, 8}, {8, 1});
reference_func = handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false);

const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
}