Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA async copy with barrier synchronization #15613

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,38 @@ TVM_DLL const Op& ptx_cp_async();
TVM_DLL const Op& ptx_commit_group();
TVM_DLL const Op& ptx_wait_group();

/*!
* \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive
*
* ptx_cp_async_barrier(barrier_array, barrier_id)
*
*/
TVM_DLL const Op& ptx_cp_async_barrier();

/*!
* \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init
*
* ptx_init_barrier_thread_count(barrier_array, barrier_id, thread_count)
*
*/
TVM_DLL const Op& ptx_init_barrier_thread_count();

/*!
* \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
*
* ptx_arrive_barrier(barrier_array, barrier_id)
*
*/
TVM_DLL const Op& ptx_arrive_barrier();

/*!
* \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
*
* ptx_wait_barrier(barrier_array, barrier_id)
*
*/
TVM_DLL const Op& ptx_wait_barrier();

/*!
* \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer.
* For example, if each thread in a warp of size 32 has 4 elements from the result of
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,10 @@ def wrapped(*args, **kwargs):
tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
ptx_cp_async_barrier = _op_wrapper(_tir_op.ptx_cp_async_barrier)
ptx_init_barrier_thread_count = _op_wrapper(_tir_op.ptx_init_barrier_thread_count)
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
Expand Down Expand Up @@ -2113,6 +2117,10 @@ def wrapped(*args, **kwargs):
"ptx_cp_async",
"ptx_wait_group",
"ptx_commit_group",
"ptx_cp_async_barrier",
"ptx_init_barrier_thread_count",
"ptx_arrive_barrier",
"ptx_wait_barrier",
"mma_store",
"mma_fill",
"vectorlow",
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@
tvm_fill_fragment,
)
from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill
from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
from .op import (
ptx_ldmatrix,
ptx_cp_async,
ptx_commit_group,
ptx_wait_group,
ptx_cp_async_barrier,
ptx_init_barrier_thread_count,
ptx_arrive_barrier,
ptx_wait_barrier,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
80 changes: 80 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,86 @@ def ptx_wait_group(num):
return call_intrin("", "tir.ptx_wait_group", num)


def ptx_cp_async_barrier(barrier_arr, barrier_id):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_arr, barrier_id)


def ptx_init_barrier_thread_count(barrier_arr, barrier_id, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array
thread_count : int
Number of threads expected to arrive at the barrier

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"", "tir.ptx_init_barrier_thread_count", barrier_arr, barrier_id, thread_count
)


def ptx_arrive_barrier(barrier_arr, barrier_id):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_arr, barrier_id)


def ptx_wait_barrier(barrier_arr, barrier_id):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

Parameters
----------
barrier_arr : string
The name of the barrier array in shared memory
barrier_id : int
Index into the barrier array

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_barrier", barrier_arr, barrier_id)


def vectorlow(dtype, vec):
"""Get the low level half of the vector

Expand Down
39 changes: 39 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <mma.h>\n";
}

if (need_cast_smem_ptr_to_int_) {
decl_stream << "__forceinline__ __device__ unsigned int\n";
decl_stream << "cast_smem_ptr_to_int(const void* const smem_ptr)\n";
decl_stream << "{\n";
decl_stream << " unsigned int smem_int;\n";
decl_stream << " asm volatile (\"{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
"cvt.u32.u64 %0, smem_int; }\"\n";
decl_stream << " : \"=r\"(smem_int) : \"l\"(smem_ptr));\n";
decl_stream << " return smem_int;\n";
decl_stream << "}\n";
}

decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n";
decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
Expand Down Expand Up @@ -873,6 +885,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
os << "}\n";
} else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
need_cast_smem_ptr_to_int_ = true;
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
}
Expand Down Expand Up @@ -941,6 +954,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async
need_cast_smem_ptr_to_int_ = true;
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
} else {
Expand All @@ -952,6 +966,31 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
std::string thread_count = this->PrintExpr(op->args[2]);
this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
this->stream << PrintArriveBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barriers_arr = Downcast<StringImm>(op->args[0])->value;
std::string barrier_id = this->PrintExpr(op->args[1]);
std::string barrier = barriers_arr + "[" + barrier_id + "]";
this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class CodeGenCUDA final : public CodeGenC {
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");

Expand Down
93 changes: 75 additions & 18 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
unsigned int addr = cast_smem_ptr_to_int({smem_addr});
__asm__ __volatile__(
"ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
"{templates};\n"
Expand Down Expand Up @@ -638,12 +633,7 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& global_elem_offset, const std::string& bytes) {
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
unsigned int addr = cast_smem_ptr_to_int({smem_addr});
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
Expand Down Expand Up @@ -674,12 +664,7 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
<< "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
std::string predicated_asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
unsigned int addr = cast_smem_ptr_to_int({smem_addr});
int pred_guard = (int){pred_guard};
__asm__ __volatile__(
"{ .reg .pred p;"
Expand Down Expand Up @@ -724,5 +709,77 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
return predicated_asm_code;
}

std::string PrintCpAsyncBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
__asm__ __volatile__(
"cp.async.mbarrier.arrive.shared.b64 [%0];"
:: "r" (barrier_addr_int)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
const std::string& thread_count) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
int thread_count = {thread_count};
__asm__ __volatile__(
"mbarrier.init.shared.b64 [%0], %1;"
:: "r"(barrier_addr_int), "r"(thread_count)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
replacer.register_rule("{thread_count}", thread_count);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

std::string PrintArriveBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
__asm__ __volatile__(
"{ .reg .b64 state; mbarrier.arrive.shared.b64 state, [%0]; }"
:: "r"(barrier_addr_int)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

std::string PrintWaitBarrierAsm(const std::string& barrier) {
std::string predicated_asm_code = R"(
{
unsigned int barrier_addr_int = cast_smem_ptr_to_int(&{barrier});
constexpr int phase_bit = 0;
__asm__ __volatile__(
"{ .reg .pred P; WAIT: mbarrier.try_wait.parity.shared.b64 P, [%0], %1; @P bra.uni DONE; bra.uni WAIT; DONE: }"
:: "r"(barrier_addr_int), "r"(phase_bit)
);
}
)";

Replacer replacer;
replacer.register_rule("{barrier}", barrier);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

} // namespace codegen
} // namespace tvm
26 changes: 26 additions & 0 deletions src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,32 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& bytes,
const std::string& predicate_value);

/*!
* \brief Print ptx async copy barrier using cp.async.mbarrier.arrive
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
*/
std::string PrintCpAsyncBarrierAsm(const std::string& barrier);

/*!
* \brief Print ptx barrier initialization of thread count using mbarrier.init
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
* \param thread_count: The number of threads expected to arrive at the barrier
*/
std::string PrintInitBarrierThreadCountAsm(const std::string& barrier,
const std::string& thread_count);

/*!
* \brief Print ptx barrier arrival using mbarrier.arrive
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
*/
std::string PrintArriveBarrierAsm(const std::string& barrier);

/*!
* \brief Print ptx barrier wait using mbarrier.try_wait
* \param barrier: The barrier in shared memory in the form barrier_array[barrier_index]
*/
std::string PrintWaitBarrierAsm(const std::string& barrier);

} // namespace codegen
} // namespace tvm

Expand Down
Loading