diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp index 8e578ee557dfe2..c028216572d4f3 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp @@ -31,61 +31,95 @@ struct gemm_impl : typed_primitive_impl_ocl { auto gemm_optional_params = get_default_optional_params(arg.get_program()); - auto gemmSpecificPartialShape = [](ov::PartialShape& pshape) { - switch (pshape.rank().get_length()) { - case 2: { // batch, feature representation (rank == 2) - pshape.insert(pshape.begin(), 1ul); - pshape.insert(pshape.begin(), 1ul); - break; + auto get_gemm_input_layouts = [desc](const std::vector& input_layouts, const layout& output_layout) { + auto gemm_specific_pshape = [](ov::PartialShape& pshape) { + switch (pshape.rank().get_length()) { + case 2: { // batch, feature representation (rank == 2) + pshape.insert(pshape.begin(), 1ul); + pshape.insert(pshape.begin(), 1ul); + break; + } + case 3 : { // feature representation (rank == 3) + pshape.insert(pshape.begin(), 1, 1ul); + break; + } } - case 3 : { // feature representation (rank == 3) - pshape.insert(pshape.begin(), 1, 1ul); - break; + }; + std::vector layouts; + auto output_pshape = output_layout.size; + auto output_rank = output_pshape.rank().get_length(); + for (size_t i = 0; i != input_layouts.size(); ++i) { + auto input_layout = input_layouts[i]; + auto input_pshape = input_layout.size; + auto input_rank = input_pshape.rank().get_length(); + if (input_rank != output_rank || input_rank < 4) { + if (input_rank == 1) { + bool transpose = false; + if (i == 0) { + transpose = desc->transpose_input0; + input_pshape.insert(input_pshape.begin(), 1); + } else { + transpose = desc->transpose_input1; + input_pshape.insert(input_pshape.end(), 1); + } + if (transpose) { + std::swap(input_pshape[0], input_pshape[1]); + } + } + if (input_rank < output_rank) + input_pshape.insert(input_pshape.begin(), output_rank - input_rank, 1ul); + + gemm_specific_pshape(input_pshape); } + input_layout.size = input_pshape; + layouts.push_back(input_layout); } + return layouts; }; - auto output_layout = arg.get_output_layout(); - auto output_pshape = output_layout.size; - auto output_rank = output_pshape.rank().get_length(); - std::vector input_shapes; - for (size_t i = 0; i < arg.inputs_count(); i++) { - auto input_layout = arg.input(i).get_output_layout(); - auto input_pshape = input_layout.get_partial_shape(); - auto input_rank = input_pshape.rank().get_length(); - if (input_rank != output_rank || input_rank < 4) { - if (input_rank == 1) { - bool transpose = false; - if (i == 0) { - transpose = arg.get_primitive()->transpose_input0; - input_pshape.insert(input_pshape.begin(), 1); - } else { - transpose = arg.get_primitive()->transpose_input1; - input_pshape.insert(input_pshape.end(), 1); - } - if (transpose) { - std::swap(input_pshape[0], input_pshape[1]); + + auto get_gemm_output_layout = [desc](const std::vector& input_layouts, const layout& output_layout) { + auto layout = output_layout; + auto output_pshape = output_layout.size; + auto output_rank = output_pshape.rank().get_length(); + if (output_rank < 4) { + auto input0_layout = input_layouts[0]; + auto input1_layout = input_layouts[1]; + bool transpose_input0 = desc->transpose_input0; + bool transpose_input1 = desc->transpose_input1; + + auto M = !transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0); + auto N = !transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1); + + auto output_shape = input_layouts[0].size.to_shape(); + for (size_t i = 0; i != input_layouts.size(); ++i) { + auto input_pshape = input_layouts[i].size; + auto input_shape = input_pshape.to_shape(); + for (size_t j = 0; j != input_pshape.rank().get_length(); ++j) { + output_shape[j] = std::max(output_shape[j], input_shape[j]); } } - if (input_rank < output_rank) - input_pshape.insert(input_pshape.begin(), output_rank - input_rank, 1ul); - - gemmSpecificPartialShape(input_pshape); + layout.size = ov::PartialShape(output_shape); + auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) { + const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx); + return idx; + }; + layout.size[get_spatial_idx(layout.format, 0)] = N; + layout.size[get_spatial_idx(layout.format, 1)] = M; } - input_layout.size = input_pshape; - input_shapes.push_back(input_pshape); - if (i == 0) - gemm_params.inputs[0] = convert_data_tensor(input_layout); - else - gemm_params.inputs.push_back(convert_data_tensor(input_layout)); - } - if (output_rank < 4) { - ov::op::v0::MatMul op; - op.set_transpose_a(arg.get_primitive()->transpose_input0); - op.set_transpose_b(arg.get_primitive()->transpose_input1); - std::vector output_shapes = {ov::PartialShape()}; - shape_infer(&op, input_shapes, output_shapes); - output_layout.size = output_shapes[0]; - gemm_params.outputs[0] = convert_data_tensor(output_layout); + return layout; + }; + const auto input_layouts = get_gemm_input_layouts(arg.get_input_layouts(), arg.get_output_layout()); + const auto output_layout = get_gemm_output_layout(input_layouts, arg.get_output_layout()); + const auto& param_info = kernel_impl_params(arg.get_program(), desc, arg.get_unique_id(), + input_layouts, output_layout, + arg.get_fused_primitives(), + arg.get_fused_activations_funcs(), arg.get_fused_activations_params()); + auto gemm_params = get_default_params(param_info, 1); + auto gemm_optional_params = + get_default_optional_params(arg.get_program()); + + for (size_t i = 1; i < arg.inputs_count(); i++) { + gemm_params.inputs.push_back(convert_data_tensor(param_info.input_layouts[i])); } gemm_params.alpha = desc->alpha;