From cbe797094289489c560dc6c9c5f887c050c4b328 Mon Sep 17 00:00:00 2001 From: adstraw Date: Tue, 22 Aug 2023 19:05:55 +0000 Subject: [PATCH] [Codegen] CUDA async copy with barrier synchronization --- include/tvm/tir/builtin.h | 32 ++++++ python/tvm/script/ir_builder/tir/ir.py | 8 ++ python/tvm/tir/__init__.py | 11 +- python/tvm/tir/op.py | 80 +++++++++++++ src/target/source/codegen_cuda.cc | 39 +++++++ src/target/source/codegen_cuda.h | 2 + src/target/source/ptx.cc | 93 ++++++++++++--- src/target/source/ptx.h | 26 +++++ src/tir/op/builtin.cc | 9 ++ tests/python/unittest/test_tir_op_types.py | 20 ++++ ...est_tir_transform_inject_ptx_async_copy.py | 108 ++++++++++++------ 11 files changed, 372 insertions(+), 56 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e8bcc028fc58..b5c04f760da2 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -663,6 +663,38 @@ TVM_DLL const Op& ptx_cp_async(); TVM_DLL const Op& ptx_commit_group(); TVM_DLL const Op& ptx_wait_group(); +/*! + * \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive + * + * ptx_cp_async_barrier(barrier_array, barrier_id) + * + */ +TVM_DLL const Op& ptx_cp_async_barrier(); + +/*! + * \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init + * + * ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count) + * + */ +TVM_DLL const Op& ptx_init_barrier_thread_count(); + +/*! + * \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive + * + * ptx_arrive_barrier(barrier_array, barrier_id) + * + */ +TVM_DLL const Op& ptx_arrive_barrier(); + +/*! + * \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait + * + * ptx_wait_barrier(barrier_array, barrier_id) + * + */ +TVM_DLL const Op& ptx_wait_barrier(); + /*! * \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer. * For example, if each thread in a warp of size 32 has 4 elements from the result of diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index efea9f1aea94..d7bebbacee05 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1844,6 +1844,10 @@ def wrapped(*args, **kwargs): tvm_warp_activemask = _tir_op.tvm_warp_activemask ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) +ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier) +ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count) +ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier) +ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) @@ -2113,6 +2117,10 @@ def wrapped(*args, **kwargs): "ptx_cp_async", "ptx_wait_group", "ptx_commit_group", + "ptx_cp_async_barrier", + "ptx_init_barrier_thread_count", + "ptx_arrive_barrier", + "ptx_wait_barrier", "mma_store", "mma_fill", "vectorlow", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 5eb1059d27c8..84c575333712 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -60,7 +60,16 @@ tvm_fill_fragment, ) from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill -from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group +from .op import ( + ptx_ldmatrix, + ptx_cp_async, + ptx_commit_group, + ptx_wait_group, + ptx_cp_async_barrier, + ptx_init_barrier_thread_count, + ptx_arrive_barrier, + ptx_wait_barrier, +) from .op import vectorlow, vectorhigh, vectorcombine from .op import infinity, reinterpret from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 378be84621ba..7e1c520cc432 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1397,6 +1397,86 @@ def ptx_wait_group(num): return call_intrin("", "tir.ptx_wait_group", num) +def ptx_cp_async_barrier(barrier_arr, barrier_id): + """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive + + Parameters + ---------- + barrier_arr : string + The name of the barrier array in shared memory + barrier_id : int + Index into the barrier array + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id) + + +def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count): + """TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init + + Parameters + ---------- + barrier_arr : string + The name of the barrier array in shared memory + barrier_id : int + Index into the barrier array + thread_count : int + Number of threads expected to arrive at the barrier + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + "", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count + ) + + +def ptx_arrive_barrier(barrier_arr, barrier_id): + """TVM intrinsic for ptx barrier arrival using mbarrier.arrive + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive + + Parameters + ---------- + barrier_arr : string + The name of the barrier array in shared memory + barrier_id : int + Index into the barrier array + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id) + + +def ptx_wait_barrier(barrier_arr, barrier_id): + """TVM intrinsic for ptx barrier wait using mbarrier.try_wait + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait + + Parameters + ---------- + barrier_arr : string + The name of the barrier array in shared memory + barrier_id : int + Index into the barrier array + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id) + + def vectorlow(dtype, vec): """Get the low level half of the vector diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6c0234819199..edbe8be0303f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -141,6 +141,18 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_cast_smem_ptr_to_int_) { + decl_stream << "__forceinline__ __device__ unsigned int\n"; + decl_stream << "cast_smem_ptr_to_int(const void* const smem_ptr)\n"; + decl_stream << "{\n"; + decl_stream << " unsigned int smem_int;\n"; + decl_stream << " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; " + "cvt.u32.u64 %0, smem_int; }\"\n"; + decl_stream << " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n"; + decl_stream << " return smem_int;\n"; + decl_stream << "}\n"; + } + decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n"; decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n"; decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n"; @@ -873,6 +885,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "}\n"; } else { std::string smem_elem_offset = this->PrintExpr(op->args[6]); + need_cast_smem_ptr_to_int_ = true; this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, smem_ptr, smem_elem_offset); } @@ -941,6 +954,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); // use size of argument list to indicate whether or not to use predicated cp.async + need_cast_smem_ptr_to_int_ = true; if (op->args.size() == 5) { this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); } else { @@ -952,6 +966,31 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { } else if (op->op.same_as(builtin::ptx_wait_group())) { int n = Downcast(op->args[0])->value; this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n"; + } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { + need_cast_smem_ptr_to_int_ = true; + std::string barriers_arr = Downcast(op->args[0])->value; + std::string barrier_id = this->PrintExpr(op->args[1]); + std::string barrier = barriers_arr + "[" + barrier_id + "]"; + this->stream << PrintCpAsyncBarrierAsm(barrier); + } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { + need_cast_smem_ptr_to_int_ = true; + std::string barriers_arr = Downcast(op->args[0])->value; + std::string barrier_id = this->PrintExpr(op->args[1]); + std::string barrier = barriers_arr + "[" + barrier_id + "]"; + std::string thread_count = this->PrintExpr(op->args[2]); + this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); + } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { + need_cast_smem_ptr_to_int_ = true; + std::string barriers_arr = Downcast(op->args[0])->value; + std::string barrier_id = this->PrintExpr(op->args[1]); + std::string barrier = barriers_arr + "[" + barrier_id + "]"; + this->stream << PrintArriveBarrierAsm(barrier); + } else if (op->op.same_as(builtin::ptx_wait_barrier())) { + need_cast_smem_ptr_to_int_ = true; + std::string barriers_arr = Downcast(op->args[0])->value; + std::string barrier_id = this->PrintExpr(op->args[1]); + std::string barrier = barriers_arr + "[" + barrier_id + "]"; + this->stream << PrintWaitBarrierAsm(barrier); } else if (op->op.same_as(builtin::ptx_ldg32())) { /* asm volatile ( diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 7de6ae05e87d..797ac9936375 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -104,6 +104,8 @@ class CodeGenCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; + // whether need cast_smem_ptr_to_int helper function + bool need_cast_smem_ptr_to_int_{false}; // Op attribute map OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index feffc3d304ef..6ff57f43bd2d 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -603,12 +603,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16."; std::string asm_code = R"( { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)({smem_addr})) - ); + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); __asm__ __volatile__( "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}" "{templates};\n" @@ -638,12 +633,7 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_elem_offset, const std::string& bytes) { std::string asm_code = R"( { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)({smem_addr})) - ); + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;" @@ -674,12 +664,7 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; std::string predicated_asm_code = R"( { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" - : "=r"(addr) - : "l"((void *)({smem_addr})) - ); + unsigned int addr = cast_smem_ptr_to_int({smem_addr}); int pred_guard = (int){pred_guard}; __asm__ __volatile__( "{ .reg .pred p;" @@ -724,5 +709,77 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, return predicated_asm_code; } +std::string PrintCpAsyncBarrierAsm(const std::string& barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + __asm__ __volatile__( + "cp.async.mbarrier.arrive.shared.b64 [%0];" + :: "r" (barrier_addr_int) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, + const std::string& thread_count) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + int thread_count = {thread_count}; + __asm__ __volatile__( + "mbarrier.init.shared.b64 [%0], %1;" + :: "r"(barrier_addr_int), "r"(thread_count) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", barrier); + replacer.register_rule("{thread_count}", thread_count); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintArriveBarrierAsm(const std::string& barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + __asm__ __volatile__( + "{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }" + :: "r"(barrier_addr_int) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + +std::string PrintWaitBarrierAsm(const std::string& barrier) { + std::string predicated_asm_code = R"( + { + unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier}); + constexpr int phase_bit = 0; + __asm__ __volatile__( + "{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; @P bra.uni DONE; bra.uni WAIT; DONE: }" + :: "r"(barrier_addr_int), "r"(phase_bit) + ); + } +)"; + + Replacer replacer; + replacer.register_rule("{barrier}", barrier); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index 1e49b57c1790..18519d85f6a4 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -108,6 +108,32 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, const std::string& bytes, const std::string& predicate_value); +/*! + * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive + * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] + */ +std::string PrintCpAsyncBarrierAsm(const std::string& barrier); + +/*! + * \brief Print ptx barrier initialization of thread count using mbarrier.init + * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] + * \param thread_count: The number of threads expected to arrive at the barrier + */ +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, + const std::string& thread_count); + +/*! + * \brief Print ptx barrier arrival using mbarrier.arrive + * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] + */ +std::string PrintArriveBarrierAsm(const std::string& barrier); + +/*! + * \brief Print ptx barrier wait using mbarrier.try_wait + * \param barrier: The barrier in shared memory in the form barrier_array[barrier_index] + */ +std::string PrintWaitBarrierAsm(const std::string& barrier); + } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index c85590428450..0ca61b409967 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -290,6 +290,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_barrier) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(mma_store) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("TScriptDtypePrintLocation", diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index 58954e745948..30e9ed2dfac5 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -244,6 +244,26 @@ def test_op_ptx_wait_group(): assert expr.op.name == "tir.ptx_wait_group" +def test_op_ptx_cp_async_barrier(): + expr = tir.ptx_cp_async_barrier("barrier", 0) + assert expr.op.name == "tir.ptx_cp_async_barrier" + + +def ptx_init_barrier_thread_count(): + expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32) + assert expr.op.name == "tir.ptx_init_barrier_thread_count" + + +def ptx_arrive_barrier(): + expr = tir.ptx_arrive_barrier("barrier", 0) + assert expr.op.name == "tir.ptx_arrive_barrier" + + +def ptx_wait_barrier(): + expr = tir.ptx_wait_barrier("barrier", 0) + assert expr.op.name == "tir.ptx_wait_barrier" + + def test_tir_op_vectorlow(): buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) vec = buffer.vload([0, 0], dtype="int8x16") diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index b39fca72c871..5d866199e79b 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -183,7 +183,71 @@ def test_inject_async_copy_shared_dyn(): tvm.testing.assert_allclose(C_nd.numpy(), A_np + B_np) -expected_cuda_script = r""" +@T.prim_func +def ptx_global_to_shared_copy_fp32x1_barrier( + A: T.Buffer((32, 128), "float32"), B: T.Buffer((32, 128), "float32") +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + barrier = T.alloc_buffer([1], "uint64", scope="shared") + A_shared = T.alloc_buffer([32, 128], "float32", scope="shared") + T.reads(A[0:32, 0:128]) + T.writes(B[0:32, 0:128], barrier[0:1]) + + barrier[0] = 0 + T.evaluate(T.ptx_init_barrier_thread_count("barrier", 0, 32, dtype="")) + + T.attr("default", "async_scope", 1) + for i in T.serial(128): + A_shared[tx, i] = A[tx, i] + + T.evaluate(T.ptx_cp_async_barrier("barrier", 0, dtype="")) + T.evaluate(T.ptx_arrive_barrier("barrier", 0, dtype="")) + T.evaluate(T.ptx_wait_barrier("barrier", 0, dtype="")) + + for i in range(128): + B[tx, i] = A_shared[tx, i] + + +@tvm.testing.requires_cuda +def test_inject_async_copy_barrier(): + dtype = "float32" + vec_size = 1 + f = ptx_global_to_shared_copy_fp32x1_barrier + + mod = tvm.IRModule.from_expr(f) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.tir.transform.FlattenBuffer()(mod) + mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod) + + assert count_cp_async(mod["main"].body) == 1 + + if tvm.testing.is_ampere_or_newer(): + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda") + + A_np = np.random.rand(32, 128).astype(dtype) + B_np = np.zeros((32, 128)).astype(dtype) + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +expected_cuda_script = r"""__forceinline__ __device__ unsigned int +cast_smem_ptr_to_int(const void* const smem_ptr) +{ + unsigned int smem_int; + asm volatile ("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; cvt.u32.u64 %0, smem_int; }" + : "=r"(smem_int) : "l"(smem_ptr)); + return smem_int; +} + #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ (__CUDACC_VER_MAJOR__ > 11)) #define TVM_ENABLE_L2_PREFETCH 1 @@ -214,12 +278,7 @@ def test_inject_async_copy_shared_dyn(): { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(A_shared + (((int)threadIdx.x) + 16))) - ); + unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) + 16)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" @@ -231,12 +290,7 @@ def test_inject_async_copy_shared_dyn(): } { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(B_shared + (((int)threadIdx.x) + 16))) - ); + unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) + 16)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" @@ -250,12 +304,7 @@ def test_inject_async_copy_shared_dyn(): { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(A_shared + (((int)threadIdx.x) + 32))) - ); + unsigned int addr = cast_smem_ptr_to_int(A_shared + (((int)threadIdx.x) + 32)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" @@ -267,12 +316,7 @@ def test_inject_async_copy_shared_dyn(): } { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)(B_shared + (((int)threadIdx.x) + 32))) - ); + unsigned int addr = cast_smem_ptr_to_int(B_shared + (((int)threadIdx.x) + 32)); __asm__ __volatile__( #if TVM_ENABLE_L2_PREFETCH "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;" @@ -288,12 +332,7 @@ def test_inject_async_copy_shared_dyn(): bool cse_var_1 = (i < 12); { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" - : "=r"(addr) - : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) - ); + unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); int pred_guard = (int)cse_var_1; __asm__ __volatile__( "{ .reg .pred p;" @@ -316,12 +355,7 @@ def test_inject_async_copy_shared_dyn(): __syncthreads(); { - unsigned int addr; - __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" - : "=r"(addr) - : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) - ); + unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); int pred_guard = (int)cse_var_1; __asm__ __volatile__( "{ .reg .pred p;"