Skip to content

Commit

Permalink
[TensorIR] Support for L2 prefetch async copy and pred_guard enabled …
Browse files Browse the repository at this point in the history
…async in vectorized if_then_else (#14329)

This pull request adds support for the L2 prefetch option in the cp.async instruction, which is supported in CUDA 11.4 and later, with the support method referencing Cutlass. Additionally, this pull request adds support for asynchronous copying of if_then_else under vectorization, and fixes some bugs.

for example, the original async cp can not support vectorized if_then_else, for a given template:

```python
for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"):
    for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"):
        for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"):
            with T.block("data_im2col_reindex_shared.dyn_o"):
                v0 = T.axis.spatial(512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8)
                v1_o = T.axis.spatial(1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8)
                T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8])
                T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8])
                for ax1_1 in T.vectorized(8):
                    with T.block("data_im2col_reindex_shared.dyn"):
                        v1_i = T.axis.spatial(8, ax1_1)
                        T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i])
                        T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i])
                        T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]})
                        data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0))
```

this pr will make the code into:

```c++
  {
    unsigned int addr;
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
      : "=r"(addr)
      : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_0_fused_0 * 2304) + (((int)threadIdx.z) * 1152)) + (((int)threadIdx.y) * 576)) + ((((int)threadIdx.x) >> 3) * 144)) + ((((int)threadIdx.x) & 7) * 16))))
    );
    int pred_guard = (int)((((1 <= ((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0)) && (((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0) < 17)) && (1 <= ((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)))) && (((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)) < 17));
    __asm__ __volatile__(
        "{  .reg .pred p;"
        "  setp.ne.b32 p, %0, 0;"
      #if TVM_ENABLE_L2_PREFETCH
        " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;"
      #else
        " @p cp.async.ca.shared.global [%1], [%2], %3;"
      #endif
      "  @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}"
        :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((((((((int)blockIdx.y) * 81920) + ((k_0_0 / 60) * 20480)) + (ax0_ax1_0_fused_0 * 20480)) + (((int)threadIdx.z) * 10240)) + (((int)threadIdx.y) * 5120)) + ((((int)threadIdx.x) >> 3) * 1280)) + ((k_0_0 % 60) * 64)) + ((((int)threadIdx.x) & 7) * 8)) - 21760))), "n"(16), "r"(0), "r"(0), "r"(0),"r"(0)
    );
  }
```

while the original injectptxasync pass can only support for a serial, 4 bytes aligned assignment of if_then_else, and the current code also has some bugs:

```c++
  std::string predicated_asm_code = R"(
  {
    unsigned int addr;
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)({smem_addr}))
    );
    int src_bytes = {pred_guard} ? {bytes} : 0;
    __asm__ __volatile__(
      "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
       :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
    );
  }
)";
```

if the condition is false, the code above will do nothing, however the right way is to assign zero to the shared memory address, or the value of memory is unpredictable, which will make the result uncorrected. we fixed by native shared memory copy asm:

```c++
    int pred_guard = (int){pred_guard};
    __asm__ __volatile__(
        "{  .reg .pred p;"
        "  setp.ne.b32 p, %0, 0;"
      #if TVM_ENABLE_L2_PREFETCH
        " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
      #else
        " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
      #endif
      "  @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}"
        :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(0), "r"(0), "r"(0),"r"(0)
    );
```

Co-authored-by: leiwang1999 <[email protected]>
  • Loading branch information
LeiWang1999 and LeiWang1999 authored Apr 4, 2023
1 parent 4d7e890 commit 44dd644
Show file tree
Hide file tree
Showing 6 changed files with 600 additions and 44 deletions.
7 changes: 7 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <mma.h>\n";
}

decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n";
decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
decl_stream << "#else\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
decl_stream << "#endif\n";

decl_stream << "\n#ifdef _WIN32\n";
decl_stream << " using uint = unsigned int;\n";
decl_stream << " using uchar = unsigned char;\n";
Expand Down
46 changes: 40 additions & 6 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
#if TVM_ENABLE_L2_PREFETCH
"cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
);
}
)";
Expand All @@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& global_elem_offset,
const std::string& bytes,
const std::string& predicate_value) {
CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" ||
bytes == "1")
<< "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
std::string predicated_asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
int src_bytes = {pred_guard} ? {bytes} : 0;
int pred_guard = (int){pred_guard};
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
"{ .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
#if TVM_ENABLE_L2_PREFETCH
" @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
#else
" @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
#endif
" @!p {store_shared};}"
:: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg}
);
}
)";
auto [store_shared, nopreg] = [](const std::string& bytes) {
if (bytes == "16")
return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}",
"\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)");
else if (bytes == "12")
return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)");
else if (bytes == "8")
return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)");
else if (bytes == "4")
return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)");
else if (bytes == "2")
return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)");
else if (bytes == "1")
return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)");
else
return std::make_tuple("", "");
}(bytes);

Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
replacer.register_rule("{store_shared}", store_shared);
replacer.register_rule("{nopreg}", nopreg);
replacer.register_rule("{pred_guard}", predicate_value);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
Expand Down
31 changes: 30 additions & 1 deletion src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,41 @@ class PTXAsyncCopyInjector : public StmtMutator {
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
} else {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes), predicate_value}));
}
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer)
: IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}

// TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend
Stmt VisitStmt_(const ForNode* loop) final {
// if for loop is not within async_commit_queue_scope
if (!async_queue_id_.has_value()) {
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}

// if for loop is not a memcpy of a contiguous region
// if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior
std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy(GetRef<For>(loop), analyzer_);
if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
mem_copy->source->region.size() != 1) {
LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access";
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}

// now that we are about to perform the `copy` transform
Expand Down
18 changes: 0 additions & 18 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,23 +879,5 @@ def test_meta(hexagon_session):
)


def test_non_contiguous():
"""Test Non Contiguous memory lowering."""
sch = tvm.tir.Schedule(conv2d_async_non_contig)
target_hexagon = tvm.target.hexagon("v68", link_params=True)
err_rgx = r"Unable to lower async dma due to non contiguous memory access"
# Currently we do not support non contiguous memory access being lowered to
# async dma so we throw an error.
with pytest.raises(tvm.TVMError, match=err_rgx):
with tvm.transform.PassContext(
config={
"tir.use_async_copy": 1,
}
):
tvm.build(
sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)
)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit 44dd644

Please sign in to comment.