Skip to content

Commit

Permalink
Update gemm_impl::create to calculate input layouts and output layout…
Browse files Browse the repository at this point in the history
… for kernel_impl_params (openvinotoolkit#76)

Signed-off-by: Andrew Park <[email protected]>
  • Loading branch information
andrew-k-park authored and vladimir-paramuzov committed Jul 26, 2022
1 parent a066477 commit b56274f
Showing 1 changed file with 82 additions and 48 deletions.
130 changes: 82 additions & 48 deletions src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,61 +31,95 @@ struct gemm_impl : typed_primitive_impl_ocl<gemm> {
auto gemm_optional_params =
get_default_optional_params<kernel_selector::gemm_optional_params>(arg.get_program());

auto gemmSpecificPartialShape = [](ov::PartialShape& pshape) {
switch (pshape.rank().get_length()) {
case 2: { // batch, feature representation (rank == 2)
pshape.insert(pshape.begin(), 1ul);
pshape.insert(pshape.begin(), 1ul);
break;
auto get_gemm_input_layouts = [desc](const std::vector<layout>& input_layouts, const layout& output_layout) {
auto gemm_specific_pshape = [](ov::PartialShape& pshape) {
switch (pshape.rank().get_length()) {
case 2: { // batch, feature representation (rank == 2)
pshape.insert(pshape.begin(), 1ul);
pshape.insert(pshape.begin(), 1ul);
break;
}
case 3 : { // feature representation (rank == 3)
pshape.insert(pshape.begin(), 1, 1ul);
break;
}
}
case 3 : { // feature representation (rank == 3)
pshape.insert(pshape.begin(), 1, 1ul);
break;
};
std::vector<layout> layouts;
auto output_pshape = output_layout.size;
auto output_rank = output_pshape.rank().get_length();
for (size_t i = 0; i != input_layouts.size(); ++i) {
auto input_layout = input_layouts[i];
auto input_pshape = input_layout.size;
auto input_rank = input_pshape.rank().get_length();
if (input_rank != output_rank || input_rank < 4) {
if (input_rank == 1) {
bool transpose = false;
if (i == 0) {
transpose = desc->transpose_input0;
input_pshape.insert(input_pshape.begin(), 1);
} else {
transpose = desc->transpose_input1;
input_pshape.insert(input_pshape.end(), 1);
}
if (transpose) {
std::swap(input_pshape[0], input_pshape[1]);
}
}
if (input_rank < output_rank)
input_pshape.insert(input_pshape.begin(), output_rank - input_rank, 1ul);

gemm_specific_pshape(input_pshape);
}
input_layout.size = input_pshape;
layouts.push_back(input_layout);
}
return layouts;
};
auto output_layout = arg.get_output_layout();
auto output_pshape = output_layout.size;
auto output_rank = output_pshape.rank().get_length();
std::vector<ov::PartialShape> input_shapes;
for (size_t i = 0; i < arg.inputs_count(); i++) {
auto input_layout = arg.input(i).get_output_layout();
auto input_pshape = input_layout.get_partial_shape();
auto input_rank = input_pshape.rank().get_length();
if (input_rank != output_rank || input_rank < 4) {
if (input_rank == 1) {
bool transpose = false;
if (i == 0) {
transpose = arg.get_primitive()->transpose_input0;
input_pshape.insert(input_pshape.begin(), 1);
} else {
transpose = arg.get_primitive()->transpose_input1;
input_pshape.insert(input_pshape.end(), 1);
}
if (transpose) {
std::swap(input_pshape[0], input_pshape[1]);

auto get_gemm_output_layout = [desc](const std::vector<layout>& input_layouts, const layout& output_layout) {
auto layout = output_layout;
auto output_pshape = output_layout.size;
auto output_rank = output_pshape.rank().get_length();
if (output_rank < 4) {
auto input0_layout = input_layouts[0];
auto input1_layout = input_layouts[1];
bool transpose_input0 = desc->transpose_input0;
bool transpose_input1 = desc->transpose_input1;

auto M = !transpose_input0 ? input0_layout.spatial(1) : input0_layout.spatial(0);
auto N = !transpose_input1 ? input1_layout.spatial(0) : input1_layout.spatial(1);

auto output_shape = input_layouts[0].size.to_shape();
for (size_t i = 0; i != input_layouts.size(); ++i) {
auto input_pshape = input_layouts[i].size;
auto input_shape = input_pshape.to_shape();
for (size_t j = 0; j != input_pshape.rank().get_length(); ++j) {
output_shape[j] = std::max(output_shape[j], input_shape[j]);
}
}
if (input_rank < output_rank)
input_pshape.insert(input_pshape.begin(), output_rank - input_rank, 1ul);

gemmSpecificPartialShape(input_pshape);
layout.size = ov::PartialShape(output_shape);
auto get_spatial_idx = [](cldnn::format format, size_t spatial_idx) {
const size_t idx = (format::is_grouped(format) ? 3 : 2) + (format.spatial_num() - 1 - spatial_idx);
return idx;
};
layout.size[get_spatial_idx(layout.format, 0)] = N;
layout.size[get_spatial_idx(layout.format, 1)] = M;
}
input_layout.size = input_pshape;
input_shapes.push_back(input_pshape);
if (i == 0)
gemm_params.inputs[0] = convert_data_tensor(input_layout);
else
gemm_params.inputs.push_back(convert_data_tensor(input_layout));
}
if (output_rank < 4) {
ov::op::v0::MatMul op;
op.set_transpose_a(arg.get_primitive()->transpose_input0);
op.set_transpose_b(arg.get_primitive()->transpose_input1);
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
shape_infer(&op, input_shapes, output_shapes);
output_layout.size = output_shapes[0];
gemm_params.outputs[0] = convert_data_tensor(output_layout);
return layout;
};
const auto input_layouts = get_gemm_input_layouts(arg.get_input_layouts(), arg.get_output_layout());
const auto output_layout = get_gemm_output_layout(input_layouts, arg.get_output_layout());
const auto& param_info = kernel_impl_params(arg.get_program(), desc, arg.get_unique_id(),
input_layouts, output_layout,
arg.get_fused_primitives(),
arg.get_fused_activations_funcs(), arg.get_fused_activations_params());
auto gemm_params = get_default_params<kernel_selector::gemm_params>(param_info, 1);
auto gemm_optional_params =
get_default_optional_params<kernel_selector::gemm_optional_params>(arg.get_program());

for (size_t i = 1; i < arg.inputs_count(); i++) {
gemm_params.inputs.push_back(convert_data_tensor(param_info.input_layouts[i]));
}

gemm_params.alpha = desc->alpha;
Expand Down

0 comments on commit b56274f

Please sign in to comment.