Skip to content

Commit

Permalink
[GPU] Use explicit output indexes in the tr_matmul_tr transformation …
Browse files Browse the repository at this point in the history
…pass (#22769)

### Details:
- Some layers (e.g. VariadicSplit) does not have the default output
index. This PR fixes the tr_matmul_tr transformation pass to use
explicit output indexes.

### Tickets:
 - 131014
 - 131032
  • Loading branch information
e-ddykim authored Feb 11, 2024
1 parent 6accdd7 commit e57c827
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ TransposeMatMulMatcher::TransposeMatMulMatcher() {
std::swap(*(order_b.end() - 1), *(order_b.end() - 2));
}

auto input_a = pattern_map.at(input_a_m).get_node_shared_ptr();
auto input_b = pattern_map.at(input_b_m).get_node_shared_ptr();
auto input_a = ov::Output<Node>(pattern_map.at(input_a_m).get_node_shared_ptr(), matmul->get_input_source_output(0).get_index());
auto input_b = ov::Output<Node>(pattern_map.at(input_b_m).get_node_shared_ptr(), matmul->get_input_source_output(1).get_index());

auto gemm = std::make_shared<op::Gemm>(input_a, input_b, order_a, order_b, order_c);
gemm->set_friendly_name(matmul->get_friendly_name());
Expand Down Expand Up @@ -175,8 +175,8 @@ TransposeMatMulTransposeMatcher::TransposeMatMulTransposeMatcher() {
std::swap(*(order_b.end() - 1), *(order_b.end() - 2));
}

auto input_a = pattern_map.at(input_a_m).get_node_shared_ptr();
auto input_b = pattern_map.at(input_b_m).get_node_shared_ptr();
auto input_a = ov::Output<Node>(pattern_map.at(input_a_m).get_node_shared_ptr(), matmul->get_input_source_output(0).get_index());
auto input_b = ov::Output<Node>(pattern_map.at(input_b_m).get_node_shared_ptr(), matmul->get_input_source_output(1).get_index());

auto gemm = std::make_shared<op::Gemm>(input_a, input_b, order_a, order_b, order_c);
gemm->set_friendly_name(m.get_match_root()->get_friendly_name());
Expand Down
224 changes: 192 additions & 32 deletions src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ class gemm_gpu_tests: public ::testing::Test {
}
}

void test_transpose_matmul(bool is_caching_test) {
void test_transpose_matmul(size_t num_dims, bool is_input_dynamic, bool is_caching_test) {
tests::random_generator rg;
rg.set_seed(GET_SUITE_NAME);

Expand All @@ -719,12 +719,43 @@ class gemm_gpu_tests: public ::testing::Test {
};

auto& engine = get_test_engine();
ov::Shape input0_shape = { BATCH_SIZE, K_SIZE, 1, M_SIZE };
ov::Shape input1_shape = { N_SIZE, BATCH_SIZE, 1, K_SIZE };
std::vector<int64_t> input0_order = {0, 2, 3, 1};
std::vector<int64_t> input1_order = {1, 2, 3, 0};
auto input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx};
auto input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfyx};
ov::Shape input0_shape;
ov::Shape input1_shape;
std::vector<int64_t> input0_order;
std::vector<int64_t> input1_order;
cldnn::layout input0_layout;
cldnn::layout input1_layout;

if (num_dims == 1) {
input0_shape = { K_SIZE };
input1_shape = { N_SIZE, K_SIZE };
input0_order = { 0 };
input1_order = { 1, 0 };
} else if (num_dims == 2) {
input0_shape = { K_SIZE, M_SIZE };
input1_shape = { N_SIZE, K_SIZE };
input0_order = { 1, 0 };
input1_order = { 1, 0 };
} else if (num_dims == 3) {
input0_shape = { BATCH_SIZE, K_SIZE, M_SIZE };
input1_shape = { N_SIZE, BATCH_SIZE, K_SIZE };
input0_order = { 0, 2, 1 };
input1_order = { 1, 2, 0 };
} else if (num_dims == 4) {
input0_shape = { BATCH_SIZE, K_SIZE, 1, M_SIZE };
input1_shape = { N_SIZE, BATCH_SIZE, 1, K_SIZE };
input0_order = { 0, 2, 3, 1 };
input1_order = { 1, 2, 3, 0 };
}

if (is_input_dynamic) {
input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx};
input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfyx};
} else {
input0_layout = layout{ov::PartialShape(input0_shape), data_types::f32, format::bfyx};
input1_layout = layout{ov::PartialShape(input1_shape), data_types::f32, format::bfyx};
}

