diff --git a/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.cpp index 5cdb0b75d6b787..fc085e3421764a 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/transpose_matmul_fusion.cpp @@ -107,8 +107,8 @@ TransposeMatMulMatcher::TransposeMatMulMatcher() { std::swap(*(order_b.end() - 1), *(order_b.end() - 2)); } - auto input_a = pattern_map.at(input_a_m).get_node_shared_ptr(); - auto input_b = pattern_map.at(input_b_m).get_node_shared_ptr(); + auto input_a = ov::Output(pattern_map.at(input_a_m).get_node_shared_ptr(), matmul->get_input_source_output(0).get_index()); + auto input_b = ov::Output(pattern_map.at(input_b_m).get_node_shared_ptr(), matmul->get_input_source_output(1).get_index()); auto gemm = std::make_shared(input_a, input_b, order_a, order_b, order_c); gemm->set_friendly_name(matmul->get_friendly_name()); @@ -175,8 +175,8 @@ TransposeMatMulTransposeMatcher::TransposeMatMulTransposeMatcher() { std::swap(*(order_b.end() - 1), *(order_b.end() - 2)); } - auto input_a = pattern_map.at(input_a_m).get_node_shared_ptr(); - auto input_b = pattern_map.at(input_b_m).get_node_shared_ptr(); + auto input_a = ov::Output(pattern_map.at(input_a_m).get_node_shared_ptr(), matmul->get_input_source_output(0).get_index()); + auto input_b = ov::Output(pattern_map.at(input_b_m).get_node_shared_ptr(), matmul->get_input_source_output(1).get_index()); auto gemm = std::make_shared(input_a, input_b, order_a, order_b, order_c); gemm->set_friendly_name(m.get_match_root()->get_friendly_name()); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp index 96d046b7ba224f..fa56c83db19414 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp @@ -692,7 +692,7 @@ class gemm_gpu_tests: public ::testing::Test { } } - void test_transpose_matmul(bool is_caching_test) { + void test_transpose_matmul(size_t num_dims, bool is_input_dynamic, bool is_caching_test) { tests::random_generator rg; rg.set_seed(GET_SUITE_NAME); @@ -719,12 +719,43 @@ class gemm_gpu_tests: public ::testing::Test { }; auto& engine = get_test_engine(); - ov::Shape input0_shape = { BATCH_SIZE, K_SIZE, 1, M_SIZE }; - ov::Shape input1_shape = { N_SIZE, BATCH_SIZE, 1, K_SIZE }; - std::vector input0_order = {0, 2, 3, 1}; - std::vector input1_order = {1, 2, 3, 0}; - auto input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx}; - auto input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfyx}; + ov::Shape input0_shape; + ov::Shape input1_shape; + std::vector input0_order; + std::vector input1_order; + cldnn::layout input0_layout; + cldnn::layout input1_layout; + + if (num_dims == 1) { + input0_shape = { K_SIZE }; + input1_shape = { N_SIZE, K_SIZE }; + input0_order = { 0 }; + input1_order = { 1, 0 }; + } else if (num_dims == 2) { + input0_shape = { K_SIZE, M_SIZE }; + input1_shape = { N_SIZE, K_SIZE }; + input0_order = { 1, 0 }; + input1_order = { 1, 0 }; + } else if (num_dims == 3) { + input0_shape = { BATCH_SIZE, K_SIZE, M_SIZE }; + input1_shape = { N_SIZE, BATCH_SIZE, K_SIZE }; + input0_order = { 0, 2, 1 }; + input1_order = { 1, 2, 0 }; + } else if (num_dims == 4) { + input0_shape = { BATCH_SIZE, K_SIZE, 1, M_SIZE }; + input1_shape = { N_SIZE, BATCH_SIZE, 1, K_SIZE }; + input0_order = { 0, 2, 3, 1 }; + input1_order = { 1, 2, 3, 0 }; + } + + if (is_input_dynamic) { + input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx}; + input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfyx}; + } else { + input0_layout = layout{ov::PartialShape(input0_shape), data_types::f32, format::bfyx}; + input1_layout = layout{ov::PartialShape(input1_shape), data_types::f32, format::bfyx}; + } + auto input0_mem = engine.allocate_memory(layout{ov::PartialShape(input0_shape), data_types::f32, format::bfyx}); auto input1_mem = engine.allocate_memory(layout{ov::PartialShape(input1_shape), data_types::f32, format::bfyx}); @@ -750,16 +781,33 @@ class gemm_gpu_tests: public ::testing::Test { auto inst = network->get_primitive("gemm"); auto impl = inst->get_impl(); ASSERT_TRUE(impl != nullptr); - ASSERT_TRUE(impl->is_dynamic()); + ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic); auto outputs = network->execute(); auto output_mem = outputs.at("gemm").get_memory(); cldnn::mem_lock output_ptr(output_mem, get_test_stream()); - ov::Shape ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE }; - ov::Shape ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE }; - ov::Shape ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE }; + ov::Shape ref_input0_shape; + ov::Shape ref_input1_shape; + ov::Shape ref_output_shape; + if (num_dims == 1) { + ref_input0_shape = { K_SIZE }; + ref_input1_shape = { K_SIZE, N_SIZE }; + ref_output_shape = { 1, N_SIZE }; + } else if (num_dims == 2) { + ref_input0_shape = { M_SIZE, K_SIZE }; + ref_input1_shape = { K_SIZE, N_SIZE }; + ref_output_shape = { M_SIZE, N_SIZE }; + } else if (num_dims == 3) { + ref_input0_shape = { BATCH_SIZE, M_SIZE, K_SIZE }; + ref_input1_shape = { BATCH_SIZE, K_SIZE, N_SIZE }; + ref_output_shape = { BATCH_SIZE, M_SIZE, N_SIZE }; + } else if (num_dims == 4) { + ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE }; + ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE }; + ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE }; + } std::vector ref_out_data; ref_out_data.resize(ov::shape_size(ref_output_shape)); @@ -798,7 +846,7 @@ class gemm_gpu_tests: public ::testing::Test { } } - void test_transpose_matmul_transpose(bool is_caching_test) { + void test_transpose_matmul_transpose(size_t num_dims, bool is_input_dynamic, bool is_caching_test) { tests::random_generator rg; rg.set_seed(GET_SUITE_NAME); @@ -825,13 +873,48 @@ class gemm_gpu_tests: public ::testing::Test { }; auto& engine = get_test_engine(); - ov::Shape input0_shape = { M_SIZE, K_SIZE, 1, BATCH_SIZE }; - ov::Shape input1_shape = { N_SIZE, 1, BATCH_SIZE, K_SIZE }; - std::vector input0_order = {3, 2, 0, 1}; - std::vector input1_order = {2, 1, 3, 0}; - std::vector output_order = {1, 0, 3, 2}; - auto input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f16, format::bfyx}; - auto input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f16, format::bfyx}; + ov::Shape input0_shape; + ov::Shape input1_shape; + std::vector input0_order; + std::vector input1_order; + std::vector output_order; + cldnn::layout input0_layout; + cldnn::layout input1_layout; + + if (num_dims == 1) { + input0_shape = { K_SIZE }; + input1_shape = { N_SIZE, K_SIZE }; + input0_order = { 0 }; + input1_order = { 1, 0 }; + output_order = { 0 }; + } else if (num_dims == 2) { + input0_shape = { K_SIZE, M_SIZE }; + input1_shape = { N_SIZE, K_SIZE }; + input0_order = { 1, 0 }; + input1_order = { 1, 0 }; + output_order = { 1, 0 }; + } else if (num_dims == 3) { + input0_shape = { BATCH_SIZE, K_SIZE, M_SIZE }; + input1_shape = { N_SIZE, BATCH_SIZE, K_SIZE }; + input0_order = { 0, 2, 1 }; + input1_order = { 1, 2, 0 }; + output_order = { 1, 0, 2 }; + } else if (num_dims == 4) { + input0_shape = { M_SIZE, K_SIZE, 1, BATCH_SIZE }; + input1_shape = { N_SIZE, 1, BATCH_SIZE, K_SIZE }; + input0_order = {3, 2, 0, 1}; + input1_order = {2, 1, 3, 0}; + output_order = {1, 0, 3, 2}; + } + + if (is_input_dynamic) { + input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f16, format::bfyx}; + input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f16, format::bfyx}; + } else { + input0_layout = layout{ov::PartialShape(input0_shape), data_types::f16, format::bfyx}; + input1_layout = layout{ov::PartialShape(input1_shape), data_types::f16, format::bfyx}; + } + auto input0_mem = engine.allocate_memory(layout{ov::PartialShape(input0_shape), data_types::f16, format::bfyx}); auto input1_mem = engine.allocate_memory(layout{ov::PartialShape(input1_shape), data_types::f16, format::bfyx}); @@ -857,17 +940,38 @@ class gemm_gpu_tests: public ::testing::Test { auto inst = network->get_primitive("gemm"); auto impl = inst->get_impl(); ASSERT_TRUE(impl != nullptr); - ASSERT_TRUE(impl->is_dynamic()); + ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic); auto outputs = network->execute(); auto output_mem = outputs.at("gemm").get_memory(); cldnn::mem_lock output_ptr(output_mem, get_test_stream()); - ov::Shape ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE }; - ov::Shape ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE }; - ov::Shape ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE }; - ov::Shape transposed_output_shape = { 1, BATCH_SIZE, N_SIZE, M_SIZE }; + ov::Shape ref_input0_shape; + ov::Shape ref_input1_shape; + ov::Shape ref_output_shape; + ov::Shape transposed_output_shape; + if (num_dims == 1) { + ref_input0_shape = { K_SIZE }; + ref_input1_shape = { K_SIZE, N_SIZE }; + ref_output_shape = { 1, N_SIZE }; + transposed_output_shape = { N_SIZE, 1 }; + } else if (num_dims == 2) { + ref_input0_shape = { M_SIZE, K_SIZE }; + ref_input1_shape = { K_SIZE, N_SIZE }; + ref_output_shape = { M_SIZE, N_SIZE }; + transposed_output_shape = { N_SIZE, M_SIZE }; + } else if (num_dims == 3) { + ref_input0_shape = { BATCH_SIZE, M_SIZE, K_SIZE }; + ref_input1_shape = { BATCH_SIZE, K_SIZE, N_SIZE }; + ref_output_shape = { BATCH_SIZE, M_SIZE, N_SIZE }; + transposed_output_shape = { M_SIZE, BATCH_SIZE, N_SIZE }; + } else if (num_dims == 4) { + ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE }; + ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE }; + ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE }; + transposed_output_shape = { 1, BATCH_SIZE, N_SIZE, M_SIZE }; + } std::vector ref_out_data; ref_out_data.resize(ov::shape_size(ref_output_shape)); @@ -936,12 +1040,68 @@ TEST_F(gemm_gpu_tests, dynamic_multi_inference_different_shape) { this->test_dynamic_multi_inference_different_shape(false); } -TEST_F(gemm_gpu_tests, transpose_matmul) { - this->test_transpose_matmul(false); +TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_1d) { + this->test_transpose_matmul(1, true, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_static_1d) { + this->test_transpose_matmul(1, false, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_2d) { + this->test_transpose_matmul(2, true, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_static_2d) { + this->test_transpose_matmul(2, false, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_3d) { + this->test_transpose_matmul(3, true, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_static_3d) { + this->test_transpose_matmul(3, false, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_4d) { + this->test_transpose_matmul(4, true, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_static_4d) { + this->test_transpose_matmul(4, false, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_1d) { + this->test_transpose_matmul_transpose(1, true, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_1d) { + this->test_transpose_matmul_transpose(1, false, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_2d) { + this->test_transpose_matmul_transpose(2, true, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_2d) { + this->test_transpose_matmul_transpose(2, false, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_3d) { + this->test_transpose_matmul_transpose(3, true, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_3d) { + this->test_transpose_matmul_transpose(3, false, false); +} + +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_4d) { + this->test_transpose_matmul_transpose(4, true, false); } -TEST_F(gemm_gpu_tests, transpose_matmul_transpose) { - this->test_transpose_matmul_transpose(false); +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_4d) { + this->test_transpose_matmul_transpose(4, false, false); } INSTANTIATE_TEST_SUITE_P( @@ -2366,11 +2526,11 @@ TEST_F(gemm_gpu_tests, basic_bfyx_t2_inplace_crop_with_pad_cached) { this->test_basic_bfyx_t2_inplace_crop_with_pad(true); } -TEST_F(gemm_gpu_tests, transpose_matmul_cached) { - this->test_transpose_matmul(true); +TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_4d_cached) { + this->test_transpose_matmul(4, true, true); } -TEST_F(gemm_gpu_tests, transpose_matmul_transpose_cached) { - this->test_transpose_matmul_transpose(true); +TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_4d_cached) { + this->test_transpose_matmul_transpose(4, true, true); } } // namespace