Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TensorIR] Support for L2 prefetch async copy and pred_guard enabled …
…async in vectorized if_then_else (#14329) This pull request adds support for the L2 prefetch option in the cp.async instruction, which is supported in CUDA 11.4 and later, with the support method referencing Cutlass. Additionally, this pull request adds support for asynchronous copying of if_then_else under vectorization, and fixes some bugs. for example, the original async cp can not support vectorized if_then_else, for a given template: ```python for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"): for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"): for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"): with T.block("data_im2col_reindex_shared.dyn_o"): v0 = T.axis.spatial(512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8) v1_o = T.axis.spatial(1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8) T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8]) T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8]) for ax1_1 in T.vectorized(8): with T.block("data_im2col_reindex_shared.dyn"): v1_i = T.axis.spatial(8, ax1_1) T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i]) T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i]) T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0)) ``` this pr will make the code into: ```c++ { unsigned int addr; __asm__ __volatile__( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }" : "=r"(addr) : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_0_fused_0 * 2304) + (((int)threadIdx.z) * 1152)) + (((int)threadIdx.y) * 576)) + ((((int)threadIdx.x) >> 3) * 144)) + ((((int)threadIdx.x) & 7) * 16)))) ); int pred_guard = (int)((((1 <= ((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0)) && (((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0) < 17)) && (1 <= ((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)))) && (((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)) < 17)); __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" #if TVM_ENABLE_L2_PREFETCH " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;" #else " @p cp.async.ca.shared.global [%1], [%2], %3;" #endif " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}" :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((((((((int)blockIdx.y) * 81920) + ((k_0_0 / 60) * 20480)) + (ax0_ax1_0_fused_0 * 20480)) + (((int)threadIdx.z) * 10240)) + (((int)threadIdx.y) * 5120)) + ((((int)threadIdx.x) >> 3) * 1280)) + ((k_0_0 % 60) * 64)) + ((((int)threadIdx.x) & 7) * 8)) - 21760))), "n"(16), "r"(0), "r"(0), "r"(0),"r"(0) ); } ``` while the original injectptxasync pass can only support for a serial, 4 bytes aligned assignment of if_then_else, and the current code also has some bugs: ```c++ std::string predicated_asm_code = R"( { unsigned int addr; __asm__ __volatile__( "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" : "=r"(addr) : "l"((void *)({smem_addr})) ); int src_bytes = {pred_guard} ? {bytes} : 0; __asm__ __volatile__( "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) ); } )"; ``` if the condition is false, the code above will do nothing, however the right way is to assign zero to the shared memory address, or the value of memory is unpredictable, which will make the result uncorrected. we fixed by native shared memory copy asm: ```c++ int pred_guard = (int){pred_guard}; __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" #if TVM_ENABLE_L2_PREFETCH " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;" #else " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;" #endif " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}" :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(0), "r"(0), "r"(0),"r"(0) ); ``` Co-authored-by: leiwang1999 <[email protected]>
- Loading branch information