Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Fix gemm_ref and gemm_tiled_opt #22617

Merged
merged 4 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/plugins/intel_gpu/src/graph/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
return input_shape_update;
};

auto transpose_shape = [](const ov::Shape& shape, const std::vector<int64_t>& order) {
auto shape_transposed = ov::Shape(shape);
auto rank_diff = shape.size() - order.size();
for (size_t i = 0; i < order.size(); i++) {
size_t idx = static_cast<size_t>(order[i]);
shape_transposed[i + rank_diff] = shape[idx + rank_diff];
}

return shape_transposed;
};

auto input0_shape_update = update_input_shape(input0_shape, input_rank, input0_order, true);
auto input1_shape_update = update_input_shape(input1_shape, weight_rank, input1_order, false);

Expand All @@ -72,6 +83,9 @@ layout gemm_inst::calc_output_layout(gemm_node const& node, kernel_impl_params c
size_t ones_to_add = 4 - std::min(output_shape.size(), static_cast<size_t>(4));
output_shape.insert(output_shape.begin(), ones_to_add, 1);

if (prim->output_order.size() > 0)
output_shape = transpose_shape(output_shape, prim->output_order);

yeonbok marked this conversation as resolved.
Show resolved Hide resolved
auto output_type = input0_layout.data_type;
if ((output_type == data_types::u8 || output_type == data_types::i8) && prim->output_data_types[0])
output_type = *prim->output_data_types[0];
Expand Down Expand Up @@ -216,6 +230,17 @@ layout gemm_inst::transform_output_layout(const std::shared_ptr<const gemm> prim

output_pshape[get_spatial_idx(updated_output_layout.format, 0)] = std::move(N);
output_pshape[get_spatial_idx(updated_output_layout.format, 1)] = std::move(M);

if (primitive->output_order.size() > 0) {
ov::PartialShape transposed_output_pshape = output_pshape;
auto rank_diff = output_pshape.size() - primitive->output_order.size();
for (size_t i = 0; i < primitive->output_order.size(); ++i) {
size_t idx = static_cast<size_t>(primitive->output_order[i]);
transposed_output_pshape[i + rank_diff] = std::move(output_pshape[idx + rank_diff]);
}
output_pshape = transposed_output_pshape;
}

updated_output_layout.set_partial_shape(output_pshape);
}
return updated_output_layout;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,15 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
}
}

auto gemm_prim = node.get_primitive();
for (size_t idx = 0; idx < gemm_prim->output_order.size(); ++idx) {
size_t output_order_idx = static_cast<size_t>(gemm_prim->output_order[idx]);
if (idx != output_order_idx) {
does_support_fusings = false;
break;
}
}

return does_support_fusings;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ inline uint FUNC(get_input2_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint
}
#endif // INPUT2_TYPE

#define INPUT0_SIZE_F INPUT0_FEATURE_NUM
#define INPUT0_SIZE_B INPUT0_BATCH_NUM

