Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] Support for L2 prefetch async copy and pred_guard enabled async in vectorized if_then_else #14329

Merged
merged 8 commits into from
Apr 4, 2023
7 changes: 7 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <mma.h>\n";
}

decl_stream << "\n#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \\\n";
decl_stream << " (__CUDACC_VER_MAJOR__ > 11))\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 1\n";
decl_stream << "#else\n";
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
decl_stream << "#endif\n";

decl_stream << "\n#ifdef _WIN32\n";
decl_stream << " using uint = unsigned int;\n";
decl_stream << " using uchar = unsigned char;\n";
Expand Down
46 changes: 40 additions & 6 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,12 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
#if TVM_ENABLE_L2_PREFETCH
"cp.async.{cg_or_ca}.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
);
}
)";
Expand All @@ -665,26 +669,56 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& global_elem_offset,
const std::string& bytes,
const std::string& predicate_value) {
CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" ||
bytes == "1")
<< "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async";
std::string predicated_asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
int src_bytes = {pred_guard} ? {bytes} : 0;
int pred_guard = (int){pred_guard};
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
"{ .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
#if TVM_ENABLE_L2_PREFETCH
" @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
#else
" @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
#endif
" @!p {store_shared};}"
:: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), {nopreg}
);
}
)";
auto [store_shared, nopreg] = [](const std::string& bytes) {
if (bytes == "16")
return std::make_tuple("st.shared.v4.u32 [%1], {%4, %5, %6, %7}",
"\"r\"(0), \"r\"(0), \"r\"(0),\"r\"(0)");
else if (bytes == "12")
return std::make_tuple("st.shared.v3.u32 [%1], {%4, %5, %6}", "\"r\"(0), \"r\"(0), \"r\"(0)");
else if (bytes == "8")
return std::make_tuple("st.shared.v2.u32 [%1], {%4, %5}", "\"r\"(0), \"r\"(0)");
else if (bytes == "4")
return std::make_tuple("st.shared.u32 [%1], {%4}", "\"r\"(0)");
else if (bytes == "2")
return std::make_tuple("st.shared.u16 [%1], {%4}", "\"r\"(0)");
else if (bytes == "1")
return std::make_tuple("st.shared.u8 [%1], {%4}", "\"r\"(0)");
else
return std::make_tuple("", "");
}(bytes);

Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
replacer.register_rule("{store_shared}", store_shared);
replacer.register_rule("{nopreg}", nopreg);
replacer.register_rule("{pred_guard}", predicate_value);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
Expand Down
31 changes: 30 additions & 1 deletion src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,41 @@ class PTXAsyncCopyInjector : public StmtMutator {
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
} else {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes), predicate_value}));
}
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer)
: IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}

// TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend
Stmt VisitStmt_(const ForNode* loop) final {
// if for loop is not within async_commit_queue_scope
if (!async_queue_id_.has_value()) {
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}

// if for loop is not a memcpy of a contiguous region
// if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior
std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy(GetRef<For>(loop), analyzer_);
if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
mem_copy->source->region.size() != 1) {
LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access";
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}

// now that we are about to perform the `copy` transform
Expand Down
18 changes: 0 additions & 18 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,23 +879,5 @@ def test_meta(hexagon_session):
)


def test_non_contiguous():
"""Test Non Contiguous memory lowering."""
sch = tvm.tir.Schedule(conv2d_async_non_contig)
target_hexagon = tvm.target.hexagon("v68", link_params=True)
err_rgx = r"Unable to lower async dma due to non contiguous memory access"
# Currently we do not support non contiguous memory access being lowered to
# async dma so we throw an error.
with pytest.raises(tvm.TVMError, match=err_rgx):
with tvm.transform.PassContext(
config={
"tir.use_async_copy": 1,
}
):
tvm.build(
sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)
)


if __name__ == "__main__":
tvm.testing.main()
Loading