Skip to content

Commit

Permalink
Avoid writing unnecesssary output buffer for fc_tiled kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
riverlijunjie committed Jun 15, 2024
1 parent 42f7d3a commit bf874ad
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ inline void FUNC(fc_bf_tiled_kernel_default)(
uint feature_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV) % (CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV);
uint batch_mega_block = gid / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV);

uint left_batch = TILE_B;
uint last_batch_mega_block = (get_num_groups(0) - 1) / (DISPATCH_FSV * DISPATCH_BSV * CEIL_DIV(TILE_OUT_F_NUM, TILE_OFM * SIMD) / DISPATCH_FSV);
uint last_batch_mini_block = (get_num_groups(0) - 1) / DISPATCH_FSV % DISPATCH_BSV;
if (batch_mega_block == last_batch_mega_block && batch_mini_block == last_batch_mini_block){
left_batch = INPUT0_BATCH_NUM % TILE_B;
}
#if USE_SLM
uint out_f = gid * (TILE_OFM * SIMD);
uint out_b = LWS_BATCHES * TILE_B * (uint)get_group_id(2) + local_id * TILE_B;
Expand Down Expand Up @@ -641,7 +647,13 @@ inline void FUNC(fc_bf_tiled_kernel_default)(
output_offset += TILE_OUT_B_PITCH; \
} while (false)
#endif
CONST_LOOP(TILE_B, WRITE_OUTPUT);
if(left_batch == TILE_B) {
CONST_LOOP(TILE_B, WRITE_OUTPUT);
} else {
unroll_for (uint bi = 0; bi < left_batch; ++bi) {
CONST_LOOP(1, WRITE_OUTPUT);
}
}
#undef WRITE_OUTPUT
} else {
output_offset += sglid;
Expand All @@ -667,7 +679,7 @@ inline void FUNC(fc_bf_tiled_kernel_default)(
//#undef WRITE_OUTPUT
//#undef WRITE_OUTPUT_FEATURE

for (uint bi = 0; bi < TILE_B; ++bi) {
for (uint bi = 0; bi < left_batch; ++bi) {
for (uint fi = 0; fi < TILE_OFM; ++fi) {
const bool should_write =
#if IS_DYNAMIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,12 @@ bool TuneParamsSelector::VerifyTuneParams(const fully_connected_params& params,

auto batch_size = params.is_shape_agnostic ? Align(output_b, tparams.tile_b) : output_b;
if (batch_size % (tparams.tile_b * tparams.dispatch_bsv) != 0) {
bool is_odd = (output_b > 1) && (output_b % 2 == 1);
// If batch_size is unaligned with tile_b, params can use dispatch_bsv==1 in case of odd batch_size to avoid
// performance drop.
if ((tparams.dispatch_bsv != 1) || !is_odd)
size_t tile = simd;
while (batch_size % tile != 0)
tile--;
if ((tparams.dispatch_bsv != 1) || (tile > 1) || batch_size == 1)
return false;
std::cout << "batch_size = " << batch_size << ", output_b = " << output_b << std::endl << std::endl;
}

if (CeilDiv(output_f, tparams.tile_ofm * simd) % tparams.dispatch_fsv != 0)
Expand Down

0 comments on commit bf874ad

Please sign in to comment.