KERNEL(gemm_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input0,
Expand All @@ -84,13 +87,13 @@ KERNEL(gemm_ref)(
const uint y = (uint)get_global_id(1);

uint bidx = get_global_id(2);
const uint b = bidx % OUTPUT_BATCH_NUM;
bidx /= OUTPUT_BATCH_NUM;
const uint f = bidx % OUTPUT_FEATURE_NUM;
bidx /= OUTPUT_FEATURE_NUM;
const uint z = bidx % OUTPUT_SIZE_Z;
bidx /= OUTPUT_SIZE_Z;
const uint w = bidx % OUTPUT_SIZE_W;
const uint b = bidx % TR_OUTPUT_BATCH_NUM;
bidx /= TR_OUTPUT_BATCH_NUM;
const uint f = bidx % TR_OUTPUT_FEATURE_NUM;
bidx /= TR_OUTPUT_FEATURE_NUM;
const uint z = bidx % TR_OUTPUT_SIZE_Z;
bidx /= TR_OUTPUT_SIZE_Z;
const uint w = bidx % TR_OUTPUT_SIZE_W;

const uint K = CAT(INPUT0_SIZE_, MATMUL_AXIS);

Expand Down Expand Up @@ -129,3 +132,6 @@ KERNEL(gemm_ref)(
output[dst_index] = dequantized;
#endif
}

#undef INPUT0_SIZE_F
#undef INPUT0_SIZE_B
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ KERNEL(gemm_tiled_opt)(
// Start pointers offsets
#if TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
const __global INPUT0_TYPE* a_ptr = input0 + batch_offset_input0;
#if HAS_DYNAMIC_K_PADDING
#if HAS_DYNAMIC_K_PADDING || INPUT0_HAS_PADDING
const uint input0_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y+1), 0) - batch_offset_input0;
const uint input0_offset1 = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, (TILE_K)) - batch_offset_input0;
#else
Expand All @@ -133,7 +133,7 @@ KERNEL(gemm_tiled_opt)(
#endif
#elif TRANSPOSE_INPUT0 == TRANSPOSE_Y_LAST
const __global INPUT0_TYPE* a_ptr = input0 + batch_offset_input0;
#if HAS_DYNAMIC_K_PADDING
#if HAS_DYNAMIC_K_PADDING || INPUT0_HAS_PADDING
const uint input0_offset = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, 1) - batch_offset_input0;
const uint input0_offset1 = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, (TILE_K)) - batch_offset_input0;
#else
Expand All @@ -143,14 +143,14 @@ KERNEL(gemm_tiled_opt)(
#endif // TRANSPOSE_INPUT0
#if TRANSPOSE_INPUT1 == TRANSPOSE_X_LAST
const __global INPUT1_TYPE* b_ptr = input1 + batch_offset_input1;
#if HAS_DYNAMIC_K_PADDING
#if HAS_DYNAMIC_K_PADDING || INPUT1_HAS_PADDING
const uint input1_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, 1, tile_n_offset) - batch_offset_input1;
#else
const uint input1_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR 0, 0, 0, 0, 1, 0);
#endif
#elif TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST
const __global INPUT1_TYPE* b_ptr = input1 + batch_offset_input1;
#if HAS_DYNAMIC_K_PADDING
#if HAS_DYNAMIC_K_PADDING || INPUT1_HAS_PADDING
const uint input1_offset = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, 0, (tile_n_offset + 1)) - batch_offset_input1;
const uint input1_offset1 = FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (TILE_K), tile_n_offset) - batch_offset_input1;
#else
Expand Down Expand Up @@ -187,7 +187,7 @@ KERNEL(gemm_tiled_opt)(
unroll_for (uint b_load_id = 0; b_load_id < TILE_K; b_load_id++) {
#if IS_DYNAMIC
#if TRANSPOSE_INPUT1 == TRANSPOSE_X_LAST
#if HAS_DYNAMIC_N_PADDING
#if HAS_DYNAMIC_N_PADDING || INPUT1_HAS_PADDING
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else
b_tile[b_load_id] = TILE_N_NOT_DIVISIBLE ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
Expand Down Expand Up @@ -229,7 +229,7 @@ KERNEL(gemm_tiled_opt)(
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
#if TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
#if IS_DYNAMIC
#if HAS_DYNAMIC_K_PADDING
#if HAS_DYNAMIC_K_PADDING || INPUT0_HAS_PADDING
// In case of dynamic padding we can't guarantee memory access alignment for
// block reads (4 bytes), so use scattered read
uint a_idx = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (k * TILE_K + sglid));
Expand Down Expand Up @@ -286,7 +286,7 @@ KERNEL(gemm_tiled_opt)(
// Loading leftovers of the matrix B
unroll_for (uint b_load_id = 0; b_load_id < TILE_K_LEFTOVER; b_load_id++) {
#if TRANSPOSE_INPUT1 == TRANSPOSE_X_LAST
#if HAS_DYNAMIC_N_PADDING
#if HAS_DYNAMIC_N_PADDING || INPUT1_HAS_PADDING
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else
b_tile[b_load_id] = TILE_N_NOT_DIVISIBLE ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,37 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
MakeJitConstant("QUANTIZATION_TERM", params.quantization != QuantizationType::NONE),
});

auto get_output_size = [this](const std::vector<int64_t>& output_order_idx, const int target_idx) {
auto output_dims_order = GetDimsOrder(output_order_idx);

switch (output_dims_order.at(target_idx)) {
case 'b':
return "OUTPUT_BATCH_NUM";
case 'f':
return "OUTPUT_FEATURE_NUM";
case 'w':
return "OUTPUT_SIZE_W";
case 'z':
return "OUTPUT_SIZE_Z";
case 'y':
return "OUTPUT_SIZE_Y";
case 'x':
return "OUTPUT_SIZE_X";
default:
return "";
}
};

jit.AddConstants({
MakeJitConstant("TRANSPOSE_X_LAST", 0),
MakeJitConstant("TRANSPOSE_Y_LAST", 1),
MakeJitConstant("TRANSPOSE_OTHER", 2),
MakeJitConstant("INPUT0_DIMS_ORDER", GetDimsOrder(params.input0_order)),
MakeJitConstant("INPUT1_DIMS_ORDER", GetDimsOrder(params.input1_order)),
MakeJitConstant("MATMUL_AXIS", static_cast<char>(std::toupper(GetDimsOrder(params.input0_order).at(10)))),
MakeJitConstant("TR_OUTPUT_SIZE_Z", get_output_size(params.output_order, 6)),
MakeJitConstant("TR_OUTPUT_SIZE_W", get_output_size(params.output_order, 4)),
MakeJitConstant("TR_OUTPUT_FEATURE_NUM", get_output_size(params.output_order, 2)),
MakeJitConstant("TR_OUTPUT_BATCH_NUM", get_output_size(params.output_order, 0)),
});

return jit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,29 @@ JitConstants GemmKernelRef::GetJitConstants(const gemm_params& params) const {
jit.Merge(MakeTypeJitConstants(Datatype::F32, "ACTIVATION"));
}

auto get_matmul_axis = [](const std::vector<int64_t>& order_idx) {
auto last_idx = static_cast<size_t>(order_idx.back());
last_idx = (last_idx >= order_idx.size()) ? (order_idx.size() - 1) : last_idx;

std::vector<std::string> dims;
if (order_idx.size() == 1) {
dims = {"X"};
} else if (order_idx.size() == 2) {
dims = {"Y", "X"};
} else if (order_idx.size() == 3) {
dims = {"F", "Y", "X"};
} else if (order_idx.size() == 4) {
dims = {"B", "F", "Y", "X"};
} else if (order_idx.size() == 5) {
dims = {"B", "F", "Z", "Y", "X"};
} else if (order_idx.size() == 6) {
dims = {"B", "F", "W", "Z", "Y", "X"};
}
return dims[last_idx];
};

jit.AddConstants({
MakeJitConstant("MATMUL_AXIS", get_matmul_axis(params.input0_order)),
MakeJitConstant("TR_B", GetTransposedDims(params.output_order).at(0)),
MakeJitConstant("TR_F", GetTransposedDims(params.output_order).at(1)),
MakeJitConstant("TR_W", GetTransposedDims(params.output_order).at(4)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,9 @@ GemmKernelTiledOpt::GemmTuningData GemmKernelTiledOpt::SetTuningParams(const gem
JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) const {
JitConstants jit = Parent::GetJitConstants(params);

const auto& output = params.outputs[0];
GemmTuningData tuning_data = SetTuningParams(params);
auto b_vec_size = tuning_data.tile_n_size / tuning_data.simd_size;

auto get_output_size = [this](const std::vector<int64_t>& output_order_idx, const int target_idx) {
auto output_dims_order = Parent::GetDimsOrder(output_order_idx);

switch (output_dims_order.at(target_idx)) {
case 'b':
return "OUTPUT_BATCH_NUM";
case 'f':
return "OUTPUT_FEATURE_NUM";
case 'w':
return "OUTPUT_SIZE_W";
case 'z':
return "OUTPUT_SIZE_Z";
case 'y':
return "OUTPUT_SIZE_Y";
case 'x':
return "OUTPUT_SIZE_X";
default:
return "";
}
};

jit.Merge(MakeTypeJitConstants(params.inputs[0].GetDType(), "ACCUMULATOR"));
if (params.has_dynamic_tensors()) {
DimensionAccessHelper dims0(params.inputs[0]);
Expand Down Expand Up @@ -178,10 +156,6 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
MakeJitConstant("TR_Z", GetTransposedDims(params.output_order, true).at(5)),
MakeJitConstant("TR_Y", GetTransposedDims(params.output_order, true).at(6)),
MakeJitConstant("TR_X", GetTransposedDims(params.output_order, true).at(7)),
MakeJitConstant("TR_OUTPUT_SIZE_Z", get_output_size(params.output_order, 6)),
MakeJitConstant("TR_OUTPUT_SIZE_W", get_output_size(params.output_order, 4)),
MakeJitConstant("TR_OUTPUT_FEATURE_NUM", get_output_size(params.output_order, 2)),
MakeJitConstant("TR_OUTPUT_BATCH_NUM", get_output_size(params.output_order, 0)),
});

bool has_dynamic_k_padding = params.transpose_input0 ? params.inputs[0].Y().pad.is_dynamic
Expand All @@ -193,28 +167,25 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
if (has_dynamic_n_padding)
jit.AddConstant(MakeJitConstant("HAS_DYNAMIC_N_PADDING", 1));
} else {
auto get_untransposed_dim_size = [](const kernel_selector::DataTensor &data_tensor,
const std::vector<int64_t>& dims_order, const std::string dim) {
auto get_transposed_dim_size = [](const kernel_selector::DataTensor &data_tensor,
const std::vector<int64_t>& dims_order, const std::string dim) {
int64_t target_dim_idx;
const size_t rank = data_tensor.GetDims().size();
if (dim.compare("Y") == 0) {
if (dims_order.size() > 1 && dim.compare("Y") == 0) {
target_dim_idx = dims_order.at(dims_order.size() - 2);
} else if (dims_order.size() > 0 && dim.compare("X") == 0) {
target_dim_idx = dims_order.back();
} else if (dims_order.size() == 0 && dim.compare("Y") == 0) {
target_dim_idx = rank - 2;
} else if (dim.compare("X") == 0) {
} else if (dims_order.size() == 0 && dim.compare("X") == 0) {
target_dim_idx = rank - 1;
} else {
OPENVINO_THROW("Unsupported dimension: ", dim);
}

size_t loc = (dims_order.size() < rank) ? (rank - dims_order.size()) : 0;
if (dims_order.size() == 0) {
loc = static_cast<size_t>(target_dim_idx);
} else {
target_dim_idx = (dims_order.size() < rank) ? (target_dim_idx + dims_order.size() - rank) : target_dim_idx;
for (auto dim_idx : dims_order) {
if (dim_idx == target_dim_idx)
break;
loc += 1;
}
size_t loc = static_cast<size_t>(target_dim_idx);
if (dims_order.size() > 0) {
loc += (dims_order.size() < rank) ? (rank - dims_order.size()) : 0;
}

if (loc == 0) {
Expand All @@ -233,9 +204,9 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
OPENVINO_THROW("Target dimension is not found.");
};

auto m_size = get_untransposed_dim_size(output, params.output_order, "Y");
auto n_size = get_untransposed_dim_size(output, params.output_order, "X");
auto k_size = get_untransposed_dim_size(params.inputs[0], params.input0_order, "X");
auto m_size = get_transposed_dim_size(params.inputs[0], params.input0_order, "Y");
auto n_size = get_transposed_dim_size(params.inputs[1], params.input1_order, "X");
auto k_size = get_transposed_dim_size(params.inputs[0], params.input0_order, "X");
auto leftover_m = m_size % tuning_data.tile_m_size;
auto leftover_n = n_size % tuning_data.tile_n_size;
auto leftover_k = k_size % tuning_data.tile_k_size;
Expand Down Expand Up @@ -263,11 +234,12 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
MakeJitConstant("TR_Z", GetTransposedDims(params.output_order, true).at(5)),
MakeJitConstant("TR_Y", GetTransposedDims(params.output_order, true).at(6)),
MakeJitConstant("TR_X", GetTransposedDims(params.output_order, true).at(7)),
MakeJitConstant("TR_OUTPUT_SIZE_Z", get_output_size(params.output_order, 6)),
MakeJitConstant("TR_OUTPUT_SIZE_W", get_output_size(params.output_order, 4)),
MakeJitConstant("TR_OUTPUT_FEATURE_NUM", get_output_size(params.output_order, 2)),
MakeJitConstant("TR_OUTPUT_BATCH_NUM", get_output_size(params.output_order, 0)),
});

if (params.inputs[0].LogicalSize() != params.inputs[0].PhysicalSize())
jit.AddConstant(MakeJitConstant("INPUT0_HAS_PADDING", 1));
if (params.inputs[1].LogicalSize() != params.inputs[1].PhysicalSize())
jit.AddConstant(MakeJitConstant("INPUT1_HAS_PADDING", 1));
}

if (tuning_data.tile_k_size > tuning_data.simd_size) {
Expand Down Expand Up @@ -362,12 +334,7 @@ bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& o
return false;
}

bool gemm_leftovers = gmm_params.inputs[0].X().v % 16 || gmm_params.inputs[0].Y().v % 16 ||
gmm_params.inputs[1].X().v % 16 || gmm_params.inputs[1].Y().v % 16;
// If gmm_params has dynamic inputs, the correct dimension value cannot be obtained
// and leftovers cannot be calculated, so it returns false
if ((gmm_params.transpose_input0 || gmm_params.transpose_input1) && (gemm_leftovers || gmm_params.has_dynamic_inputs()) &&
!gmm_params.is_shape_agnostic)
if (gmm_params.has_dynamic_inputs() && !gmm_params.is_shape_agnostic)
return false;

for (size_t i = 1; i < gmm_params.inputs.size(); i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ TEST(prepare_buffer_fusing, test_implicit_crop_and_outerpadding) {
cldnn::mem_lock<int8_t> output_ptr(output, get_test_stream());

ExecutionConfig ref_config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(false));
ref_config.set_property(ov::intel_gpu::optimize_data(false));
cldnn::network ref_network(engine, topology, ref_config);
ref_network.set_input_data("Input", in_input);
ref_network.set_input_data("Input_idx_1", input_idx1);
Expand Down
Loading