auto input0_mem = engine.allocate_memory(layout{ov::PartialShape(input0_shape), data_types::f32, format::bfyx});
auto input1_mem = engine.allocate_memory(layout{ov::PartialShape(input1_shape), data_types::f32, format::bfyx});

Expand All @@ -750,16 +781,33 @@ class gemm_gpu_tests: public ::testing::Test {
auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());
ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic);

auto outputs = network->execute();

auto output_mem = outputs.at("gemm").get_memory();
cldnn::mem_lock<float> output_ptr(output_mem, get_test_stream());

ov::Shape ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE };
ov::Shape ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE };
ov::Shape ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE };
ov::Shape ref_input0_shape;
ov::Shape ref_input1_shape;
ov::Shape ref_output_shape;
if (num_dims == 1) {
ref_input0_shape = { K_SIZE };
ref_input1_shape = { K_SIZE, N_SIZE };
ref_output_shape = { 1, N_SIZE };
} else if (num_dims == 2) {
ref_input0_shape = { M_SIZE, K_SIZE };
ref_input1_shape = { K_SIZE, N_SIZE };
ref_output_shape = { M_SIZE, N_SIZE };
} else if (num_dims == 3) {
ref_input0_shape = { BATCH_SIZE, M_SIZE, K_SIZE };
ref_input1_shape = { BATCH_SIZE, K_SIZE, N_SIZE };
ref_output_shape = { BATCH_SIZE, M_SIZE, N_SIZE };
} else if (num_dims == 4) {
ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE };
ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE };
ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE };
}

std::vector<float> ref_out_data;
ref_out_data.resize(ov::shape_size(ref_output_shape));
Expand Down Expand Up @@ -798,7 +846,7 @@ class gemm_gpu_tests: public ::testing::Test {
}
}

void test_transpose_matmul_transpose(bool is_caching_test) {
void test_transpose_matmul_transpose(size_t num_dims, bool is_input_dynamic, bool is_caching_test) {
tests::random_generator rg;
rg.set_seed(GET_SUITE_NAME);

Expand All @@ -825,13 +873,48 @@ class gemm_gpu_tests: public ::testing::Test {
};

auto& engine = get_test_engine();
ov::Shape input0_shape = { M_SIZE, K_SIZE, 1, BATCH_SIZE };
ov::Shape input1_shape = { N_SIZE, 1, BATCH_SIZE, K_SIZE };
std::vector<int64_t> input0_order = {3, 2, 0, 1};
std::vector<int64_t> input1_order = {2, 1, 3, 0};
std::vector<int64_t> output_order = {1, 0, 3, 2};
auto input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f16, format::bfyx};
auto input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f16, format::bfyx};
ov::Shape input0_shape;
ov::Shape input1_shape;
std::vector<int64_t> input0_order;
std::vector<int64_t> input1_order;
std::vector<int64_t> output_order;
cldnn::layout input0_layout;
cldnn::layout input1_layout;

if (num_dims == 1) {
input0_shape = { K_SIZE };
input1_shape = { N_SIZE, K_SIZE };
input0_order = { 0 };
input1_order = { 1, 0 };
output_order = { 0 };
} else if (num_dims == 2) {
input0_shape = { K_SIZE, M_SIZE };
input1_shape = { N_SIZE, K_SIZE };
input0_order = { 1, 0 };
input1_order = { 1, 0 };
output_order = { 1, 0 };
} else if (num_dims == 3) {
input0_shape = { BATCH_SIZE, K_SIZE, M_SIZE };
input1_shape = { N_SIZE, BATCH_SIZE, K_SIZE };
input0_order = { 0, 2, 1 };
input1_order = { 1, 2, 0 };
output_order = { 1, 0, 2 };
} else if (num_dims == 4) {
input0_shape = { M_SIZE, K_SIZE, 1, BATCH_SIZE };
input1_shape = { N_SIZE, 1, BATCH_SIZE, K_SIZE };
input0_order = {3, 2, 0, 1};
input1_order = {2, 1, 3, 0};
output_order = {1, 0, 3, 2};
}

if (is_input_dynamic) {
input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f16, format::bfyx};
input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f16, format::bfyx};
} else {
input0_layout = layout{ov::PartialShape(input0_shape), data_types::f16, format::bfyx};
input1_layout = layout{ov::PartialShape(input1_shape), data_types::f16, format::bfyx};
}

