Skip to content

Commit

Permalink
[GPU] In gemm_tile_kernel, block read is used when N and K sizes are …
Browse files Browse the repository at this point in the history
…odd.

Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback committed Mar 12, 2024
1 parent 15921ea commit 166a9df
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ KERNEL(gemm_tiled_opt)(
#if HAS_DYNAMIC_N_PADDING || INPUT1_HAS_PADDING
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else
b_tile[b_load_id] = TILE_N_NOT_DIVISIBLE ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
b_tile[b_load_id] = N_IS_ODD ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
#endif
b_ptr += input1_offset;
}
Expand Down Expand Up @@ -275,7 +275,7 @@ KERNEL(gemm_tiled_opt)(
else
#endif // INDIRECT_INPUT1
{
#if TILE_N_NOT_DIVISIBLE
#if N_IS_ODD
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
Expand Down Expand Up @@ -334,17 +334,17 @@ KERNEL(gemm_tiled_opt)(
uint a_idx = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (k * TILE_K + sglid));
A_FLOATN a_read = input0[a_idx];
#else
A_FLOATN a_read = TILE_K_NOT_DIVISIBLE ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0);
A_FLOATN a_read = K_IS_ODD ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0);
#endif
#else // IS_DYNAMIC
#if INDIRECT_INPUT0
uint a_idx = FUNC_CALL(get_input0_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (k * TILE_K + sglid), beam_table);
A_FLOATN a_read = input0[a_idx];
#elif TILE_K_NOT_DIVISIBLE
#elif K_IS_ODD
A_FLOATN a_read = a_ptr[sglid];
#else // TILE_K_NOT_DIVISIBLE
#else // K_IS_ODD
A_FLOATN a_read = BLOCK_READ_A(a_ptr, 0);
#endif // TILE_K_NOT_DIVISIBLE
#endif // K_IS_ODD
#endif // IS_DYNAMIC
a_ptr += input0_offset;

Expand Down Expand Up @@ -414,7 +414,7 @@ KERNEL(gemm_tiled_opt)(
#if HAS_DYNAMIC_N_PADDING || INPUT1_HAS_PADDING
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else
b_tile[b_load_id] = TILE_N_NOT_DIVISIBLE ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
b_tile[b_load_id] = N_IS_ODD ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
#endif
b_ptr += input1_offset;
}
Expand Down Expand Up @@ -486,11 +486,11 @@ KERNEL(gemm_tiled_opt)(
else
#endif
{
#if TILE_N_NOT_DIVISIBLE
#if N_IS_ODD
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else // TILE_N_NOT_DIVISIBLE
#else
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
#endif // TILE_N_NOT_DIVISIBLE
#endif
b_ptr += input1_offset;
}
#elif TRANSPOSE_INPUT1 == TRANSPOSE_OTHER // TRANSPOSE_INPUT1 == 0
Expand Down Expand Up @@ -530,14 +530,15 @@ KERNEL(gemm_tiled_opt)(
#endif // TRANSPOSE_INPUT1 == TRANSPOSE_Y_LAST

// Loading leftovers of the matrix A and tile C calculation
a_ptr = input0 + FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, (K_FULL_ITERATIONS * TILE_K));
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
#if INDIRECT_INPUT0
uint a_idx = FUNC_CALL(get_input0_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (K_FULL_ITERATIONS * TILE_K + sglid), beam_table);
INPUT0_TYPE a_read = input0[a_idx];
#else
uint a_idx = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, (y + dot_id), (K_FULL_ITERATIONS * TILE_K + sglid));
INPUT0_TYPE a_read = BLOCK_READ_A(a_ptr, 0);
a_ptr += input0_offset;
#endif
INPUT0_TYPE a_read = input0[a_idx];

unroll_for (uint simd_id = 0; simd_id < TILE_K_LEFTOVER; simd_id++) {
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_id)), b_tile[simd_id], c_tile[dot_id]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
auto m_size = dims0.dims_sizes[input0_dims[6]];
auto n_size = dims1.dims_sizes[input1_dims[7]];
auto n_padded_size = "(" + dims1_padded.dims_sizes[input1_dims[7]] + ")";
auto n_odd = "(" + n_size + "%2) != 0";
auto k_size = dims0.dims_sizes[input0_dims[7]];
auto k_padded_size_in0 = "(" + dims0_padded.dims_sizes[input0_dims[7]] + ")";
auto k_odd = "(" + n_size + "%2) != 0";
const std::string leftover_m = "(" + m_size + "%" + std::to_string(tuning_data.tile_m_size) + ")";
const std::string leftover_n = "(" + n_size + "%" + std::to_string(tuning_data.tile_n_size) + ")";
const std::string leftover_k = "(" + k_size + "%" + std::to_string(tuning_data.tile_k_size) + ")";
Expand All @@ -141,6 +143,8 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
MakeJitConstant("N", n_size),
MakeJitConstant("K_PADDED_IN0", k_padded_size_in0),
MakeJitConstant("N_PADDED", n_padded_size),
MakeJitConstant("K_IS_ODD", k_odd),
MakeJitConstant("N_IS_ODD", n_odd),
MakeJitConstant("SIMD_WIDTH", tuning_data.simd_size),
MakeJitConstant("TILE_M", tuning_data.tile_m_size),
MakeJitConstant("TILE_K", tuning_data.tile_k_size),
Expand Down Expand Up @@ -219,6 +223,8 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
MakeJitConstant("N", n_size),
MakeJitConstant("K_PADDED_IN0", k_size),
MakeJitConstant("N_PADDED", n_size),
MakeJitConstant("K_IS_ODD", k_size % 2 != 0),
MakeJitConstant("N_IS_ODD", n_size % 2 != 0),
MakeJitConstant("SIMD_WIDTH", tuning_data.simd_size),
MakeJitConstant("TILE_M", tuning_data.tile_m_size),
MakeJitConstant("TILE_K", tuning_data.tile_k_size),
Expand Down

0 comments on commit 166a9df

Please sign in to comment.