diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 077c70af2f1b..3b3fdbc58a4b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -134,6 +134,13 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n"; + decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n"; + decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n"; + decl_stream << "#else\n"; + decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n"; + decl_stream << "#endif\n"; + decl_stream << "\n#ifdef _WIN32\n"; decl_stream << " using uint = unsigned int;\n"; decl_stream << " using uchar = unsigned char;\n"; diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index b5299b4e4b2a..feffc3d304ef 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, : "l"((void *)({smem_addr})) ); __asm__ __volatile__( - "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) ); } )"; @@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_elem_offset, const std::string& bytes, const std::string& predicate_value) { + CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" || + bytes == "1") + << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; std::string predicated_asm_code = R"( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)({smem_addr})) ); - int src_bytes = {pred_guard} ? {bytes} : 0; + int pred_guard = (int){pred_guard}; __asm__ __volatile__( - "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;" + #endif + " @!p {store_shared};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg} ); } )"; + auto [store_shared, nopreg] = [](const std::string& bytes) { + if (bytes == "16") + return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}", + "\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)"); + else if (bytes == "12") + return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)"); + else if (bytes == "8") + return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)"); + else if (bytes == "4") + return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "2") + return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "1") + return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)"); + else + return std::make_tuple("", ""); + }(bytes); + Replacer replacer; replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{store_shared}", store_shared); + replacer.register_rule("{nopreg}", nopreg); replacer.register_rule("{pred_guard}", predicate_value); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 2e3c906e89c1..a6685fe87b48 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -112,12 +112,41 @@ class PTXAsyncCopyInjector : public StmtMutator { } return PrimExpr(); }(); - if (src_offset.defined() && dst_offset.defined()) { return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } + } else { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by merging dynamic + // shared memory. + // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] + auto* add = store->indices[0].as(); + if (!add->a->IsInstance()) return PrimExpr(); + if (!add->b->IsInstance()) return PrimExpr(); + return tir::Add(add->a.as()->base, add->b.as()->value); + } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate( + Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); + } } } } diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index d899b6ec70ab..e1ec0f1572c7 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -43,17 +43,18 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} + // TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend Stmt VisitStmt_(const ForNode* loop) final { // if for loop is not within async_commit_queue_scope if (!async_queue_id_.has_value()) { return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); } - // if for loop is not a memcpy of a contiguous region + // if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || mem_copy->source->region.size() != 1) { - LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; + return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); } // now that we are about to perform the `copy` transform diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 04e7595abf37..efe4920ec0b2 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -879,23 +879,5 @@ def test_meta(hexagon_session): ) -def test_non_contiguous(): - """Test Non Contiguous memory lowering.""" - sch = tvm.tir.Schedule(conv2d_async_non_contig) - target_hexagon = tvm.target.hexagon("v68", link_params=True) - err_rgx = r"Unable to lower async dma due to non contiguous memory access" - # Currently we do not support non contiguous memory access being lowered to - # async dma so we throw an error. - with pytest.raises(tvm.TVMError, match=err_rgx): - with tvm.transform.PassContext( - config={ - "tir.use_async_copy": 1, - } - ): - tvm.build( - sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) - ) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 3d779bc7d114..168f8c879bcd 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -181,6 +181,13 @@ def test_inject_async_copy_shared_dyn(): expected_cuda_script = r""" +#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ + (__CUDACC_VER_MAJOR__ > 11)) +#define TVM_ENABLE_L2_PREFETCH 1 +#else +#define TVM_ENABLE_L2_PREFETCH 0 +#endif + #ifdef _WIN32 using uint = unsigned int; using uchar = unsigned char; @@ -210,8 +217,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(A_shared + (((int)threadIdx.x) + 16))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) ); } @@ -223,8 +234,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(B_shared + (((int)threadIdx.x) + 16))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -238,8 +253,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(A_shared + (((int)threadIdx.x) + 32))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) ); } @@ -251,8 +270,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(B_shared + (((int)threadIdx.x) + 32))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -263,14 +286,21 @@ def test_inject_async_copy_shared_dyn(): { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) ); - int src_bytes = cse_var_1 ? 4 : 0; + int pred_guard = (int)cse_var_1; __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.ca.shared.global [%1], [%2], %3;" + #endif + " @!p st.shared.u32 [%1], {%4};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(0) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -284,14 +314,21 @@ def test_inject_async_copy_shared_dyn(): { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) ); - int src_bytes = cse_var_1 ? 4 : 0; + int pred_guard = (int)cse_var_1; __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.ca.shared.global [%1], [%2], %3;" + #endif + " @!p st.shared.u32 [%1], {%4};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(0) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -385,7 +422,6 @@ def simple_compute( mod = tvm.IRModule.from_expr(simple_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.build(mod, target="cuda") - assert generated_code == expected_cuda_script if not support_async: @@ -393,7 +429,474 @@ def simple_compute( support_async = True +@tvm.testing.requires_cuda +def test_vectorize_cp_async_in_if_then_else(): + global support_async + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + support_async = False + + @T.prim_func + def complex_compute( + A: T.Buffer((2, 16, 16, 1280), "float16"), + W: T.Buffer((1280, 3, 3, 1280), "float16"), + Conv: T.Buffer((512, 1280), "float16"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + data_im2col_reindex_shared_dyn = T.alloc_buffer((512, 11520), "float16", scope="shared.dyn") + data_im2col_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer( + (512, 11520), "float16", scope="wmma.matrix_a" + ) + weight_flatten_reindex_shared_dyn = T.alloc_buffer( + (1280, 11520), "float16", scope="shared.dyn" + ) + weight_flatten_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer( + (1280, 11520), "float16", scope="wmma.matrix_b" + ) + Conv_reindex_wmma_accumulator = T.alloc_buffer( + (512, 1280), "float16", scope="wmma.accumulator" + ) + for x_0_0 in T.thread_binding(8, thread="blockIdx.y"): + for y_0_0 in T.thread_binding(20, thread="blockIdx.x"): + for x_0_1 in T.thread_binding(2, thread="threadIdx.y"): + for y_0_1 in T.thread_binding(2, thread="threadIdx.z"): + for x_0_2_init, y_0_2_init in T.grid(2, 2): + with T.block("Conv_init_o"): + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) + T.reads() + T.writes( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, v_y_o * 16 : v_y_o * 16 + 16 + ] + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, v_y_o * 16 : v_y_o * 16 + 16 + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + T.tvm_fill_fragment( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.float32(0), + ) + for k_0_0 in T.serial( + 180, + annotations={ + "software_pipeline_stage": [0, 0, 1], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_async_stages": [0], + }, + ): + for ax0_ax1_0_fused_0 in range(4): + for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_0_fused_2 in T.thread_binding( + 2, thread="threadIdx.y" + ): + for ax0_ax1_0_fused_3 in T.thread_binding( + 32, thread="threadIdx.x" + ): + with T.block("data_im2col_reindex_shared.dyn_o"): + v0 = T.axis.spatial( + 512, + x_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + // 8, + ) + v1_o = T.axis.spatial( + 1440, + k_0_0 * 8 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + % 8, + ) + T.reads( + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 : v1_o % 160 * 8 + 8, + ] + ) + T.writes( + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 : v1_o * 8 + 8 + ] + ) + for ax1_1 in T.vectorized(8): + with T.block("data_im2col_reindex_shared.dyn"): + v1_i = T.axis.spatial(8, ax1_1) + T.reads( + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 + v1_i, + ] + ) + T.writes( + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 + v1_i + ] + ) + T.block_attr( + {"buffer_dim_align": [[0, 0, 32, 8]]} + ) + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 + v1_i + ] = T.if_then_else( + 1 <= v1_o // 480 + v0 % 256 // 16 + and v1_o // 480 + v0 % 256 // 16 < 17 + and 1 <= v1_o % 480 // 160 + v0 % 16 + and v1_o % 480 // 160 + v0 % 16 < 17, + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 + v1_i, + ], + T.float16(0), + ) + for ax0_ax1_0_fused_0 in range(4): + for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_0_fused_2 in T.thread_binding( + 2, thread="threadIdx.y" + ): + for ax0_ax1_0_fused_3 in T.thread_binding( + 32, thread="threadIdx.x" + ): + for ax1_1 in T.vectorized(8): + with T.block("weight_flatten_reindex_shared.dyn"): + v0 = T.axis.spatial( + 1280, + y_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + // 8, + ) + v1 = T.axis.spatial( + 11520, + k_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + % 8 + * 8 + + ax1_1, + ) + T.reads( + W[ + v0, + v1 // 3840, + v1 % 3840 // 1280, + v1 % 1280, + ] + ) + T.writes( + weight_flatten_reindex_shared_dyn[v0, v1] + ) + T.block_attr( + {"buffer_dim_align": [[0, 0, 32, 8]]} + ) + weight_flatten_reindex_shared_dyn[v0, v1] = W[ + v0, + v1 // 1280 // 3, + v1 // 1280 % 3, + v1 % 1280, + ] + for k_0_1 in range(4): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("data_im2col_reindex_shared.dyn_wmma.matrix_a_o"): + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads( + data_im2col_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + T.writes( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + data_im2col_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="shared.dyn", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.matrix_a", + offset_factor=16, + ) + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + A_1.data, + A_1.elem_offset, + A_s0 * 16, + 1, + ), + A_s0, + "row_major", + ) + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block( + "weight_flatten_reindex_shared.dyn_wmma.matrix_b_o" + ): + v0_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads( + weight_flatten_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + T.writes( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + weight_flatten_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="shared.dyn", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.matrix_b", + offset_factor=16, + ) + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + A_1.data, + A_1.elem_offset, + A_s0 * 16, + 1, + ), + A_s0, + "col_major", + ) + for x_0_2, y_0_2 in T.grid(2, 2): + with T.block("Conv_update_o"): + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) + v_k_o = T.axis.reduce(720, k_0_0 * 4 + k_0_1) + T.reads( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ], + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v_x_o * 16 : v_x_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v_y_o * 16 : v_y_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + ) + T.writes( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v_x_o * 16 : v_x_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="wmma.matrix_a", + offset_factor=16, + ) + B_s0 = T.int32() + B_s1 = T.int32() + B = T.match_buffer( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v_y_o * 16 : v_y_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(B_s0, B_s1), + scope="wmma.matrix_b", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + T.tvm_mma_sync( + C.data, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + A_1.data, + A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + + A_1.elem_offset % A_s0 // 16, + B.data, + B.elem_offset // B_s0 // 16 * (B_s0 // 16) + + B.elem_offset % B_s0 // 16, + C.data, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + ) + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("Conv_reindex_wmma.accumulator_o"): + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) + T.reads( + Conv_reindex_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ] + ) + T.writes( + Conv[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16] + ) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer( + Conv[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], + (16, 16), + "float16", + strides=(C_s0, C_s1), + offset_factor=16, + ) + T.tvm_store_matrix_sync( + A_1.data, + 16, + 16, + 16, + A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + + A_1.elem_offset % A_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + C.data, + C.elem_offset, + C_s0 * 16, + 2, + ), + C_s0, + "row_major", + ) + + mod = tvm.IRModule.from_expr(complex_compute) + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + tvm.build(mod, target="cuda") + # generated_code must contain " setp.ne.b32 p, %0, 0;" + assert "setp.ne.b32" in generated_code + + if not support_async: + # avoid return dummy code to other tests + support_async = True + + if __name__ == "__main__": test_inject_async_copy() test_inject_async_copy_shared_dyn() test_cp_async_in_if_then_else() + test_vectorize_cp_async_in_if_then_else()