From 44dd6445efdcd955d41e5ae2b21e73da55b7a738 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 4 Apr 2023 12:20:33 +0800 Subject: [PATCH] [TensorIR] Support for L2 prefetch async copy and pred_guard enabled async in vectorized if_then_else (#14329) This pull request adds support for the L2 prefetch option in the cp.async instruction, which is supported in CUDA 11.4 and later, with the support method referencing Cutlass. Additionally, this pull request adds support for asynchronous copying of if_then_else under vectorization, and fixes some bugs. for example, the original async cp can not support vectorized if_then_else, for a given template: ```python for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): with T.block("data_im2col_reindex_shared.dyn_o"): v0 = T.axis.spatial(512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) v1_o = T.axis.spatial(1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8) T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8]) T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8]) for ax1_1 in T.vectorized(8): with T.block("data_im2col_reindex_shared.dyn"): v1_i = T.axis.spatial(8, ax1_1) T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i]) T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i]) T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0)) ``` this pr will make the code into: ```c++ { unsigned int addr; __asm__ __volatile__( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_0_fused_0 * 2304) + (((int)threadIdx.z) * 1152)) + (((int)threadIdx.y) * 576)) + ((((int)threadIdx.x) >> 3) * 144)) + ((((int)threadIdx.x) & 7) * 16)))) ); int pred_guard = (int)((((1 <= ((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0)) && (((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0) < 17)) && (1 <= ((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)))) && (((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)) < 17)); __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" #if TVM_ENABLE_L2_PREFETCH " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" #else " @p cp.async.ca.shared.global [%1], [%2], %3;" #endif " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}" :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((((((((int)blockIdx.y) * 81920) + ((k_0_0 / 60) * 20480)) + (ax0_ax1_0_fused_0 * 20480)) + (((int)threadIdx.z) * 10240)) + (((int)threadIdx.y) * 5120)) + ((((int)threadIdx.x) >> 3) * 1280)) + ((k_0_0 % 60) * 64)) + ((((int)threadIdx.x) & 7) * 8)) - 21760))), "n"(16), "r"(0), "r"(0), "r"(0),"r"(0) ); } ``` while the original injectptxasync pass can only support for a serial, 4 bytes aligned assignment of if_then_else, and the current code also has some bugs: ```c++ std::string predicated_asm_code = R"( { unsigned int addr; __asm__ __volatile__( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" : "=r"(addr) : "l"((void *)({smem_addr})) ); int src_bytes = {pred_guard} ? {bytes} : 0; __asm__ __volatile__( "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) ); } )"; ``` if the condition is false, the code above will do nothing, however the right way is to assign zero to the shared memory address, or the value of memory is unpredictable, which will make the result uncorrected. we fixed by native shared memory copy asm: ```c++ int pred_guard = (int){pred_guard}; __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" #if TVM_ENABLE_L2_PREFETCH " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;" #else " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;" #endif " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}" :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(0), "r"(0), "r"(0),"r"(0) ); ``` Co-authored-by: leiwang1999 --- src/target/source/codegen_cuda.cc | 7 + src/target/source/ptx.cc | 46 +- src/tir/transforms/inject_ptx_async_copy.cc | 31 +- src/tir/transforms/lower_async_dma.cc | 5 +- .../test_hexagon/test_async_dma_pipeline.py | 18 - ...est_tir_transform_inject_ptx_async_copy.py | 537 +++++++++++++++++- 6 files changed, 600 insertions(+), 44 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 077c70af2f1b..3b3fdbc58a4b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -134,6 +134,13 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n"; + decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n"; + decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n"; + decl_stream << "#else\n"; + decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n"; + decl_stream << "#endif\n"; + decl_stream << "\n#ifdef _WIN32\n"; decl_stream << " using uint = unsigned int;\n"; decl_stream << " using uchar = unsigned char;\n"; diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index b5299b4e4b2a..feffc3d304ef 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, : "l"((void *)({smem_addr})) ); __asm__ __volatile__( - "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) ); } )"; @@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_elem_offset, const std::string& bytes, const std::string& predicate_value) { + CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" || + bytes == "1") + << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; std::string predicated_asm_code = R"( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)({smem_addr})) ); - int src_bytes = {pred_guard} ? {bytes} : 0; + int pred_guard = (int){pred_guard}; __asm__ __volatile__( - "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;" + #endif + " @!p {store_shared};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg} ); } )"; + auto [store_shared, nopreg] = [](const std::string& bytes) { + if (bytes == "16") + return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}", + "\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)"); + else if (bytes == "12") + return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)"); + else if (bytes == "8") + return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)"); + else if (bytes == "4") + return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "2") + return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "1") + return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)"); + else + return std::make_tuple("", ""); + }(bytes); + Replacer replacer; replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{store_shared}", store_shared); + replacer.register_rule("{nopreg}", nopreg); replacer.register_rule("{pred_guard}", predicate_value); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 2e3c906e89c1..a6685fe87b48 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -112,12 +112,41 @@ class PTXAsyncCopyInjector : public StmtMutator { } return PrimExpr(); }(); - if (src_offset.defined() && dst_offset.defined()) { return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } + } else { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by merging dynamic + // shared memory. + // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] + auto* add = store->indices[0].as(); + if (!add->a->IsInstance()) return PrimExpr(); + if (!add->b->IsInstance()) return PrimExpr(); + return tir::Add(add->a.as()->base, add->b.as()->value); + } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate( + Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); + } } } } diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index d899b6ec70ab..e1ec0f1572c7 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -43,17 +43,18 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} + // TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend Stmt VisitStmt_(const ForNode* loop) final { // if for loop is not within async_commit_queue_scope if (!async_queue_id_.has_value()) { return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); } - // if for loop is not a memcpy of a contiguous region + // if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || mem_copy->source->region.size() != 1) { - LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; + return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); } // now that we are about to perform the `copy` transform diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 04e7595abf37..efe4920ec0b2 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -879,23 +879,5 @@ def test_meta(hexagon_session): ) -def test_non_contiguous(): - """Test Non Contiguous memory lowering.""" - sch = tvm.tir.Schedule(conv2d_async_non_contig) - target_hexagon = tvm.target.hexagon("v68", link_params=True) - err_rgx = r"Unable to lower async dma due to non contiguous memory access" - # Currently we do not support non contiguous memory access being lowered to - # async dma so we throw an error. - with pytest.raises(tvm.TVMError, match=err_rgx): - with tvm.transform.PassContext( - config={ - "tir.use_async_copy": 1, - } - ): - tvm.build( - sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) - ) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 3d779bc7d114..168f8c879bcd 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -181,6 +181,13 @@ def test_inject_async_copy_shared_dyn(): expected_cuda_script = r""" +#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ + (__CUDACC_VER_MAJOR__ > 11)) +#define TVM_ENABLE_L2_PREFETCH 1 +#else +#define TVM_ENABLE_L2_PREFETCH 0 +#endif + #ifdef _WIN32 using uint = unsigned int; using uchar = unsigned char; @@ -210,8 +217,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(A_shared + (((int)threadIdx.x) + 16))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) ); } @@ -223,8 +234,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(B_shared + (((int)threadIdx.x) + 16))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -238,8 +253,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(A_shared + (((int)threadIdx.x) + 32))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) ); } @@ -251,8 +270,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(B_shared + (((int)threadIdx.x) + 32))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -263,14 +286,21 @@ def test_inject_async_copy_shared_dyn(): { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) ); - int src_bytes = cse_var_1 ? 4 : 0; + int pred_guard = (int)cse_var_1; __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.ca.shared.global [%1], [%2], %3;" + #endif + " @!p st.shared.u32 [%1], {%4};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(0) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -284,14 +314,21 @@ def test_inject_async_copy_shared_dyn(): { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) ); - int src_bytes = cse_var_1 ? 4 : 0; + int pred_guard = (int)cse_var_1; __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.ca.shared.global [%1], [%2], %3;" + #endif + " @!p st.shared.u32 [%1], {%4};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(0) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -385,7 +422,6 @@ def simple_compute( mod = tvm.IRModule.from_expr(simple_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.build(mod, target="cuda") - assert generated_code == expected_cuda_script if not support_async: @@ -393,7 +429,474 @@ def simple_compute( support_async = True +@tvm.testing.requires_cuda +def test_vectorize_cp_async_in_if_then_else(): + global support_async + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + support_async = False + + @T.prim_func + def complex_compute( + A: T.Buffer((2, 16, 16, 1280), "float16"), + W: T.Buffer((1280, 3, 3, 1280), "float16"), + Conv: T.Buffer((512, 1280), "float16"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + data_im2col_reindex_shared_dyn = T.alloc_buffer((512, 11520), "float16", scope="shared.dyn") + data_im2col_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer( + (512, 11520), "float16", scope="wmma.matrix_a" + ) + weight_flatten_reindex_shared_dyn = T.alloc_buffer( + (1280, 11520), "float16", scope="shared.dyn" + ) + weight_flatten_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer( + (1280, 11520), "float16", scope="wmma.matrix_b" + ) + Conv_reindex_wmma_accumulator = T.alloc_buffer( + (512, 1280), "float16", scope="wmma.accumulator" + ) + for x_0_0 in T.thread_binding(8, thread="blockIdx.y"): + for y_0_0 in T.thread_binding(20, thread="blockIdx.x"): + for x_0_1 in T.thread_binding(2, thread="threadIdx.y"): + for y_0_1 in T.thread_binding(2, thread="threadIdx.z"): + for x_0_2_init, y_0_2_init in T.grid(2, 2): + with T.block("Conv_init_o"): + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) + T.reads() + T.writes( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, v_y_o * 16 : v_y_o * 16 + 16 + ] + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, v_y_o * 16 : v_y_o * 16 + 16 + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + T.tvm_fill_fragment( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.float32(0), + ) + for k_0_0 in T.serial( + 180, + annotations={ + "software_pipeline_stage": [0, 0, 1], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_async_stages": [0], + }, + ): + for ax0_ax1_0_fused_0 in range(4): + for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_0_fused_2 in T.thread_binding( + 2, thread="threadIdx.y" + ): + for ax0_ax1_0_fused_3 in T.thread_binding( + 32, thread="threadIdx.x" + ): + with T.block("data_im2col_reindex_shared.dyn_o"): + v0 = T.axis.spatial( + 512, + x_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + // 8, + ) + v1_o = T.axis.spatial( + 1440, + k_0_0 * 8 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + % 8, + ) + T.reads( + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 : v1_o % 160 * 8 + 8, + ] + ) + T.writes( + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 : v1_o * 8 + 8 + ] + ) + for ax1_1 in T.vectorized(8): + with T.block("data_im2col_reindex_shared.dyn"): + v1_i = T.axis.spatial(8, ax1_1) + T.reads( + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 + v1_i, + ] + ) + T.writes( + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 + v1_i + ] + ) + T.block_attr( + {"buffer_dim_align": [[0, 0, 32, 8]]} + ) + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 + v1_i + ] = T.if_then_else( + 1 <= v1_o // 480 + v0 % 256 // 16 + and v1_o // 480 + v0 % 256 // 16 < 17 + and 1 <= v1_o % 480 // 160 + v0 % 16 + and v1_o % 480 // 160 + v0 % 16 < 17, + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 + v1_i, + ], + T.float16(0), + ) + for ax0_ax1_0_fused_0 in range(4): + for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_0_fused_2 in T.thread_binding( + 2, thread="threadIdx.y" + ): + for ax0_ax1_0_fused_3 in T.thread_binding( + 32, thread="threadIdx.x" + ): + for ax1_1 in T.vectorized(8): + with T.block("weight_flatten_reindex_shared.dyn"): + v0 = T.axis.spatial( + 1280, + y_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + // 8, + ) + v1 = T.axis.spatial( + 11520, + k_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + % 8 + * 8 + + ax1_1, + ) + T.reads( + W[ + v0, + v1 // 3840, + v1 % 3840 // 1280, + v1 % 1280, + ] + ) + T.writes( + weight_flatten_reindex_shared_dyn[v0, v1] + ) + T.block_attr( + {"buffer_dim_align": [[0, 0, 32, 8]]} + ) + weight_flatten_reindex_shared_dyn[v0, v1] = W[ + v0, + v1 // 1280 // 3, + v1 // 1280 % 3, + v1 % 1280, + ] + for k_0_1 in range(4): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("data_im2col_reindex_shared.dyn_wmma.matrix_a_o"): + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads( + data_im2col_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + T.writes( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + data_im2col_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="shared.dyn", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.matrix_a", + offset_factor=16, + ) + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + A_1.data, + A_1.elem_offset, + A_s0 * 16, + 1, + ), + A_s0, + "row_major", + ) + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block( + "weight_flatten_reindex_shared.dyn_wmma.matrix_b_o" + ): + v0_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads( + weight_flatten_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + T.writes( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + weight_flatten_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="shared.dyn", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.matrix_b", + offset_factor=16, + ) + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + A_1.data, + A_1.elem_offset, + A_s0 * 16, + 1, + ), + A_s0, + "col_major", + ) + for x_0_2, y_0_2 in T.grid(2, 2): + with T.block("Conv_update_o"): + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) + v_k_o = T.axis.reduce(720, k_0_0 * 4 + k_0_1) + T.reads( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ], + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v_x_o * 16 : v_x_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v_y_o * 16 : v_y_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + ) + T.writes( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v_x_o * 16 : v_x_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="wmma.matrix_a", + offset_factor=16, + ) + B_s0 = T.int32() + B_s1 = T.int32() + B = T.match_buffer( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v_y_o * 16 : v_y_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(B_s0, B_s1), + scope="wmma.matrix_b", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + T.tvm_mma_sync( + C.data, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + A_1.data, + A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + + A_1.elem_offset % A_s0 // 16, + B.data, + B.elem_offset // B_s0 // 16 * (B_s0 // 16) + + B.elem_offset % B_s0 // 16, + C.data, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + ) + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("Conv_reindex_wmma.accumulator_o"): + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) + T.reads( + Conv_reindex_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ] + ) + T.writes( + Conv[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + Conv[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], + (16, 16), + "float16", + strides=(C_s0, C_s1), + offset_factor=16, + ) + T.tvm_store_matrix_sync( + A_1.data, + 16, + 16, + 16, + A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + + A_1.elem_offset % A_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + C.data, + C.elem_offset, + C_s0 * 16, + 2, + ), + C_s0, + "row_major", + ) + + mod = tvm.IRModule.from_expr(complex_compute) + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + tvm.build(mod, target="cuda") + # generated_code must contain " setp.ne.b32 p, %0, 0;" + assert "setp.ne.b32" in generated_code + + if not support_async: + # avoid return dummy code to other tests + support_async = True + + if __name__ == "__main__": test_inject_async_copy() test_inject_async_copy_shared_dyn() test_cp_async_in_if_then_else() + test_vectorize_cp_async_in_if_then_else()