Skip to content

Commit

Permalink
Fix llm perf regression.
Browse files Browse the repository at this point in the history
Because of additional mul op, some regression was occured in lln.
Minimize calculating 4byte aligned check using leftover constant.

Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback committed May 21, 2024
1 parent 47ff903 commit 503bc3e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ KERNEL(gemm_tiled_opt)(
#if B_VEC_SIZE == 1
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else // B_VEC_SIZE == 1
if (N_IS_ALIGNED_4BYTE)
if (TILE_N_NOT_DIVISIBLE == 0 || N_IS_ALIGNED_4BYTE)
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
else {
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
Expand Down Expand Up @@ -387,7 +387,7 @@ KERNEL(gemm_tiled_opt)(

// Loading A tile and tile C calculation
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING && TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
A_FLOATN a_read = K_IS_ALIGNED_4BYTE ? BLOCK_READ_A(a_ptr, 0): a_ptr[sglid];
A_FLOATN a_read = (TILE_K_NOT_DIVISIBLE == 0 || K_IS_ALIGNED_4BYTE) ? BLOCK_READ_A(a_ptr, 0): a_ptr[sglid];
#endif
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
#if TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
Expand Down Expand Up @@ -433,7 +433,7 @@ KERNEL(gemm_tiled_opt)(
}
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING
// Read A for next dot_id
a_read = (dot_id + 1 < tile_m_iterations) ? K_IS_ALIGNED_4BYTE ? BLOCK_READ_A(a_ptr, 0) : a_ptr[sglid] : 0;
a_read = (dot_id + 1 < tile_m_iterations) ? (TILE_K_NOT_DIVISIBLE == 0 || K_IS_ALIGNED_4BYTE) ? BLOCK_READ_A(a_ptr, 0) : a_ptr[sglid] : 0;
#endif
#elif TRANSPOSE_INPUT0 == TRANSPOSE_OTHER // TRANSPOSE_INPUT0
#if INDIRECT_INPUT0
Expand Down Expand Up @@ -516,7 +516,7 @@ KERNEL(gemm_tiled_opt)(
#if B_VEC_SIZE == 1
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else // B_VEC_SIZE == 1
if (N_IS_ALIGNED_4BYTE)
if (TILE_N_NOT_DIVISIBLE == 0 || N_IS_ALIGNED_4BYTE)
b_tile[b_load_id] = BLOCK_READ_B(b_ptr, 0);
else {
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,20 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
const std::string not_divisible_n = "(" + leftover_n + "!=0)";
const std::string not_divisible_k = "(" + leftover_k + "!=0)";
const std::string full_iteration_k = "(" + k_size + "/" + std::to_string(tuning_data.tile_k_size) + ")";
auto bytes_per_element = std::to_string(BytesPerElement(params.inputs[0].GetDType()));
auto n_aligned_4byte = "(" + n_size + "*" + bytes_per_element + " % 4 == 0)";
auto k_aligned_4byte = "(" + k_size + "*" + bytes_per_element + " % 4 == 0)";
std::string n_aligned_4byte = "0";
std::string k_aligned_4byte = "0";
if (BytesPerElement(params.inputs[0].GetDType()) == 4 || BytesPerElement(params.inputs[0].GetDType()) == 8) {
n_aligned_4byte = "1";
k_aligned_4byte = "1";
} else {
auto bytes_per_element = std::to_string(BytesPerElement(params.inputs[0].GetDType()));
if (n_size.find("shape_info") == std::string::npos) {
n_aligned_4byte = "(" + n_size + "*" + bytes_per_element + " % 4 == 0)";
}
if (k_size.find("shape_info") == std::string::npos) {
k_aligned_4byte = "(" + k_size + "*" + bytes_per_element + " % 4 == 0)";
}
}

jit.AddConstants({
MakeJitConstant("M", m_size),
Expand Down

0 comments on commit 503bc3e

Please sign in to comment.