diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index cccf2c505a51..9f1ad1ffc437 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -58,13 +58,20 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) { + Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size, + DataType dtype = DataType::Int(32)) { ICHECK(permute_); // Index after vectorizing by 8 - PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR), - col_idx_inner = floormod(col_idx, VECTORIZE_FACTOR); + PrimExpr col_idx_outer = floordiv(col_idx, BANK_SIZE_BYTES / dtype.bits()), + col_idx_inner = floormod(col_idx, BANK_SIZE_BYTES / dtype.bits()); PrimExpr new_col_idx_outer; - if (row_size % 64 == 0) { + // use transaction bits to support diverse dtype. + // for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits + // for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits + int coalescent_bits = dtype.bits() * row_size; + // permutation on 4 banks, each bank has 32 bits + int bank_elems = BANK_SIZE_BYTES / dtype.bits(); + if (coalescent_bits % 1024 == 0) { // Use 8 * 8 permuted layout // Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read // Every row below corresponds to 32 banks @@ -76,10 +83,10 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { // 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 // 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 // 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 - auto row_idx_sub = floormod(row_idx, 8); + auto row_idx_sub = floormod(row_idx, bank_elems); new_col_idx_outer = col_idx_outer ^ row_idx_sub; } else { - ICHECK(row_size % 32 == 0); + ICHECK(coalescent_bits % 512 == 0); // Use 8 * 4 permuted layout // Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read // Every row below corresponds to 16 banks @@ -96,10 +103,12 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { // 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 // 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 // 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 - auto row_idx_sub = floormod(row_idx, 8); - new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, 2); + auto row_idx_sub = floormod(row_idx, bank_elems); + // Interleave elems per byte + int interleave_elems = 32 / dtype.bits(); + new_col_idx_outer = col_idx_outer ^ floordiv(row_idx_sub, interleave_elems); } - return {row_idx, analyzer_->Simplify(new_col_idx_outer * 8 + col_idx_inner)}; + return {row_idx, analyzer_->Simplify(new_col_idx_outer * bank_elems + col_idx_inner)}; } static bool CheckAnnotation(ObjectRef annotation) { @@ -162,14 +171,14 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return buffer_row_size; } - Array HandleBufferIndices(Buffer buffer, Array indices) { + Array HandleBufferIndices(Buffer buffer, Array indices, DataType dtype) { auto buffer_row_size = CheckAndGetBufferRowSize(buffer); // Mutate the last two indices auto indices_size = indices.size(); PrimExpr row_idx = indices[indices_size - 2]; PrimExpr col_idx = indices[indices_size - 1]; - auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size); + auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size, dtype); indices.Set(indices_size - 2, new_indices[0]); indices.Set(indices_size - 1, new_indices[1]); return indices; @@ -180,7 +189,6 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { // We assume the shape of the shared memory is [..., row_size, col_size], // where row_size is divisible by 64, or divisible by 32 and col_size is divisible by 2. auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); - if (!permute_ || store->buffer->shape.size() < 2) { return store; } @@ -191,7 +199,8 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } auto store_node = store.CopyOnWrite(); - store_node->indices = HandleBufferIndices(store_node->buffer, store_node->indices); + store_node->indices = + HandleBufferIndices(store_node->buffer, store_node->indices, store->buffer->dtype); return store; } @@ -209,11 +218,13 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } auto load_node = load.CopyOnWrite(); - load_node->indices = HandleBufferIndices(load_node->buffer, load_node->indices); + load_node->indices = + HandleBufferIndices(load_node->buffer, load_node->indices, load->buffer->dtype); return load; } - PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional offset = NullOpt) { + PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional offset = NullOpt, + DataType dtype = DataType::Int(32)) { // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to // smem_offset CHECK(access_ptr->IsInstance()) @@ -233,7 +244,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { PrimExpr row_idx = floordiv(smem_offset, buffer_row_size); PrimExpr col_idx = floormod(smem_offset, buffer_row_size); - auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size); + auto new_indices = PermuteIndices(row_idx, col_idx, buffer_row_size, dtype); auto new_offset = analyzer_->Simplify(new_indices[0] * buffer_row_size + new_indices[1]); auto new_access_ptr = access_ptr_call.CopyOnWrite(); @@ -258,7 +269,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) auto access_ptr = call->args[5]; PrimExpr smem_offset = call->args[6]; - auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset); + auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); auto new_call = call.CopyOnWrite(); new_call->args.Set(5, new_access_ptr); new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); @@ -267,7 +278,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { // TODO(yixin): mma_store is not fully tested yet // because we will directly store result to Buffer instead of calling mma_store now auto access_ptr = call->args[2]; - auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr); + auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype); auto new_call = call.CopyOnWrite(); new_call->args.Set(2, new_access_ptr); return call;