Skip to content

Commit

Permalink
Unit tests to check fuse_transpose_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov committed Dec 1, 2022
1 parent be9ac81 commit 38bd8cc
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/common/snippets/include/snippets/snippets_isa_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

// SnippetS dialect
NGRAPH_OP(Load, ngraph::snippets::op)
NGRAPH_OP(LoopBegin, ngraph::snippets::op)
NGRAPH_OP(LoopEnd, ngraph::snippets::op)
NGRAPH_OP(Brgemm, ngraph::snippets::op)
NGRAPH_OP(BroadcastLoad, ngraph::snippets::op)

NGRAPH_OP(Store, ngraph::snippets::op)
Expand Down
32 changes: 32 additions & 0 deletions src/common/snippets/tests/include/pass/fuse_transpose_brgemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "lowering_utils.hpp"
#include "snippets_helpers.hpp"

/* The main purpose is to test that FuseTransposeBrgemm properly fuses 0213 Transposes on both inputs, as well as on output
*/

namespace ov {
namespace test {
namespace snippets {

typedef std::tuple<
std::vector<PartialShape>, // Input shapes
size_t // Transpose position
> fuseTransposeBrgemmParams;

class FuseTransposeBrgemmTests : public LoweringTests, public testing::WithParamInterface<fuseTransposeBrgemmParams> {
public:
static std::string getTestCaseName(testing::TestParamInfo<fuseTransposeBrgemmParams> obj);
protected:
void SetUp() override;
std::shared_ptr<SnippetsFunctionBase> snippets_function;
};

} // namespace snippets
} // namespace test
} // namespace ov
1 change: 1 addition & 0 deletions src/common/snippets/tests/src/lowering_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ DummyTargetMachine::DummyTargetMachine() {
jitters[ngraph::snippets::op::Kernel::get_type_info_static()] = dummy_functor;
jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = dummy_functor;
jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = dummy_functor;
jitters[ngraph::snippets::op::Brgemm::get_type_info_static()] = dummy_functor;
}

