From b6fe09b5fd2a3db42e1395f7a4a41cdcf42fcc1b Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Tue, 29 Nov 2022 13:34:25 +0400 Subject: [PATCH] Changes for MatMul --- .../include/snippets/op/matmul_cpu.hpp | 15 +++- src/common/snippets/src/op/buffer.cpp | 10 +++ src/common/snippets/src/op/matmul_cpu.cpp | 16 ++-- .../pass/fuse_transpose_and_matmul_cpu.cpp | 8 +- src/common/snippets/src/pass/insert_loops.cpp | 10 ++- .../src/emitters/jit_snippets_emitters.cpp | 40 +++++---- .../src/emitters/jit_snippets_emitters.hpp | 7 +- .../snippets/matmul.cpp | 38 ++++++++- .../shared_tests_instances/snippets/mha.cpp | 24 ++---- .../snippets/transpose_softmax.cpp | 42 ++++++++++ .../plugin/shared/include/snippets/matmul.hpp | 20 +++++ .../plugin/shared/include/snippets/mha.hpp | 15 +--- .../include/snippets/transpose_softmax.hpp | 40 +++++++++ .../plugin/shared/src/snippets/matmul.cpp | 82 ++++++++++++++++++- .../plugin/shared/src/snippets/mha.cpp | 46 ++--------- .../shared/src/snippets/transpose_softmax.cpp | 78 ++++++++++++++++++ .../include/subgraph_matmul.hpp | 38 +++++++++ .../include/subgraph_mha.hpp | 9 ++ .../src/subgraph_matmul.cpp | 51 ++++++++++++ .../src/subgraph_mha.cpp | 58 +++++++++++++ 20 files changed, 540 insertions(+), 107 deletions(-) create mode 100644 src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_softmax.cpp create mode 100644 src/tests/functional/plugin/shared/include/snippets/transpose_softmax.hpp create mode 100644 src/tests/functional/plugin/shared/src/snippets/transpose_softmax.cpp diff --git a/src/common/snippets/include/snippets/op/matmul_cpu.hpp b/src/common/snippets/include/snippets/op/matmul_cpu.hpp index f2cf941d5febdc..79e98ac03ffbcc 100644 --- a/src/common/snippets/include/snippets/op/matmul_cpu.hpp +++ b/src/common/snippets/include/snippets/op/matmul_cpu.hpp @@ -20,7 +20,7 @@ namespace op { class MatMulCPU : public ngraph::op::v0::MatMul { public: OPENVINO_OP("MatMulCPU", "SnippetsOpset", ngraph::op::v0::MatMul); - MatMulCPU(const Output& A, const Output& B); + MatMulCPU(const Output& A, const Output& B, size_t offset_a = 0, size_t offset_b = 0, size_t offset_c = 0); MatMulCPU() = default; bool visit_attributes(AttributeVisitor& visitor) override; @@ -29,9 +29,20 @@ class MatMulCPU : public ngraph::op::v0::MatMul { bool has_evaluate() const override { return false; } + size_t get_offset_a() const { return m_offset_a; } + size_t get_offset_b() const { return m_offset_b; } + size_t get_offset_c() const { return m_offset_c; } + + void set_offset_a(size_t offset) { m_offset_a = offset; } + void set_offset_b(size_t offset) { m_offset_b = offset; } + void set_offset_c(size_t offset) { m_offset_c = offset; } + private: - MatMulCPU(const Output& A, const Output& B, std::vector output_layout); + MatMulCPU(const Output& A, const Output& B, std::vector output_layout, size_t offset_a = 0, size_t offset_b = 0, size_t offset_c = 0); std::vector m_output_layout; + size_t m_offset_a = 0lu; + size_t m_offset_b = 0lu; + size_t m_offset_c = 0lu; }; } // namespace op diff --git a/src/common/snippets/src/op/buffer.cpp b/src/common/snippets/src/op/buffer.cpp index e207874cc18687..8a9b2ab428ed6f 100644 --- a/src/common/snippets/src/op/buffer.cpp +++ b/src/common/snippets/src/op/buffer.cpp @@ -56,6 +56,9 @@ void snippets::op::Buffer::set_offset(const size_t offset) { } if (auto store = std::dynamic_pointer_cast(parent)) { store->set_offset(m_offset); + } else if (auto matmul = std::dynamic_pointer_cast(parent)) { + // MatMul encapsulates work with Loops inside himself + matmul->set_offset_c(m_offset); } else { throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Store op for offset propagation"); } @@ -73,6 +76,13 @@ void snippets::op::Buffer::set_offset(const size_t offset) { } } else if (const auto load = std::dynamic_pointer_cast(child)) { load->set_offset(m_offset); + } else if (auto matmul = std::dynamic_pointer_cast(child)) { + // MatMul encapsulates work with Loops inside himself + if (target_input.get_index() == 0) { + matmul->set_offset_a(m_offset); + } else { + matmul->set_offset_b(m_offset); + } } else { throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Load op for offset propagation"); } diff --git a/src/common/snippets/src/op/matmul_cpu.cpp b/src/common/snippets/src/op/matmul_cpu.cpp index 70e53623c67349..61524ad56ff1bb 100644 --- a/src/common/snippets/src/op/matmul_cpu.cpp +++ b/src/common/snippets/src/op/matmul_cpu.cpp @@ -12,14 +12,15 @@ namespace ngraph { namespace snippets { namespace op { -MatMulCPU::MatMulCPU(const Output& A, const Output& B) : MatMul(), m_output_layout({}) { +MatMulCPU::MatMulCPU(const Output& A, const Output& B, size_t offset_a, size_t offset_b, size_t offset_c) + : MatMul(), m_output_layout({}), m_offset_a(offset_a), m_offset_b(offset_b), m_offset_c(offset_c) { set_arguments({A, B}); set_output_size(1); constructor_validate_and_infer_types(); } -MatMulCPU::MatMulCPU(const Output& A, const Output& B, std::vector output_layout) - : MatMul(), m_output_layout(std::move(output_layout)) { +MatMulCPU::MatMulCPU(const Output& A, const Output& B, std::vector output_layout, size_t offset_a, size_t offset_b, size_t offset_c) + : MatMul(), m_output_layout(std::move(output_layout)), m_offset_a(offset_a), m_offset_b(offset_b), m_offset_c(offset_c) { set_arguments({A, B}); set_output_size(1); constructor_validate_and_infer_types(); @@ -27,6 +28,9 @@ MatMulCPU::MatMulCPU(const Output& A, const Output& B, std::vector MatMulCPU::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(MatMulCPU_clone_with_new_inputs); check_new_args_count(this, new_args); -// auto new_matmul = std::make_shared(new_args.at(0), new_args.at(1)); - return std::shared_ptr(new MatMulCPU(new_args.at(0), new_args.at(1), m_output_layout)); -// new_matmul->output_layout = output_layout; -// return new_matmul; -// return std::make_shared(new_args.at(0), new_args.at(1)); + return std::shared_ptr(new MatMulCPU(new_args.at(0), new_args.at(1), m_output_layout, m_offset_a, m_offset_b, m_offset_c)); } } // namespace op diff --git a/src/common/snippets/src/pass/fuse_transpose_and_matmul_cpu.cpp b/src/common/snippets/src/pass/fuse_transpose_and_matmul_cpu.cpp index 0843f2b86351fe..a3a62912c0e18a 100644 --- a/src/common/snippets/src/pass/fuse_transpose_and_matmul_cpu.cpp +++ b/src/common/snippets/src/pass/fuse_transpose_and_matmul_cpu.cpp @@ -49,7 +49,7 @@ FuseTransposeMatMulCPU::FuseTransposeMatMulCPU() { auto matmul_out0 = pattern::wrap_type({matmul_any, constant}); auto matmul_or_transpose = std::make_shared(OutputVector{matmul_in0, matmul_in1, matmul_out0}); - auto callback = [](pattern::Matcher& m) { + auto callback = [&transpose_is_supported](pattern::Matcher& m) { OV_ITT_SCOPED_TASK(pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::FuseTransposeMatMulCPU") auto set_layout_from_order = [](const std::shared_ptr& node, const ov::Output& port) { const auto& const_order = as_type_ptr(node->get_input_node_shared_ptr(1)); @@ -72,8 +72,10 @@ FuseTransposeMatMulCPU::FuseTransposeMatMulCPU() { for (int i = 0; i < matmul->get_input_size(); i++) { const auto& in_value = matmul->input_value(i); if (const auto& transpose = as_type_ptr(in_value.get_node_shared_ptr())) { - set_layout_from_order(transpose, transpose->input_value(0)); - matmul->set_argument(i, transpose->input_value(0)); + if (transpose_is_supported(transpose)) { + set_layout_from_order(transpose, transpose->input_value(0)); + matmul->set_argument(i, transpose->input_value(0)); + } } } // need to run validate_and_infer_types manually: either input shapes were updated or diff --git a/src/common/snippets/src/pass/insert_loops.cpp b/src/common/snippets/src/pass/insert_loops.cpp index 587922f183f15b..a481707bdcdbe5 100644 --- a/src/common/snippets/src/pass/insert_loops.cpp +++ b/src/common/snippets/src/pass/insert_loops.cpp @@ -48,8 +48,7 @@ std::vector InsertLoops::calculate_finalization_offsets(const ov::Parti return inner_finalization_offsets; } -void insert_explicitly_loops(const ov::NodeVector& ops, const ov::PartialShape& master_shape, - size_t inner_work_amount, size_t outer_work_amount, size_t vector_size) { +void insert_explicitly_loops(const ov::NodeVector& ops, const ov::PartialShape& master_shape, size_t vector_size) { ov::NodeVector body; ov::OutputVector body_parameters; std::vector> body_results; @@ -67,6 +66,11 @@ void insert_explicitly_loops(const ov::NodeVector& ops, const ov::PartialShape& auto apply_increments = InsertLoops::calculate_inner_apply_increments(master_shape, body_shapes); std::vector inner_finalization_offsets(body_shapes.size(), 0); + auto body_master_shape = body_shapes.front(); + for (const auto& shape : body_shapes) + PartialShape::broadcast_merge_into(body_master_shape, shape, ::ngraph::op::AutoBroadcastType::NUMPY); + const auto inner_work_amount = utils::get_inner_dim(body_master_shape).get_length(); + const auto outer_work_amount = utils::get_outer_dim(body_master_shape).get_length(); if (outer_work_amount > 1) { inner_finalization_offsets = InsertLoops::calculate_finalization_offsets(master_shape, body_shapes); } @@ -236,7 +240,7 @@ bool InsertLoops::run_on_model(const std::shared_ptr &model) { op::insertLoopEnd(commonResults, outer_loop_begin, 1lu, outer_work_amount, 1lu, apply_increments); } } else { - insert_explicitly_loops(ops, m_master_shape, inner_work_amount, outer_work_amount, m_vector_size); + insert_explicitly_loops(ops, m_master_shape, m_vector_size); } } diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index 31ebc0858bab5b..864062a5dada48 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -1024,6 +1024,10 @@ MatMulEmitter::MatMulEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: } } } + + load_offset_a = matmul_node->get_offset_a(); + load_offset_b = matmul_node->get_offset_b(); + store_offset_c = matmul_node->get_offset_c(); } void MatMulEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const { @@ -1070,7 +1074,8 @@ void MatMulEmitter::emit_impl(const std::vector& in, template void MatMulEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, int bs, Reg64 addr_A, Reg64 addr_B, - const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch) const { + const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch, + const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const { using Vmm = typename dnnl::impl::utils::conditional3::type; size_t gpr_size = 8; Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, @@ -1120,8 +1125,15 @@ void MatMulEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. h->mov(abi_param1, reinterpret_cast(brgKernel)); h->mov(abi_param2, bs); - h->uni_vmovq(abi_param3, Xmm(0)); - h->uni_vmovq(abi_param4, Xmm(1)); + + const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t memory_bytes_offset, size_t kernel_bytes_offset) { + h->uni_vmovq(reg, xmm); + if (memory_bytes_offset) h->add(reg, memory_bytes_offset); + if (kernel_bytes_offset) h->add(reg, kernel_bytes_offset); + }; + data_ptr(Xmm(0), abi_param3, load_offset_a, in0_kernel_offset); + data_ptr(Xmm(1), abi_param4, load_offset_b, in1_kernel_offset); + size_t num_args_passed_on_stack = 1; #ifdef _WIN32 num_args_passed_on_stack = 3; @@ -1130,9 +1142,11 @@ void MatMulEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in h->mov(h->qword[h->rsp], reinterpret_cast(scratch)); h->mov(h->qword[h->rsp + gpr_size], reinterpret_cast(batch)); h->mov(h->qword[h->rsp + 2 * gpr_size], Xmm(2)); + if (store_offset_c) h->add(h->qword[h->rsp + 2 * gpr_size], store_offset_c); + if (out0_kernel_offset) h->add(h->qword[h->rsp + 2 * gpr_size], out0_kernel_offset); #else h->mov(abi_param5, reinterpret_cast(batch)); - h->uni_vmovq(abi_param6, Xmm(2)); + data_ptr(Xmm(2), abi_param6, store_offset_c, out0_kernel_offset); h->sub(h->rsp, gpr_size); h->mov(h->qword[h->rsp], reinterpret_cast(scratch)); #endif @@ -1194,25 +1208,17 @@ void MatMulEmitter::emit_isa(const std::vector &in, const std::vectoradd(input_0, in0_offset); - if (in1_offset != 0) - h->add(input_1, in1_offset); - if (out0_offset != 0) - h->add(output_0, out0_offset); + emit_brgemm_kernel_call(brgKernels0[getBrgIdx(mIdx, k, n)].get(), 1, input_0, input_1, nullptr, output_0, - nullptr); - if (in0_offset != 0) - h->sub(input_0, in0_offset); - if (in1_offset != 0) - h->sub(input_1, in1_offset); - if (out0_offset != 0) - h->sub(output_0, out0_offset); + nullptr, + in0_offset, + in1_offset, + out0_offset); } } } diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index 420c54c35f2bab..4fe823a57b4bae 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -479,7 +479,8 @@ class MatMulEmitter : public jit_emitter { template void emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, int bs, Reg64 addr_A, Reg64 addr_B, - const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch) const; + const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch, + const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const; static constexpr size_t MHA_BRGEMM_KERNELS_NUM = 8; static constexpr size_t matmulOptimalM = 32; @@ -490,6 +491,10 @@ class MatMulEmitter : public jit_emitter { size_t M, M_blk, M_tail; size_t K0, K0_blk, K0_tail, N0, N0_blk, N0_tail; size_t brg0VnniFactor; + + size_t load_offset_a = 0lu; + size_t load_offset_b = 0lu; + size_t store_offset_c = 0lu; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index 567b93bcd8fdb3..851969bbe91497 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -23,7 +23,7 @@ std::vector> input_shapes{ {{1, 2, 69, 43}, {2, 1, 43, 49}} }; std::vector precisions{element::f32}; -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMul, MatMul, ::testing::Combine( ::testing::ValuesIn(input_shapes), ::testing::ValuesIn(precisions), @@ -32,6 +32,42 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, ::testing::Values(CommonTestUtils::DEVICE_CPU)), MatMul::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 1, 69, 49}}), + ::testing::ValuesIn(precisions), + ::testing::Values(4), // Sinh * 3 + Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ExplicitTransposeMatMul, ExplicitTransposeMatMul, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}}), + ::testing::ValuesIn(precisions), + ::testing::Values(3), // Sinh * 2 + Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ExplicitTransposeMatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMatMulBias, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 1, 69, 49}}), + ::testing::ValuesIn(precisions), + ::testing::Values(4), // Sinh * 3 + Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMulMatMulBias, ExplicitTransposeMulMatMulBias, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 2, 1, 1}, {1, 1, 69, 49}}), + ::testing::ValuesIn(precisions), + ::testing::Values(5), // Sinh * 4 + Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + namespace transpose_zero_input { std::vector> transpose_input_shapes{ {{2, 69, 3, 43}, {2, 3, 43, 49}} diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 5836d60fa161c6..6cb28f2c630a92 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -12,29 +12,17 @@ namespace snippets { namespace { -const std::vector inputShape = { - ov::Shape{1, 128, 3, 16}, +const std::vector inputShapes = { + {1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}, }; -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeSoftmax, TransposeSoftmax, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA, ::testing::Combine( - ::testing::Values(inputShape), - ::testing::Values(std::vector{0, 2, 3, 1}), - ::testing::Values(-1), - ::testing::Values(2), // Subgraph + Sin + ::testing::Values(inputShapes), + ::testing::Values(5), // Subgraph + 4xSin ::testing::Values(1), ::testing::Values(CommonTestUtils::DEVICE_CPU)), - TransposeSoftmax::getTestCaseName); - -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeSoftmaxEltwise, TransposeSoftmaxEltwise, - ::testing::Combine( - ::testing::Values(inputShape), - ::testing::Values(std::vector{0, 2, 3, 1}), - ::testing::Values(-1), - ::testing::Values(2), // Subgraph + Sin - ::testing::Values(1), - ::testing::Values(CommonTestUtils::DEVICE_CPU)), - TransposeSoftmax::getTestCaseName); + MHA::getTestCaseName); } // namespace } // namespace snippets diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_softmax.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_softmax.cpp new file mode 100644 index 00000000000000..76dbb58f5b4644 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_softmax.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/transpose_softmax.hpp" +#include "common_test_utils/test_constants.hpp" + +namespace ov { +namespace test { +namespace snippets { + + +namespace { + +const std::vector inputShape = { + ov::Shape{1, 128, 3, 16}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeSoftmax, TransposeSoftmax, + ::testing::Combine( + ::testing::Values(inputShape), + ::testing::Values(std::vector{0, 2, 3, 1}), + ::testing::Values(-1), + ::testing::Values(2), // Subgraph + Sin + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeSoftmax::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeSoftmaxEltwise, TransposeSoftmaxEltwise, + ::testing::Combine( + ::testing::Values(inputShape), + ::testing::Values(std::vector{0, 2, 3, 1}), + ::testing::Values(-1), + ::testing::Values(2), // Subgraph + Sin + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeSoftmax::getTestCaseName); + +} // namespace +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp index f187715eb2dc7b..9e9373132f587e 100644 --- a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -36,6 +36,26 @@ class MatMul : public testing::WithParamInterface, virtual public ov::test::SnippetsTestsCommon { public: diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index 952b7528a00375..dc89094661026f 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -12,28 +12,21 @@ namespace snippets { typedef std::tuple< std::vector, // Input shapes - std::vector, // Transpose Order - int64_t, // Softmax Axis size_t, // Expected num nodes size_t, // Expected num subgraphs std::string // Target Device -> TransposeSoftmaxParams; +> MHAParams; -class TransposeSoftmax : public testing::WithParamInterface, - virtual public ov::test::SnippetsTestsCommon { +class MHA : public testing::WithParamInterface, + virtual public ov::test::SnippetsTestsCommon { public: - static std::string getTestCaseName(testing::TestParamInfo obj); + static std::string getTestCaseName(testing::TestParamInfo obj); protected: void SetUp() override; }; -class TransposeSoftmaxEltwise : public TransposeSoftmax { -protected: - void SetUp() override; -}; - } // namespace snippets } // namespace test diff --git a/src/tests/functional/plugin/shared/include/snippets/transpose_softmax.hpp b/src/tests/functional/plugin/shared/include/snippets/transpose_softmax.hpp new file mode 100644 index 00000000000000..952b7528a00375 --- /dev/null +++ b/src/tests/functional/plugin/shared/include/snippets/transpose_softmax.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/base/snippets_test_utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +typedef std::tuple< + std::vector, // Input shapes + std::vector, // Transpose Order + int64_t, // Softmax Axis + size_t, // Expected num nodes + size_t, // Expected num subgraphs + std::string // Target Device +> TransposeSoftmaxParams; + + +class TransposeSoftmax : public testing::WithParamInterface, + virtual public ov::test::SnippetsTestsCommon { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; +}; + +class TransposeSoftmaxEltwise : public TransposeSoftmax { +protected: + void SetUp() override; +}; + + +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp index c142d612423148..0edc41f94d068d 100644 --- a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp @@ -18,11 +18,9 @@ std::string MatMul::getTestCaseName(testing::TestParamInfo input_shapes; + ov::element::Type elem_type; + std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(dynamic_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::MatMulBiasSinhFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, + InferenceEngine::PluginConfigParams::YES}); + } +} + +void ExplicitTransposeMatMul::SetUp() { + std::vector input_shapes; + ov::element::Type elem_type; + std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(dynamic_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::TransposeMatMulSinhFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, + InferenceEngine::PluginConfigParams::YES}); + } +} + +void ExplicitTransposeMatMulBias::SetUp() { + std::vector input_shapes; + ov::element::Type elem_type; + std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(dynamic_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::TransposeMatMulBiasSinhFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, + InferenceEngine::PluginConfigParams::YES}); + } +} + +void ExplicitTransposeMulMatMulBias::SetUp() { + std::vector input_shapes; + ov::element::Type elem_type; + std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(dynamic_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::TransposeMulMatMulBiasSinhFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, + InferenceEngine::PluginConfigParams::YES}); + } +} + std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo obj) { std::vector input_shapes; size_t transpose_position; @@ -84,6 +138,26 @@ TEST_P(MatMul, CompareWithRefImpl) { validateNumSubgraphs(); } +TEST_P(MatMulBias, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(ExplicitTransposeMulMatMulBias, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + TEST_P(TransposeMatMul, CompareWithRefImpl) { run(); validateNumSubgraphs(); diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 1d648c8a55f032..fa0dd1cbdb8111 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -5,74 +5,42 @@ #include "common_test_utils/common_utils.hpp" #include "snippets/mha.hpp" #include "subgraph_mha.hpp" -#include "ngraph_functions/builders.hpp" #include "functional_test_utils/skip_tests_config.hpp" -#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp" + namespace ov { namespace test { namespace snippets { -std::string TransposeSoftmax::getTestCaseName(testing::TestParamInfo obj) { +std::string MHA::getTestCaseName(testing::TestParamInfo obj) { std::vector inputShapes; - std::vector order; - int axis; std::string targetDevice; size_t num_nodes, num_subgraphs; - std::tie(inputShapes, order, axis, num_nodes, num_subgraphs, targetDevice) = obj.param; + std::tie(inputShapes, num_nodes, num_subgraphs, targetDevice) = obj.param; std::ostringstream result; for (size_t i = 0; i < inputShapes.size(); ++i) result << "IS[" << i << "]=" << CommonTestUtils::vec2str(inputShapes[i]) << "_"; - result << "TO=" << CommonTestUtils::vec2str(order) << "_"; - result << "Axis=" << axis << "_"; result << "#N=" << num_nodes << "_"; result << "#S=" << num_subgraphs << "_"; result << "targetDevice=" << targetDevice; return result.str(); } -void TransposeSoftmax::SetUp() { - std::vector inputShapes; - std::vector order; - int64_t axis; - std::tie(inputShapes, order, axis, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_shapes_to_test_representation(inputShapes)); - - auto f = ov::test::snippets::TransposeSoftmaxFunction(inputDynamicShapes, order, axis); - function = f.getOriginal(); - - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, InferenceEngine::PluginConfigParams::YES}); - } -} - -void TransposeSoftmaxEltwise::SetUp() { +void MHA::SetUp() { std::vector inputShapes; - std::vector order; - int64_t axis; - std::tie(inputShapes, order, axis, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::tie(inputShapes, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_shapes_to_test_representation(inputShapes)); - auto f = ov::test::snippets::TransposeSoftmaxEltwiseFunction(inputDynamicShapes, order, axis); + auto f = ov::test::snippets::MHASinFunction(inputDynamicShapes); function = f.getOriginal(); - - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, InferenceEngine::PluginConfigParams::YES}); - } } -TEST_P(TransposeSoftmax, CompareWithRefImpl) { +TEST_P(MHA, CompareWithRefImpl) { run(); validateNumSubgraphs(); } -TEST_P(TransposeSoftmaxEltwise, CompareWithRefImpl) { - run(); - validateNumSubgraphs(); -} - - } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/transpose_softmax.cpp b/src/tests/functional/plugin/shared/src/snippets/transpose_softmax.cpp new file mode 100644 index 00000000000000..c3c964a5eb5ea2 --- /dev/null +++ b/src/tests/functional/plugin/shared/src/snippets/transpose_softmax.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_test_utils/common_utils.hpp" +#include "snippets/transpose_softmax.hpp" +#include "subgraph_mha.hpp" +#include "ngraph_functions/builders.hpp" +#include "functional_test_utils/skip_tests_config.hpp" +#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp" + +namespace ov { +namespace test { +namespace snippets { + +std::string TransposeSoftmax::getTestCaseName(testing::TestParamInfo obj) { + std::vector inputShapes; + std::vector order; + int axis; + std::string targetDevice; + size_t num_nodes, num_subgraphs; + std::tie(inputShapes, order, axis, num_nodes, num_subgraphs, targetDevice) = obj.param; + + std::ostringstream result; + for (size_t i = 0; i < inputShapes.size(); ++i) + result << "IS[" << i << "]=" << CommonTestUtils::vec2str(inputShapes[i]) << "_"; + result << "TO=" << CommonTestUtils::vec2str(order) << "_"; + result << "Axis=" << axis << "_"; + result << "#N=" << num_nodes << "_"; + result << "#S=" << num_subgraphs << "_"; + result << "targetDevice=" << targetDevice; + return result.str(); +} + +void TransposeSoftmax::SetUp() { + std::vector inputShapes; + std::vector order; + int64_t axis; + std::tie(inputShapes, order, axis, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_shapes_to_test_representation(inputShapes)); + + auto f = ov::test::snippets::TransposeSoftmaxFunction(inputDynamicShapes, order, axis); + function = f.getOriginal(); + + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, InferenceEngine::PluginConfigParams::YES}); + } +} + +void TransposeSoftmaxEltwise::SetUp() { + std::vector inputShapes; + std::vector order; + int64_t axis; + std::tie(inputShapes, order, axis, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_shapes_to_test_representation(inputShapes)); + + auto f = ov::test::snippets::TransposeSoftmaxEltwiseFunction(inputDynamicShapes, order, axis); + function = f.getOriginal(); + + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MHA_OPS_TOKENIZATION_ENABLE, InferenceEngine::PluginConfigParams::YES}); + } +} + +TEST_P(TransposeSoftmax, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(TransposeSoftmaxEltwise, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp index 4cced69e612bb4..4fa78ce0ed3928 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp @@ -34,6 +34,17 @@ class MatMulSinhFunction : public SnippetsFunctionBase { std::shared_ptr initReference() const override; }; +// As same as MatMulSinhFunction but with biases +class MatMulBiasSinhFunction : public SnippetsFunctionBase { +public: + explicit MatMulBiasSinhFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + /// Minimal graph to test MatMul+Transpose combinations. Transpose location is specified via the position argument: /// 0 - before the first MatMul input; 1 - before the second MatMul input; 2 - after the MatMul output. /// Tokenized simply by starting subgraph, @@ -57,6 +68,33 @@ class Transpose0213MatMulSinhFunction : public SnippetsFunctionBase { size_t transpose_position; }; +class TransposeMatMulSinhFunction : public SnippetsFunctionBase { +public: + explicit TransposeMatMulSinhFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + +class TransposeMatMulBiasSinhFunction : public SnippetsFunctionBase { +public: + explicit TransposeMatMulBiasSinhFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + +class TransposeMulMatMulBiasSinhFunction : public SnippetsFunctionBase { +public: + explicit TransposeMulMatMulBiasSinhFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp index 0800350320d825..fba9824db3a5fb 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp @@ -23,6 +23,15 @@ class MHAFunction : public SnippetsFunctionBase { std::shared_ptr initReference() const override; }; +// As same as MHAFunction but with sinh on inputs +// TODO: Remove sin when snippet will be able to tokenize after parameters +class MHASinFunction : public MHAFunction { +public: + explicit MHASinFunction(const std::vector& inputShapes) : MHAFunction(inputShapes) {} +protected: + std::shared_ptr initOriginal() const override; +}; + // TODO: Write Graph class MHAMatMul0TransposeFunction : public MHAFunction { public: diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp index e9159a0097025e..46843b0795d4b2 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp @@ -4,6 +4,7 @@ #include "subgraph_matmul.hpp" #include "common_test_utils/data_utils.hpp" +#include "ngraph_functions/builders.hpp" #include namespace ov { @@ -29,6 +30,17 @@ std::shared_ptr MatMulSinhFunction::initReference() const { ParameterVector{indata0, indata1})); return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); } +std::shared_ptr MatMulBiasSinhFunction::initOriginal() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto sinh0 = std::make_shared(data0); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto sinh1 = std::make_shared(data1); + auto matmul = std::make_shared(sinh0, sinh1); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto sinh2 = std::make_shared(data2); + auto bias = std::make_shared(matmul, sinh2); + return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2}); +} std::shared_ptr Transpose0213MatMulSinhFunction::initOriginal() const { auto data0 = std::make_shared(precision, input_shapes[0]); auto sinh0 = std::make_shared(data0); @@ -53,6 +65,45 @@ std::shared_ptr Transpose0213MatMulSinhFunction::initOriginal() const } return std::make_shared(NodeVector{result}, ParameterVector{data0, data1}); } +std::shared_ptr TransposeMatMulSinhFunction::initOriginal() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto sinh0 = std::make_shared(data0); + auto sinh1 = std::make_shared(data1); + auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 3, 1}); + auto transpose = std::make_shared(sinh1, const_order); + auto matmul = std::make_shared(sinh0, transpose); + return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); +} +std::shared_ptr TransposeMatMulBiasSinhFunction::initOriginal() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto sinh0 = std::make_shared(data0); + auto sinh1 = std::make_shared(data1); + auto sinh2 = std::make_shared(data2); + auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 3, 1}); + auto transpose = std::make_shared(sinh1, const_order); + auto matmul = std::make_shared(sinh0, transpose); + auto bias = std::make_shared(matmul, sinh2); + return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2}); +} +std::shared_ptr TransposeMulMatMulBiasSinhFunction::initOriginal() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto data3 = std::make_shared(precision, input_shapes[3]); + auto sinh0 = std::make_shared(data0); + auto sinh1 = std::make_shared(data1); + auto sinh2 = std::make_shared(data2); + auto sinh3 = std::make_shared(data3); + auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 3, 1}); + auto transpose = std::make_shared(sinh1, const_order); + auto mul = std::make_shared(transpose, sinh2); + auto matmul = std::make_shared(sinh0, mul); + auto bias = std::make_shared(matmul, sinh3); + return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2, data3}); +} } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index 3e18c4f02aa84c..8202ef961cdca3 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -127,6 +127,64 @@ std::shared_ptr MHAFunction::initReference() const { return std::make_shared(NodeVector{subgraph}, ngraphParams); } +std::shared_ptr MHASinFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + auto sin0 = std::make_shared(transpose0Param); + auto sin1 = std::make_shared(transpose1Param); + auto sin2 = std::make_shared(addParam); + auto sin3 = std::make_shared(transpose2Param); + + std::vector constantShapes; + constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({1, input_shapes[1].get_shape()[2], 1, 1})); + constantShapes.push_back(ov::Shape({2})); + constantShapes.push_back(ov::Shape({4})); + constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); + constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); + + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[0], std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[1], std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[5], std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], std::vector{0, 2, 1, 3}); + + std::vector mulConstData(ngraph::shape_size(constantShapes[2])); + auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[3], reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[4], reshape1ConstData); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(sin0, transpose0Const); + const auto transpose1 = std::make_shared(sin1, transpose1Const); + const auto mul = std::make_shared(transpose1, mulConst); + const auto matMul0 = std::make_shared(transpose0, mul, transA, transB); + const auto add = std::make_shared(matMul0, sin2); + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + const auto transpose2 = std::make_shared(sin3, transpose2Const); + const auto matMul1 = std::make_shared(reshape1, transpose2, transA, transB); + const auto transpose3 = std::make_shared(matMul1, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} + std::shared_ptr MHAMatMul0TransposeFunction::initOriginal() const { auto transpose0Param = std::make_shared(precision, input_shapes[0]); auto transpose1Param = std::make_shared(precision, input_shapes[1]);