auto input0_mem = engine.allocate_memory(layout{ov::PartialShape(input0_shape), data_types::f16, format::bfyx});
auto input1_mem = engine.allocate_memory(layout{ov::PartialShape(input1_shape), data_types::f16, format::bfyx});

Expand All @@ -857,17 +940,38 @@ class gemm_gpu_tests: public ::testing::Test {
auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());
ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic);

auto outputs = network->execute();

auto output_mem = outputs.at("gemm").get_memory();
cldnn::mem_lock<ov::float16> output_ptr(output_mem, get_test_stream());

ov::Shape ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE };
ov::Shape ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE };
ov::Shape ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE };
ov::Shape transposed_output_shape = { 1, BATCH_SIZE, N_SIZE, M_SIZE };
ov::Shape ref_input0_shape;
ov::Shape ref_input1_shape;
ov::Shape ref_output_shape;
ov::Shape transposed_output_shape;
if (num_dims == 1) {
ref_input0_shape = { K_SIZE };
ref_input1_shape = { K_SIZE, N_SIZE };
ref_output_shape = { 1, N_SIZE };
transposed_output_shape = { N_SIZE, 1 };
} else if (num_dims == 2) {
ref_input0_shape = { M_SIZE, K_SIZE };
ref_input1_shape = { K_SIZE, N_SIZE };
ref_output_shape = { M_SIZE, N_SIZE };
transposed_output_shape = { N_SIZE, M_SIZE };
} else if (num_dims == 3) {
ref_input0_shape = { BATCH_SIZE, M_SIZE, K_SIZE };
ref_input1_shape = { BATCH_SIZE, K_SIZE, N_SIZE };
ref_output_shape = { BATCH_SIZE, M_SIZE, N_SIZE };
transposed_output_shape = { M_SIZE, BATCH_SIZE, N_SIZE };
} else if (num_dims == 4) {
ref_input0_shape = { BATCH_SIZE, 1, M_SIZE, K_SIZE };
ref_input1_shape = { BATCH_SIZE, 1, K_SIZE, N_SIZE };
ref_output_shape = { BATCH_SIZE, 1, M_SIZE, N_SIZE };
transposed_output_shape = { 1, BATCH_SIZE, N_SIZE, M_SIZE };
}

std::vector<ov::float16> ref_out_data;
ref_out_data.resize(ov::shape_size(ref_output_shape));
Expand Down Expand Up @@ -936,12 +1040,68 @@ TEST_F(gemm_gpu_tests, dynamic_multi_inference_different_shape) {
this->test_dynamic_multi_inference_different_shape(false);
}

TEST_F(gemm_gpu_tests, transpose_matmul) {
this->test_transpose_matmul(false);
TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_1d) {
this->test_transpose_matmul(1, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_static_1d) {
this->test_transpose_matmul(1, false, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_2d) {
this->test_transpose_matmul(2, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_static_2d) {
this->test_transpose_matmul(2, false, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_3d) {
this->test_transpose_matmul(3, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_static_3d) {
this->test_transpose_matmul(3, false, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_4d) {
this->test_transpose_matmul(4, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_static_4d) {
this->test_transpose_matmul(4, false, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_1d) {
this->test_transpose_matmul_transpose(1, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_1d) {
this->test_transpose_matmul_transpose(1, false, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_2d) {
this->test_transpose_matmul_transpose(2, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_2d) {
this->test_transpose_matmul_transpose(2, false, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_3d) {
this->test_transpose_matmul_transpose(3, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_3d) {
this->test_transpose_matmul_transpose(3, false, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_4d) {
this->test_transpose_matmul_transpose(4, true, false);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose) {
this->test_transpose_matmul_transpose(false);
TEST_F(gemm_gpu_tests, transpose_matmul_transpose_static_4d) {
this->test_transpose_matmul_transpose(4, false, false);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down Expand Up @@ -2366,11 +2526,11 @@ TEST_F(gemm_gpu_tests, basic_bfyx_t2_inplace_crop_with_pad_cached) {
this->test_basic_bfyx_t2_inplace_crop_with_pad(true);
}

TEST_F(gemm_gpu_tests, transpose_matmul_cached) {
this->test_transpose_matmul(true);
TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_4d_cached) {
this->test_transpose_matmul(4, true, true);
}

TEST_F(gemm_gpu_tests, transpose_matmul_transpose_cached) {
this->test_transpose_matmul_transpose(true);
TEST_F(gemm_gpu_tests, transpose_matmul_transpose_dynamic_4d_cached) {
this->test_transpose_matmul_transpose(4, true, true);
}
} // namespace

0 comments on commit e57c827

Please sign in to comment.