Skip to content

Commit

Permalink
[GPU] gemm_tile supports block read when leftover with 4byte-size ali…
Browse files Browse the repository at this point in the history
…gn in dynamic. (#24535)

* Use block read in 4byte aligned left-over case. Static was already
used (#23400) This PR
will apply in dynamic.



### Tickets:
 - *141032*

---------

Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback authored May 22, 2024
1 parent 19fd28b commit 7f4e766
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,13 @@ KERNEL(gemm_tiled_opt)(
#if B_VEC_SIZE == 1
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else // B_VEC_SIZE == 1
#if TILE_N_NOT_DIVISIBLE
if (TILE_N_NOT_DIVISIBLE == 0 || N_IS_ALIGNED_4BYTE)
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
else {
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
b_tile[b_load_id][b_elem] = b_ptr[sglid + SIMD_WIDTH * b_elem];
}
#else // TILE_N_NOT_DIVISIBLE
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
#endif // TILE_N_NOT_DIVISIBLE
}
#endif // B_VEC_SIZE == 1
b_ptr += input1_offset;
}
Expand Down Expand Up @@ -387,7 +387,7 @@ KERNEL(gemm_tiled_opt)(

// Loading A tile and tile C calculation
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING && TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
A_FLOATN a_read = TILE_K_NOT_DIVISIBLE ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0);
A_FLOATN a_read = (TILE_K_NOT_DIVISIBLE == 0 || K_IS_ALIGNED_4BYTE) ? BLOCK_READ_A(a_ptr, 0): a_ptr[sglid];
#endif
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
#if TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
Expand Down Expand Up @@ -433,7 +433,7 @@ KERNEL(gemm_tiled_opt)(
}
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING
// Read A for next dot_id
a_read = (dot_id + 1 < tile_m_iterations) ? TILE_K_NOT_DIVISIBLE ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0) : 0;
a_read = (dot_id + 1 < tile_m_iterations) ? (TILE_K_NOT_DIVISIBLE == 0 || K_IS_ALIGNED_4BYTE) ? BLOCK_READ_A(a_ptr, 0) : a_ptr[sglid] : 0;
#endif
#elif TRANSPOSE_INPUT0 == TRANSPOSE_OTHER // TRANSPOSE_INPUT0
#if INDIRECT_INPUT0
Expand Down Expand Up @@ -516,13 +516,13 @@ KERNEL(gemm_tiled_opt)(
#if B_VEC_SIZE == 1
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else // B_VEC_SIZE == 1
#if TILE_N_NOT_DIVISIBLE
if (TILE_N_NOT_DIVISIBLE == 0 || N_IS_ALIGNED_4BYTE)
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
else {
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
b_tile[b_load_id][b_elem] = b_ptr[sglid + SIMD_WIDTH * b_elem];
}
#else
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
#endif // TILE_N_NOT_DIVISIBLE
}
#endif // B_VEC_SIZE == 1
b_ptr += input1_offset;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,29 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
const std::string not_divisible_n = "(" + leftover_n + "!=0)";
const std::string not_divisible_k = "(" + leftover_k + "!=0)";
const std::string full_iteration_k = "(" + k_size + "/" + std::to_string(tuning_data.tile_k_size) + ")";
std::string n_aligned_4byte = "0";
std::string k_aligned_4byte = "0";
if (BytesPerElement(params.inputs[0].GetDType()) == 4 || BytesPerElement(params.inputs[0].GetDType()) == 8) {
n_aligned_4byte = "1";
k_aligned_4byte = "1";
} else {
auto bytes_per_element = std::to_string(BytesPerElement(params.inputs[0].GetDType()));
if (n_size.find("shape_info") == std::string::npos) {
n_aligned_4byte = "(" + n_size + "*" + bytes_per_element + " % 4 == 0)";
}
if (k_size.find("shape_info") == std::string::npos) {
k_aligned_4byte = "(" + k_size + "*" + bytes_per_element + " % 4 == 0)";
}
}

jit.AddConstants({
MakeJitConstant("M", m_size),
MakeJitConstant("K", k_size),
MakeJitConstant("N", n_size),
MakeJitConstant("K_PADDED_IN0", k_padded_size_in0),
MakeJitConstant("N_PADDED", n_padded_size),
MakeJitConstant("K_IS_ALIGNED_4BYTE", k_aligned_4byte),
MakeJitConstant("N_IS_ALIGNED_4BYTE", n_aligned_4byte),
MakeJitConstant("SIMD_WIDTH", tuning_data.simd_size),
MakeJitConstant("TILE_M", tuning_data.tile_m_size),
MakeJitConstant("TILE_K", tuning_data.tile_k_size),
Expand Down

0 comments on commit 7f4e766

Please sign in to comment.