void LoweringTests::SetUp() {
Expand Down
56 changes: 56 additions & 0 deletions src/common/snippets/tests/src/pass/fuse_transpose_brgemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>
#include "pass/fuse_transpose_brgemm.hpp"
#include "common_test_utils/common_utils.hpp"
#include "subgraph_matmul.hpp"
#include "subgraph_lowered.hpp"

namespace ov {
namespace test {
namespace snippets {

std::string FuseTransposeBrgemmTests::getTestCaseName(testing::TestParamInfo<fuseTransposeBrgemmParams> obj) {
std::vector<PartialShape> input_shapes(2);
size_t transpose_position;
std::tie(input_shapes, transpose_position) = obj.param;
std::ostringstream result;
result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_";
result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_";
result << "Pos=" << transpose_position << "_";
return result.str();
}

void FuseTransposeBrgemmTests::SetUp() {
LoweringTests::SetUp();
std::vector<PartialShape> input_shapes(2);
size_t transpose_position;
std::tie(input_shapes, transpose_position) = this->GetParam();

snippets_function = std::make_shared<Transpose0213MatMulSinhLoweredFunction>(input_shapes, transpose_position);
}

TEST_P(FuseTransposeBrgemmTests, FuseTransposeMatmul) {
auto subgraph = getLoweredSubgraph(snippets_function->getOriginal(), master_shape);
function = subgraph->get_body();
function_ref = snippets_function->getLowered();
}

namespace FuseTransposeBrgemmTestsInstantiation {
using ov::Shape;
std::vector<fuseTransposeBrgemmParams> test_params{
{{{1, 49, 2, 23}, {2, 2, 23, 39}}, 0},
{{{1, 2, 49, 23}, {2, 23, 1, 39}}, 1},
{{{1, 2, 49, 23}, {2, 2, 23, 39}}, 2},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_FuseTransposeMatMul, FuseTransposeBrgemmTests,
::testing::ValuesIn(test_params),
FuseTransposeBrgemmTests::getTestCaseName);

} // namespace FuseTransposeBrgemmTestsInstantiation
} // namespace snippets
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
#include "snippets_helpers.hpp"
#include "subgraph_simple.hpp"
#include "subgraph_converts.hpp"
#include "subgraph_matmul.hpp"

/* This file provides lowered representations (after the generate() was calles) for some simple functions.
/* This file provides lowered representations (after the generate() was called) for some simple functions.
* This is required to test snippets lowering and optimization passes. All the functions are expected to be direct
* descendants of SnippetsFunctionCustomizable (defined here) and one of the SnippetsFunctionBase derived classes
* (declared in subgraph_simple.hpp). Note that the corresponding SnippetsFunctionBase child should use virtual inheritance
Expand Down Expand Up @@ -51,6 +52,16 @@ class EltwiseThreeInputsLoweredFunction : public EltwiseThreeInputsFunction {
std::vector<Shape> broadcast_shapes;
};

class Transpose0213MatMulSinhLoweredFunction : public Transpose0213MatMulSinhFunction {
public:
explicit Transpose0213MatMulSinhLoweredFunction(const std::vector<PartialShape>& inputShapes, size_t position = 0) :
Transpose0213MatMulSinhFunction(inputShapes, position, false) {
}

protected:
std::shared_ptr<ov::Model> initLowered() const override;
};

} // namespace snippets
} // namespace test
} // namespace ov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ class MatMulSinhFunction : public SnippetsFunctionBase {
// todo: remove Sinh once "no subgraph after input" limitation is relaxed
class Transpose0213MatMulSinhFunction : public SnippetsFunctionBase {
public:
explicit Transpose0213MatMulSinhFunction(const std::vector<PartialShape>& inputShapes, size_t position = 0)
: SnippetsFunctionBase(inputShapes), transpose_position(position) {
explicit Transpose0213MatMulSinhFunction(const std::vector<PartialShape>& inputShapes, size_t position = 0,
bool insert_guard = true)
: SnippetsFunctionBase(inputShapes), transpose_position(position), insert_guard(insert_guard) {
NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes");
NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4,
"Only rank 4 input shapes are supported by this test");
Expand All @@ -55,6 +56,7 @@ class Transpose0213MatMulSinhFunction : public SnippetsFunctionBase {
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
size_t transpose_position;
bool insert_guard; // true if Sinh ops should be inserted after inputs
};

} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@ std::shared_ptr<ov::Model> EltwiseThreeInputsLoweredFunction::initLowered() cons
}
return model;
}

std::shared_ptr<ov::Model> Transpose0213MatMulSinhLoweredFunction::initLowered() const {
ParameterVector data{std::make_shared<op::v0::Parameter>(precision, input_shapes[0]),
std::make_shared<op::v0::Parameter>(precision, input_shapes[1])};
std::vector<size_t> layout{0, 2, 1, 3};
// Note: validity of transpose_position values is checked in Transpose0213MatMulSinhFunction constructor
if (transpose_position <= 1) {
auto& rt_info = data[transpose_position]->get_rt_info();
rt_info["Layout"] = layout;
}
auto matmul = std::make_shared<ngraph::snippets::op::Brgemm>(data[0], data[1]);
if (transpose_position == 2) {
auto& rt_info = matmul->get_rt_info();
rt_info["Layout"] = layout;
matmul->validate_and_infer_types();
}
return std::make_shared<ov::Model>(NodeVector{matmul}, data);
}
} // namespace snippets
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ std::shared_ptr<ov::Model> MatMulSinhFunction::initReference() const {
}
std::shared_ptr<ov::Model> Transpose0213MatMulSinhFunction::initOriginal() const {
auto data0 = std::make_shared<op::v0::Parameter>(precision, input_shapes[0]);
auto sinh0 = std::make_shared<ov::op::v0::Sinh>(data0);
auto data0_guarded = insert_guard ? std::make_shared<ov::op::v0::Sinh>(data0)->output(0) : data0->output(0);
auto data1 = std::make_shared<op::v0::Parameter>(precision, input_shapes[1]);
auto sinh1 = std::make_shared<ov::op::v0::Sinh>(data1);
auto data1_guarded = insert_guard ? std::make_shared<ov::op::v0::Sinh>(data1)->output(0) : data1->output(0);
auto const_order = std::make_shared<op::v0::Constant>(ov::element::i32, Shape {4}, std::vector<int>{0, 2, 1, 3});
std::shared_ptr<Node> result;
switch (transpose_position) {
case 0: {
auto transpose = std::make_shared<op::v1::Transpose>(sinh0, const_order);
result = std::make_shared<op::v0::MatMul>(transpose, sinh1);
auto transpose = std::make_shared<op::v1::Transpose>(data0_guarded, const_order);
result = std::make_shared<op::v0::MatMul>(transpose, data1_guarded);
break;
} case 1: {
auto transpose = std::make_shared<op::v1::Transpose>(sinh1, const_order);
result = std::make_shared<op::v0::MatMul>(sinh0, transpose);
auto transpose = std::make_shared<op::v1::Transpose>(data1_guarded, const_order);
result = std::make_shared<op::v0::MatMul>(data0_guarded, transpose);
break;
} case 2: {
auto matmul = std::make_shared<op::v0::MatMul>(sinh0, sinh1);
auto matmul = std::make_shared<op::v0::MatMul>(data0_guarded, data1_guarded);
result = std::make_shared<op::v1::Transpose>(matmul, const_order);
break;
}
Expand Down

0 comments on commit 38bd8cc

Please sign in to comment.