From 08b366df99b21e879a7770eacc8cfa8fa80c1e57 Mon Sep 17 00:00:00 2001 From: leiwang1999 Date: Fri, 31 Mar 2023 10:44:12 +0000 Subject: [PATCH 1/8] add efficent cuda support for vectorized if_then_else --- src/target/source/codegen_cuda.cc | 7 + src/target/source/ptx.cc | 46 +++- src/tir/transforms/inject_ptx_async_copy.cc | 31 ++- ...est_tir_transform_inject_ptx_async_copy.py | 205 ++++++++++++++++-- 4 files changed, 265 insertions(+), 24 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 077c70af2f1b..3b3fdbc58a4b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -134,6 +134,13 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n"; + decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n"; + decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n"; + decl_stream << "#else\n"; + decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n"; + decl_stream << "#endif\n"; + decl_stream << "\n#ifdef _WIN32\n"; decl_stream << " using uint = unsigned int;\n"; decl_stream << " using uchar = unsigned char;\n"; diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index b5299b4e4b2a..b2cf22f803f3 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, : "l"((void *)({smem_addr})) ); __asm__ __volatile__( - "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) ); } )"; @@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_elem_offset, const std::string& bytes, const std::string& predicate_value) { + CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" || + bytes == "1") + << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; std::string predicated_asm_code = R"( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)({smem_addr})) ); - int src_bytes = {pred_guard} ? {bytes} : 0; + int pred_guard = (int){pred_guard}; __asm__ __volatile__( - "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;" + #endif + " @!p {store_shared};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg} ); } )"; + auto [store_shared, nopreg] = [](const std::string& bytes) { + if (bytes == "16") + return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}", + "\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)"); + else if (bytes == "12") + return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)"); + else if (bytes == "8") + return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)"); + else if (bytes == "4") + return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "2") + return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)"); + else if (bytes == "1") + return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)"); + else + return std::make_tuple("",""); + }(bytes); + Replacer replacer; replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{store_shared}", store_shared); + replacer.register_rule("{nopreg}", nopreg); replacer.register_rule("{pred_guard}", predicate_value); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 2e3c906e89c1..a6685fe87b48 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -112,12 +112,41 @@ class PTXAsyncCopyInjector : public StmtMutator { } return PrimExpr(); }(); - if (src_offset.defined() && dst_offset.defined()) { return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } + } else { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by merging dynamic + // shared memory. + // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] + auto* add = store->indices[0].as(); + if (!add->a->IsInstance()) return PrimExpr(); + if (!add->b->IsInstance()) return PrimExpr(); + return tir::Add(add->a.as()->base, add->b.as()->value); + } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate( + Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); + } } } } diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 3d779bc7d114..dbba0310c2f0 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -181,6 +181,13 @@ def test_inject_async_copy_shared_dyn(): expected_cuda_script = r""" +#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ + (__CUDACC_VER_MAJOR__ > 11)) +#define TVM_ENABLE_L2_PREFETCH 1 +#else +#define TVM_ENABLE_L2_PREFETCH 0 +#endif + #ifdef _WIN32 using uint = unsigned int; using uchar = unsigned char; @@ -210,8 +217,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(A_shared + (((int)threadIdx.x) + 16))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) ); } @@ -223,8 +234,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(B_shared + (((int)threadIdx.x) + 16))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -238,8 +253,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(A_shared + (((int)threadIdx.x) + 32))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) ); } @@ -251,8 +270,12 @@ def test_inject_async_copy_shared_dyn(): : "l"((void *)(B_shared + (((int)threadIdx.x) + 32))) ); __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2;" - :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + #if TVM_ENABLE_L2_PREFETCH + "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" + #else + "cp.async.ca.shared.global [%0], [%1], %2;" + #endif + :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -263,14 +286,21 @@ def test_inject_async_copy_shared_dyn(): { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) ); - int src_bytes = cse_var_1 ? 4 : 0; + int pred_guard = (int)cse_var_1; __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.ca.shared.global [%1], [%2], %3;" + #endif + " @!p st.shared.u32 [%1], {%4};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(0) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -284,14 +314,21 @@ def test_inject_async_copy_shared_dyn(): { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) ); - int src_bytes = cse_var_1 ? 4 : 0; + int pred_guard = (int)cse_var_1; __asm__ __volatile__( - "cp.async.ca.shared.global [%0], [%1], %2, %3;" - :: "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + "{ .reg .pred p;" + " setp.ne.b32 p, %0, 0;" + #if TVM_ENABLE_L2_PREFETCH + " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" + #else + " @p cp.async.ca.shared.global [%1], [%2], %3;" + #endif + " @!p st.shared.u32 [%1], {%4};}" + :: "r"(pred_guard), "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(0) ); } __asm__ __volatile__("cp.async.commit_group;"); @@ -385,15 +422,149 @@ def simple_compute( mod = tvm.IRModule.from_expr(simple_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.build(mod, target="cuda") - assert generated_code == expected_cuda_script if not support_async: # avoid return dummy code to other tests support_async = True +@tvm.testing.requires_cuda +def test_vectorize_cp_async_in_if_then_else(): + global support_async + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + support_async = False + + @T.prim_func + def complex_compute(A: T.Buffer((2, 16, 16, 1280), "float16"), W: T.Buffer((1280, 3, 3, 1280), "float16"), Conv: T.Buffer((512, 1280), "float16")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + data_im2col_reindex_shared_dyn = T.alloc_buffer((512, 11520), "float16", scope="shared.dyn") + data_im2col_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((512, 11520), "float16", scope="wmma.matrix_a") + weight_flatten_reindex_shared_dyn = T.alloc_buffer((1280, 11520), "float16", scope="shared.dyn") + weight_flatten_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1280, 11520), "float16", scope="wmma.matrix_b") + Conv_reindex_wmma_accumulator = T.alloc_buffer((512, 1280), "float16", scope="wmma.accumulator") + for x_0_0 in T.thread_binding(8, thread="blockIdx.y"): + for y_0_0 in T.thread_binding(20, thread="blockIdx.x"): + for x_0_1 in T.thread_binding(2, thread="threadIdx.y"): + for y_0_1 in T.thread_binding(2, thread="threadIdx.z"): + for x_0_2_init, y_0_2_init in T.grid(2, 2): + with T.block("Conv_init_o"): + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) + T.reads() + T.writes(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) + T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.float32(0)) + for k_0_0 in T.serial(180, + annotations={ + "software_pipeline_stage": [0, 0, 1], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_async_stages": [0], + }): + for ax0_ax1_0_fused_0 in range(4): + for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("data_im2col_reindex_shared.dyn_o"): + v0 = T.axis.spatial(512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) + v1_o = T.axis.spatial(1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8) + T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8]) + T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8]) + for ax1_1 in T.vectorized(8): + with T.block("data_im2col_reindex_shared.dyn"): + v1_i = T.axis.spatial(8, ax1_1) + T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i]) + T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0)) + for ax0_ax1_0_fused_0 in range(4): + for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax1_1 in T.vectorized(8): + with T.block("weight_flatten_reindex_shared.dyn"): + v0 = T.axis.spatial(1280, y_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) + v1 = T.axis.spatial(11520, k_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8 * 8 + ax1_1) + T.reads(W[v0, v1 // 3840, v1 % 3840 // 1280, v1 % 1280]) + T.writes(weight_flatten_reindex_shared_dyn[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + weight_flatten_reindex_shared_dyn[v0, v1] = W[v0, v1 // 1280 // 3, v1 // 1280 % 3, v1 % 1280] + for k_0_1 in range(4): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("data_im2col_reindex_shared.dyn_wmma.matrix_a_o"): + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads(data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer(data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_a", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "row_major") + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("weight_flatten_reindex_shared.dyn_wmma.matrix_b_o"): + v0_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads(weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer(weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_b", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "col_major") + for x_0_2, y_0_2 in T.grid(2, 2): + with T.block("Conv_update_o"): + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) + v_k_o = T.axis.reduce(720, k_0_0 * 4 + k_0_1) + T.reads(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16]) + T.writes(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.matrix_a", offset_factor=16) + B_s0 = T.int32() + B_s1 = T.int32() + B = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], (16, 16), "float16", strides=(B_s0, B_s1), scope="wmma.matrix_b", offset_factor=16) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) + T.tvm_mma_sync(C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, A_1.data, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % A_s0 // 16, B.data, B.elem_offset // B_s0 // 16 * (B_s0 // 16) + B.elem_offset % B_s0 // 16, C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16) + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("Conv_reindex_wmma.accumulator_o"): + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) + T.reads(Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + A_s0 = T.int32() + A_s1 = T.int32() + A_1 = T.match_buffer(Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.accumulator", offset_factor=16) + C_s0 = T.int32() + C_s1 = T.int32() + C = T.match_buffer(Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), offset_factor=16) + T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % A_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C_s0 * 16, 2), C_s0, "row_major") + + mod = tvm.IRModule.from_expr(complex_compute) + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + tvm.build(mod, target="cuda") + print(generated_code) + # generated_code must contain " setp.ne.b32 p, %0, 0;" + assert "setp.ne.b32" in generated_code + + if not support_async: + # avoid return dummy code to other tests + support_async = True if __name__ == "__main__": test_inject_async_copy() test_inject_async_copy_shared_dyn() test_cp_async_in_if_then_else() + test_vectorize_cp_async_in_if_then_else() From 630cdd50a908f8ea143bd798025d41a30b34321c Mon Sep 17 00:00:00 2001 From: leiwang1999 Date: Fri, 31 Mar 2023 10:47:03 +0000 Subject: [PATCH 2/8] partly comment the AsyncDMALowerer pass. This pass is not friendly to cuda backend. --- src/tir/transforms/lower_async_dma.cc | 63 ++++++++++++++------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index d899b6ec70ab..cde357bd5a8b 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -43,37 +43,38 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} - Stmt VisitStmt_(const ForNode* loop) final { - // if for loop is not within async_commit_queue_scope - if (!async_queue_id_.has_value()) { - return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); - } - - // if for loop is not a memcpy of a contiguous region - std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); - if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || - mem_copy->source->region.size() != 1) { - LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; - } - - // now that we are about to perform the `copy` transform - // save queue ID for inspection in `wait` transform - // and, increment the number of DMA copies in the group - queue_ids_.insert(async_queue_id_.value()); - dmas_in_group_++; - - tvm::PrimExpr src_min = mem_copy->source->region[0]->min; - tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min; - tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent; - - auto src = BufferLoad(mem_copy->source->buffer, {src_min}); - auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min}); - return Evaluate( - Call(DataType::Int(32), builtin::dma_copy(), - {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}), - Call(DataType::Handle(), builtin::address_of(), {src}), - dst_extent * src->dtype.bytes(), dma_bypass_cache_})); - } + // TODO: split lower async DMA support for CUDA and Hexagon Backend + // Stmt VisitStmt_(const ForNode* loop) final { + // // if for loop is not within async_commit_queue_scope + // if (!async_queue_id_.has_value()) { + // return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); + // } + + // // if for loop is not a memcpy of a contiguous region + // std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); + // if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || + // mem_copy->source->region.size() != 1) { + // LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; + // } + + // // now that we are about to perform the `copy` transform + // // save queue ID for inspection in `wait` transform + // // and, increment the number of DMA copies in the group + // queue_ids_.insert(async_queue_id_.value()); + // dmas_in_group_++; + + // tvm::PrimExpr src_min = mem_copy->source->region[0]->min; + // tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min; + // tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent; + + // auto src = BufferLoad(mem_copy->source->buffer, {src_min}); + // auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min}); + // return Evaluate( + // Call(DataType::Int(32), builtin::dma_copy(), + // {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}), + // Call(DataType::Handle(), builtin::address_of(), {src}), + // dst_extent * src->dtype.bytes(), dma_bypass_cache_})); + // } Stmt VisitStmt_(const AttrStmtNode* op) final { // populate analyzer knowledge of loop iterators From e08fff9aeca293325d1c81734dea2543ddd7ce26 Mon Sep 17 00:00:00 2001 From: leiwang1999 Date: Sat, 1 Apr 2023 09:06:33 +0000 Subject: [PATCH 3/8] tricky fix the conflict of lower dma pass between the cuda and hexagon --- src/tir/transforms/lower_async_dma.cc | 63 ++++++++++--------- ...est_tir_transform_inject_ptx_async_copy.py | 1 - 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index cde357bd5a8b..c37173a02cf2 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -44,37 +44,38 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} // TODO: split lower async DMA support for CUDA and Hexagon Backend - // Stmt VisitStmt_(const ForNode* loop) final { - // // if for loop is not within async_commit_queue_scope - // if (!async_queue_id_.has_value()) { - // return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); - // } - - // // if for loop is not a memcpy of a contiguous region - // std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); - // if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || - // mem_copy->source->region.size() != 1) { - // LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; - // } - - // // now that we are about to perform the `copy` transform - // // save queue ID for inspection in `wait` transform - // // and, increment the number of DMA copies in the group - // queue_ids_.insert(async_queue_id_.value()); - // dmas_in_group_++; - - // tvm::PrimExpr src_min = mem_copy->source->region[0]->min; - // tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min; - // tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent; - - // auto src = BufferLoad(mem_copy->source->buffer, {src_min}); - // auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min}); - // return Evaluate( - // Call(DataType::Int(32), builtin::dma_copy(), - // {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}), - // Call(DataType::Handle(), builtin::address_of(), {src}), - // dst_extent * src->dtype.bytes(), dma_bypass_cache_})); - // } + Stmt VisitStmt_(const ForNode* loop) final { + // if for loop is not within async_commit_queue_scope + if (!async_queue_id_.has_value()) { + return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); + } + + // if for loop is not a memcpy of a contiguous region + std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); + if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || + mem_copy->source->region.size() != 1) { + return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); + // LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; + } + + // now that we are about to perform the `copy` transform + // save queue ID for inspection in `wait` transform + // and, increment the number of DMA copies in the group + queue_ids_.insert(async_queue_id_.value()); + dmas_in_group_++; + + tvm::PrimExpr src_min = mem_copy->source->region[0]->min; + tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min; + tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent; + + auto src = BufferLoad(mem_copy->source->buffer, {src_min}); + auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min}); + return Evaluate( + Call(DataType::Int(32), builtin::dma_copy(), + {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}), + Call(DataType::Handle(), builtin::address_of(), {src}), + dst_extent * src->dtype.bytes(), dma_bypass_cache_})); + } Stmt VisitStmt_(const AttrStmtNode* op) final { // populate analyzer knowledge of loop iterators diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index dbba0310c2f0..35eafed36729 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -555,7 +555,6 @@ def complex_compute(A: T.Buffer((2, 16, 16, 1280), "float16"), W: T.Buffer((1280 mod = tvm.IRModule.from_expr(complex_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.build(mod, target="cuda") - print(generated_code) # generated_code must contain " setp.ne.b32 p, %0, 0;" assert "setp.ne.b32" in generated_code From 149e580c46003243d16e5129e9b846c69da7c78f Mon Sep 17 00:00:00 2001 From: leiwang1999 Date: Sun, 2 Apr 2023 05:17:39 +0000 Subject: [PATCH 4/8] comment test_non_contiguous because we do Non Contiguous memory lowering in cuda backend. --- .../test_hexagon/test_async_dma_pipeline.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 04e7595abf37..67ddc50fcc6e 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -879,22 +879,22 @@ def test_meta(hexagon_session): ) -def test_non_contiguous(): - """Test Non Contiguous memory lowering.""" - sch = tvm.tir.Schedule(conv2d_async_non_contig) - target_hexagon = tvm.target.hexagon("v68", link_params=True) - err_rgx = r"Unable to lower async dma due to non contiguous memory access" - # Currently we do not support non contiguous memory access being lowered to - # async dma so we throw an error. - with pytest.raises(tvm.TVMError, match=err_rgx): - with tvm.transform.PassContext( - config={ - "tir.use_async_copy": 1, - } - ): - tvm.build( - sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) - ) +# def test_non_contiguous(): +# """Test Non Contiguous memory lowering.""" +# sch = tvm.tir.Schedule(conv2d_async_non_contig) +# target_hexagon = tvm.target.hexagon("v68", link_params=True) +# err_rgx = r"Unable to lower async dma due to non contiguous memory access" +# # Currently we do not support non contiguous memory access being lowered to +# # async dma so we throw an error. +# with pytest.raises(tvm.TVMError, match=err_rgx): +# with tvm.transform.PassContext( +# config={ +# "tir.use_async_copy": 1, +# } +# ): +# tvm.build( +# sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) +# ) if __name__ == "__main__": From 6037eca8211dec2cd33680d964899ab5cfc48998 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Apr 2023 11:49:24 +0800 Subject: [PATCH 5/8] fix format with pylint & clang-format-10 --- src/target/source/ptx.cc | 2 +- ...est_tir_transform_inject_ptx_async_copy.py | 188 ++++++++++++------ 2 files changed, 125 insertions(+), 65 deletions(-) diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index b2cf22f803f3..feffc3d304ef 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -709,7 +709,7 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, else if (bytes == "1") return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)"); else - return std::make_tuple("",""); + return std::make_tuple("", ""); }(bytes); Replacer replacer; diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 35eafed36729..87632b421d6e 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -52,7 +52,8 @@ def ptx_global_to_shared_copy( T.attr("default", "async_scope", 1) for i in T.serial(num_iters): for j in T.vectorized(vector_size): - A_shared[tx, i * vector_size_expr + j] = A[tx, i * vector_size_expr + j] + A_shared[tx, i * vector_size_expr + + j] = A[tx, i * vector_size_expr + j] T.evaluate(T.ptx_commit_group(dtype="")) T.evaluate(T.ptx_wait_group(0, dtype="")) @@ -400,8 +401,10 @@ def simple_compute( with T.block("compute"): T.reads(A[tx, i]) T.writes(C[tx, i]) - A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") - B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + A_shared = T.alloc_buffer( + (16, 1), dtype="float32", scope="shared") + B_shared = T.alloc_buffer( + (16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(A_shared[tx, 0]) @@ -428,6 +431,7 @@ def simple_compute( # avoid return dummy code to other tests support_async = True + @tvm.testing.requires_cuda def test_vectorize_cp_async_in_if_then_else(): global support_async @@ -441,127 +445,183 @@ def test_vectorize_cp_async_in_if_then_else(): def complex_compute(A: T.Buffer((2, 16, 16, 1280), "float16"), W: T.Buffer((1280, 3, 3, 1280), "float16"), Conv: T.Buffer((512, 1280), "float16")): T.func_attr({"global_symbol": "main", "tir.noalias": True}) # with T.block("root"): - data_im2col_reindex_shared_dyn = T.alloc_buffer((512, 11520), "float16", scope="shared.dyn") - data_im2col_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((512, 11520), "float16", scope="wmma.matrix_a") - weight_flatten_reindex_shared_dyn = T.alloc_buffer((1280, 11520), "float16", scope="shared.dyn") - weight_flatten_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1280, 11520), "float16", scope="wmma.matrix_b") - Conv_reindex_wmma_accumulator = T.alloc_buffer((512, 1280), "float16", scope="wmma.accumulator") + data_im2col_reindex_shared_dyn = T.alloc_buffer( + (512, 11520), "float16", scope="shared.dyn") + data_im2col_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer( + (512, 11520), "float16", scope="wmma.matrix_a") + weight_flatten_reindex_shared_dyn = T.alloc_buffer( + (1280, 11520), "float16", scope="shared.dyn") + weight_flatten_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer( + (1280, 11520), "float16", scope="wmma.matrix_b") + Conv_reindex_wmma_accumulator = T.alloc_buffer( + (512, 1280), "float16", scope="wmma.accumulator") for x_0_0 in T.thread_binding(8, thread="blockIdx.y"): for y_0_0 in T.thread_binding(20, thread="blockIdx.x"): for x_0_1 in T.thread_binding(2, thread="threadIdx.y"): for y_0_1 in T.thread_binding(2, thread="threadIdx.z"): for x_0_2_init, y_0_2_init in T.grid(2, 2): with T.block("Conv_init_o"): - v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) - v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) + v_x_o = T.axis.spatial( + 32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) + v_y_o = T.axis.spatial( + 80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) T.reads() - T.writes(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) + T.writes( + Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) - T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.float32(0)) + C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], ( + 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) + T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * ( + C_s0 // 16) + C.elem_offset % C_s0 // 16, T.float32(0)) for k_0_0 in T.serial(180, - annotations={ - "software_pipeline_stage": [0, 0, 1], - "software_pipeline_order": [0, 1, 2], - "software_pipeline_async_stages": [0], - }): + annotations={ + "software_pipeline_stage": [0, 0, 1], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_async_stages": [0], + }): for ax0_ax1_0_fused_0 in range(4): for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): with T.block("data_im2col_reindex_shared.dyn_o"): - v0 = T.axis.spatial(512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) - v1_o = T.axis.spatial(1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8) - T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8]) - T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8]) + v0 = T.axis.spatial( + 512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) + v1_o = T.axis.spatial( + 1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8) + T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % + 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8]) + T.writes( + data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8]) for ax1_1 in T.vectorized(8): with T.block("data_im2col_reindex_shared.dyn"): - v1_i = T.axis.spatial(8, ax1_1) - T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i]) - T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) - data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0)) + v1_i = T.axis.spatial( + 8, ax1_1) + T.reads( + A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i]) + T.writes( + data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i]) + T.block_attr( + {"buffer_dim_align": [[0, 0, 32, 8]]}) + data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % + 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0)) for ax0_ax1_0_fused_0 in range(4): for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): for ax1_1 in T.vectorized(8): with T.block("weight_flatten_reindex_shared.dyn"): - v0 = T.axis.spatial(1280, y_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) - v1 = T.axis.spatial(11520, k_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8 * 8 + ax1_1) - T.reads(W[v0, v1 // 3840, v1 % 3840 // 1280, v1 % 1280]) - T.writes(weight_flatten_reindex_shared_dyn[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) - weight_flatten_reindex_shared_dyn[v0, v1] = W[v0, v1 // 1280 // 3, v1 // 1280 % 3, v1 % 1280] + v0 = T.axis.spatial( + 1280, y_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) + v1 = T.axis.spatial(11520, k_0_0 * 64 + ( + ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8 * 8 + ax1_1) + T.reads( + W[v0, v1 // 3840, v1 % 3840 // 1280, v1 % 1280]) + T.writes( + weight_flatten_reindex_shared_dyn[v0, v1]) + T.block_attr( + {"buffer_dim_align": [[0, 0, 32, 8]]}) + weight_flatten_reindex_shared_dyn[v0, v1] = W[v0, + v1 // 1280 // 3, v1 // 1280 % 3, v1 % 1280] for k_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 1): with T.block("data_im2col_reindex_shared.dyn_wmma.matrix_a_o"): - v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) - T.reads(data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.writes(data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + v0_o = T.axis.spatial( + 32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial( + 720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads( + data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes( + data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) + A_1 = T.match_buffer(data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( + 16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_a", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "row_major") + C = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( + 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_a", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr( + T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "row_major") for ax0_0, ax1_0 in T.grid(2, 1): with T.block("weight_flatten_reindex_shared.dyn_wmma.matrix_b_o"): - v0_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) - T.reads(weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.writes(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + v0_o = T.axis.spatial( + 80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial( + 720, k_0_0 * 4 + k_0_1 + ax1_0) + T.reads( + weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) + A_1 = T.match_buffer(weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( + 16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_b", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "col_major") + C = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( + 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_b", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr( + T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "col_major") for x_0_2, y_0_2 in T.grid(2, 2): with T.block("Conv_update_o"): - v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) - v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) - v_k_o = T.axis.reduce(720, k_0_0 * 4 + k_0_1) - T.reads(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16]) - T.writes(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) + v_x_o = T.axis.spatial( + 32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) + v_y_o = T.axis.spatial( + 80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) + v_k_o = T.axis.reduce( + 720, k_0_0 * 4 + k_0_1) + T.reads(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * + 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16]) + T.writes( + Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.matrix_a", offset_factor=16) + A_1 = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], ( + 16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.matrix_a", offset_factor=16) B_s0 = T.int32() B_s1 = T.int32() - B = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], (16, 16), "float16", strides=(B_s0, B_s1), scope="wmma.matrix_b", offset_factor=16) + B = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], ( + 16, 16), "float16", strides=(B_s0, B_s1), scope="wmma.matrix_b", offset_factor=16) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) - T.tvm_mma_sync(C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, A_1.data, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % A_s0 // 16, B.data, B.elem_offset // B_s0 // 16 * (B_s0 // 16) + B.elem_offset % B_s0 // 16, C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16) + C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], ( + 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) + T.tvm_mma_sync(C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, A_1.data, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % + A_s0 // 16, B.data, B.elem_offset // B_s0 // 16 * (B_s0 // 16) + B.elem_offset % B_s0 // 16, C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16) for ax0_0, ax1_0 in T.grid(2, 2): with T.block("Conv_reindex_wmma.accumulator_o"): - v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) - T.reads(Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) - T.writes(Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + v0_o = T.axis.spatial( + 32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial( + 80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) + T.reads( + Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes( + Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.accumulator", offset_factor=16) + A_1 = T.match_buffer(Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( + 16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.accumulator", offset_factor=16) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], (16, 16), "float16", strides=(C_s0, C_s1), offset_factor=16) - T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % A_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C_s0 * 16, 2), C_s0, "row_major") + C = T.match_buffer(Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], + (16, 16), "float16", strides=(C_s0, C_s1), offset_factor=16) + T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % + A_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C_s0 * 16, 2), C_s0, "row_major") mod = tvm.IRModule.from_expr(complex_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): tvm.build(mod, target="cuda") # generated_code must contain " setp.ne.b32 p, %0, 0;" assert "setp.ne.b32" in generated_code - + if not support_async: # avoid return dummy code to other tests support_async = True + if __name__ == "__main__": test_inject_async_copy() test_inject_async_copy_shared_dyn() From 19b05c3e4b83dfe0e946ac9767424ff2a8d73453 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Apr 2023 11:55:47 +0800 Subject: [PATCH 6/8] resolve test_tir_transform_inject_ptx_async_copy.py format with black --- ...est_tir_transform_inject_ptx_async_copy.py | 485 ++++++++++++++---- 1 file changed, 379 insertions(+), 106 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 87632b421d6e..168f8c879bcd 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -52,8 +52,7 @@ def ptx_global_to_shared_copy( T.attr("default", "async_scope", 1) for i in T.serial(num_iters): for j in T.vectorized(vector_size): - A_shared[tx, i * vector_size_expr + - j] = A[tx, i * vector_size_expr + j] + A_shared[tx, i * vector_size_expr + j] = A[tx, i * vector_size_expr + j] T.evaluate(T.ptx_commit_group(dtype="")) T.evaluate(T.ptx_wait_group(0, dtype="")) @@ -401,10 +400,8 @@ def simple_compute( with T.block("compute"): T.reads(A[tx, i]) T.writes(C[tx, i]) - A_shared = T.alloc_buffer( - (16, 1), dtype="float32", scope="shared") - B_shared = T.alloc_buffer( - (16, 1), dtype="float32", scope="shared") + A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(A_shared[tx, 0]) @@ -442,174 +439,450 @@ def test_vectorize_cp_async_in_if_then_else(): support_async = False @T.prim_func - def complex_compute(A: T.Buffer((2, 16, 16, 1280), "float16"), W: T.Buffer((1280, 3, 3, 1280), "float16"), Conv: T.Buffer((512, 1280), "float16")): + def complex_compute( + A: T.Buffer((2, 16, 16, 1280), "float16"), + W: T.Buffer((1280, 3, 3, 1280), "float16"), + Conv: T.Buffer((512, 1280), "float16"), + ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) # with T.block("root"): - data_im2col_reindex_shared_dyn = T.alloc_buffer( - (512, 11520), "float16", scope="shared.dyn") + data_im2col_reindex_shared_dyn = T.alloc_buffer((512, 11520), "float16", scope="shared.dyn") data_im2col_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer( - (512, 11520), "float16", scope="wmma.matrix_a") + (512, 11520), "float16", scope="wmma.matrix_a" + ) weight_flatten_reindex_shared_dyn = T.alloc_buffer( - (1280, 11520), "float16", scope="shared.dyn") + (1280, 11520), "float16", scope="shared.dyn" + ) weight_flatten_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer( - (1280, 11520), "float16", scope="wmma.matrix_b") + (1280, 11520), "float16", scope="wmma.matrix_b" + ) Conv_reindex_wmma_accumulator = T.alloc_buffer( - (512, 1280), "float16", scope="wmma.accumulator") + (512, 1280), "float16", scope="wmma.accumulator" + ) for x_0_0 in T.thread_binding(8, thread="blockIdx.y"): for y_0_0 in T.thread_binding(20, thread="blockIdx.x"): for x_0_1 in T.thread_binding(2, thread="threadIdx.y"): for y_0_1 in T.thread_binding(2, thread="threadIdx.z"): for x_0_2_init, y_0_2_init in T.grid(2, 2): with T.block("Conv_init_o"): - v_x_o = T.axis.spatial( - 32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) - v_y_o = T.axis.spatial( - 80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2_init) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2_init) T.reads() T.writes( - Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, v_y_o * 16 : v_y_o * 16 + 16 + ] + ) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], ( - 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) - T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * ( - C_s0 // 16) + C.elem_offset % C_s0 // 16, T.float32(0)) - for k_0_0 in T.serial(180, - annotations={ - "software_pipeline_stage": [0, 0, 1], - "software_pipeline_order": [0, 1, 2], - "software_pipeline_async_stages": [0], - }): + C = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, v_y_o * 16 : v_y_o * 16 + 16 + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + T.tvm_fill_fragment( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.float32(0), + ) + for k_0_0 in T.serial( + 180, + annotations={ + "software_pipeline_stage": [0, 0, 1], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_async_stages": [0], + }, + ): for ax0_ax1_0_fused_0 in range(4): for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): - for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): - for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_0_fused_2 in T.thread_binding( + 2, thread="threadIdx.y" + ): + for ax0_ax1_0_fused_3 in T.thread_binding( + 32, thread="threadIdx.x" + ): with T.block("data_im2col_reindex_shared.dyn_o"): v0 = T.axis.spatial( - 512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) + 512, + x_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + // 8, + ) v1_o = T.axis.spatial( - 1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8) - T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % - 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8]) + 1440, + k_0_0 * 8 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + % 8, + ) + T.reads( + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 : v1_o % 160 * 8 + 8, + ] + ) T.writes( - data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8]) + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 : v1_o * 8 + 8 + ] + ) for ax1_1 in T.vectorized(8): with T.block("data_im2col_reindex_shared.dyn"): - v1_i = T.axis.spatial( - 8, ax1_1) + v1_i = T.axis.spatial(8, ax1_1) T.reads( - A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i]) + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 + v1_i, + ] + ) T.writes( - data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i]) + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 + v1_i + ] + ) T.block_attr( - {"buffer_dim_align": [[0, 0, 32, 8]]}) - data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % - 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0)) + {"buffer_dim_align": [[0, 0, 32, 8]]} + ) + data_im2col_reindex_shared_dyn[ + v0, v1_o * 8 + v1_i + ] = T.if_then_else( + 1 <= v1_o // 480 + v0 % 256 // 16 + and v1_o // 480 + v0 % 256 // 16 < 17 + and 1 <= v1_o % 480 // 160 + v0 % 16 + and v1_o % 480 // 160 + v0 % 16 < 17, + A[ + v0 // 256, + v1_o // 480 + v0 % 256 // 16 - 1, + v1_o % 480 // 160 + v0 % 16 - 1, + v1_o % 160 * 8 + v1_i, + ], + T.float16(0), + ) for ax0_ax1_0_fused_0 in range(4): for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): - for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): - for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_0_fused_2 in T.thread_binding( + 2, thread="threadIdx.y" + ): + for ax0_ax1_0_fused_3 in T.thread_binding( + 32, thread="threadIdx.x" + ): for ax1_1 in T.vectorized(8): with T.block("weight_flatten_reindex_shared.dyn"): v0 = T.axis.spatial( - 1280, y_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) - v1 = T.axis.spatial(11520, k_0_0 * 64 + ( - ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8 * 8 + ax1_1) + 1280, + y_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + // 8, + ) + v1 = T.axis.spatial( + 11520, + k_0_0 * 64 + + ( + ax0_ax1_0_fused_0 * 128 + + ax0_ax1_0_fused_1 * 64 + + ax0_ax1_0_fused_2 * 32 + + ax0_ax1_0_fused_3 + ) + % 8 + * 8 + + ax1_1, + ) T.reads( - W[v0, v1 // 3840, v1 % 3840 // 1280, v1 % 1280]) + W[ + v0, + v1 // 3840, + v1 % 3840 // 1280, + v1 % 1280, + ] + ) T.writes( - weight_flatten_reindex_shared_dyn[v0, v1]) + weight_flatten_reindex_shared_dyn[v0, v1] + ) T.block_attr( - {"buffer_dim_align": [[0, 0, 32, 8]]}) - weight_flatten_reindex_shared_dyn[v0, v1] = W[v0, - v1 // 1280 // 3, v1 // 1280 % 3, v1 % 1280] + {"buffer_dim_align": [[0, 0, 32, 8]]} + ) + weight_flatten_reindex_shared_dyn[v0, v1] = W[ + v0, + v1 // 1280 // 3, + v1 // 1280 % 3, + v1 % 1280, + ] for k_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 1): with T.block("data_im2col_reindex_shared.dyn_wmma.matrix_a_o"): - v0_o = T.axis.spatial( - 32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial( - 720, k_0_0 * 4 + k_0_1 + ax1_0) + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) T.reads( - data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + data_im2col_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) T.writes( - data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(data_im2col_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( - 16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) + A_1 = T.match_buffer( + data_im2col_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="shared.dyn", + offset_factor=16, + ) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( - 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_a", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr( - T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "row_major") + C = T.match_buffer( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.matrix_a", + offset_factor=16, + ) + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + A_1.data, + A_1.elem_offset, + A_s0 * 16, + 1, + ), + A_s0, + "row_major", + ) for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("weight_flatten_reindex_shared.dyn_wmma.matrix_b_o"): - v0_o = T.axis.spatial( - 80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial( - 720, k_0_0 * 4 + k_0_1 + ax1_0) + with T.block( + "weight_flatten_reindex_shared.dyn_wmma.matrix_b_o" + ): + v0_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(720, k_0_0 * 4 + k_0_1 + ax1_0) T.reads( - weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + weight_flatten_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) T.writes( - weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(weight_flatten_reindex_shared_dyn[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( - 16, 16), "float16", strides=(A_s0, A_s1), scope="shared.dyn", offset_factor=16) + A_1 = T.match_buffer( + weight_flatten_reindex_shared_dyn[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="shared.dyn", + offset_factor=16, + ) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( - 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.matrix_b", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, T.tvm_access_ptr( - T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_s0 * 16, 1), A_s0, "col_major") + C = T.match_buffer( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.matrix_b", + offset_factor=16, + ) + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + A_1.data, + A_1.elem_offset, + A_s0 * 16, + 1, + ), + A_s0, + "col_major", + ) for x_0_2, y_0_2 in T.grid(2, 2): with T.block("Conv_update_o"): - v_x_o = T.axis.spatial( - 32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) - v_y_o = T.axis.spatial( - 80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) - v_k_o = T.axis.reduce( - 720, k_0_0 * 4 + k_0_1) - T.reads(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * - 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16]) + v_x_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + x_0_2) + v_y_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + y_0_2) + v_k_o = T.axis.reduce(720, k_0_0 * 4 + k_0_1) + T.reads( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ], + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v_x_o * 16 : v_x_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v_y_o * 16 : v_y_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + ) T.writes( - Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16]) + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ] + ) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(data_im2col_reindex_shared_dyn_wmma_matrix_a[v_x_o * 16:v_x_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], ( - 16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.matrix_a", offset_factor=16) + A_1 = T.match_buffer( + data_im2col_reindex_shared_dyn_wmma_matrix_a[ + v_x_o * 16 : v_x_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="wmma.matrix_a", + offset_factor=16, + ) B_s0 = T.int32() B_s1 = T.int32() - B = T.match_buffer(weight_flatten_reindex_shared_dyn_wmma_matrix_b[v_y_o * 16:v_y_o * 16 + 16, v_k_o * 16:v_k_o * 16 + 16], ( - 16, 16), "float16", strides=(B_s0, B_s1), scope="wmma.matrix_b", offset_factor=16) + B = T.match_buffer( + weight_flatten_reindex_shared_dyn_wmma_matrix_b[ + v_y_o * 16 : v_y_o * 16 + 16, + v_k_o * 16 : v_k_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(B_s0, B_s1), + scope="wmma.matrix_b", + offset_factor=16, + ) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(Conv_reindex_wmma_accumulator[v_x_o * 16:v_x_o * 16 + 16, v_y_o * 16:v_y_o * 16 + 16], ( - 16, 16), "float16", strides=(C_s0, C_s1), scope="wmma.accumulator", offset_factor=16) - T.tvm_mma_sync(C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16, A_1.data, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % - A_s0 // 16, B.data, B.elem_offset // B_s0 // 16 * (B_s0 // 16) + B.elem_offset % B_s0 // 16, C.data, C.elem_offset // C_s0 // 16 * (C_s0 // 16) + C.elem_offset % C_s0 // 16) + C = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v_x_o * 16 : v_x_o * 16 + 16, + v_y_o * 16 : v_y_o * 16 + 16, + ], + (16, 16), + "float16", + strides=(C_s0, C_s1), + scope="wmma.accumulator", + offset_factor=16, + ) + T.tvm_mma_sync( + C.data, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + A_1.data, + A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + + A_1.elem_offset % A_s0 // 16, + B.data, + B.elem_offset // B_s0 // 16 * (B_s0 // 16) + + B.elem_offset % B_s0 // 16, + C.data, + C.elem_offset // C_s0 // 16 * (C_s0 // 16) + + C.elem_offset % C_s0 // 16, + ) for ax0_0, ax1_0 in T.grid(2, 2): with T.block("Conv_reindex_wmma.accumulator_o"): - v0_o = T.axis.spatial( - 32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial( - 80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) + v0_o = T.axis.spatial(32, x_0_0 * 4 + x_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(80, y_0_0 * 4 + y_0_1 * 2 + ax1_0) T.reads( - Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + Conv_reindex_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ] + ) T.writes( - Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + Conv[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16] + ) A_s0 = T.int32() A_s1 = T.int32() - A_1 = T.match_buffer(Conv_reindex_wmma_accumulator[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], ( - 16, 16), "float16", strides=(A_s0, A_s1), scope="wmma.accumulator", offset_factor=16) + A_1 = T.match_buffer( + Conv_reindex_wmma_accumulator[ + v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16 + ], + (16, 16), + "float16", + strides=(A_s0, A_s1), + scope="wmma.accumulator", + offset_factor=16, + ) C_s0 = T.int32() C_s1 = T.int32() - C = T.match_buffer(Conv[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16], - (16, 16), "float16", strides=(C_s0, C_s1), offset_factor=16) - T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + A_1.elem_offset % - A_s0 // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C_s0 * 16, 2), C_s0, "row_major") + C = T.match_buffer( + Conv[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], + (16, 16), + "float16", + strides=(C_s0, C_s1), + offset_factor=16, + ) + T.tvm_store_matrix_sync( + A_1.data, + 16, + 16, + 16, + A_1.elem_offset // A_s0 // 16 * (A_s0 // 16) + + A_1.elem_offset % A_s0 // 16, + T.tvm_access_ptr( + T.type_annotation("float16"), + C.data, + C.elem_offset, + C_s0 * 16, + 2, + ), + C_s0, + "row_major", + ) mod = tvm.IRModule.from_expr(complex_compute) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): From 70202a409d8a16ced8051f69b64face1940b1158 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Apr 2023 12:39:38 +0800 Subject: [PATCH 7/8] [fix lint] todo without user asigned. --- src/tir/transforms/lower_async_dma.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index c37173a02cf2..75a62b6f15c9 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -43,7 +43,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} - // TODO: split lower async DMA support for CUDA and Hexagon Backend + // TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend Stmt VisitStmt_(const ForNode* loop) final { // if for loop is not within async_commit_queue_scope if (!async_queue_id_.has_value()) { From 6aeeb2d140ada8eef93cb3ffea66b44e56ade159 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Apr 2023 22:45:58 +0800 Subject: [PATCH 8/8] [lint] remove code instead of commenting them out --- src/tir/transforms/lower_async_dma.cc | 3 +-- .../test_hexagon/test_async_dma_pipeline.py | 18 ------------------ 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 75a62b6f15c9..e1ec0f1572c7 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -50,12 +50,11 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); } - // if for loop is not a memcpy of a contiguous region + // if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || mem_copy->source->region.size() != 1) { return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); - // LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; } // now that we are about to perform the `copy` transform diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 67ddc50fcc6e..efe4920ec0b2 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -879,23 +879,5 @@ def test_meta(hexagon_session): ) -# def test_non_contiguous(): -# """Test Non Contiguous memory lowering.""" -# sch = tvm.tir.Schedule(conv2d_async_non_contig) -# target_hexagon = tvm.target.hexagon("v68", link_params=True) -# err_rgx = r"Unable to lower async dma due to non contiguous memory access" -# # Currently we do not support non contiguous memory access being lowered to -# # async dma so we throw an error. -# with pytest.raises(tvm.TVMError, match=err_rgx): -# with tvm.transform.PassContext( -# config={ -# "tir.use_async_copy": 1, -# } -# ): -# tvm.build( -# sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) -# ) - - if __name__ == "__main__": tvm.testing.main()