Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] Support for L2 prefetch async copy and pred_guard enabled async in vectorized if_then_else #14329

Merged
merged 8 commits into from
Apr 4, 2023

Conversation

LeiWang1999
Copy link
Contributor

This pull request adds support for the L2 prefetch option in the cp.async instruction, which is supported in CUDA 11.4 and later, with the support method referencing Cutlass. Additionally, this pull request adds support for asynchronous copying of if_then_else under vectorization, and fixes some bugs.

for example, the original async cp can not support vectorized if_then_else, for a given template:

for ax0_ax1_0_fused_1 in T.thread_binding(2, thread="threadIdx.z"):
    for ax0_ax1_0_fused_2 in T.thread_binding(2, thread="threadIdx.y"):
        for ax0_ax1_0_fused_3 in T.thread_binding(32, thread="threadIdx.x"):
            with T.block("data_im2col_reindex_shared.dyn_o"):
                v0 = T.axis.spatial(512, x_0_0 * 64 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) // 8)
                v1_o = T.axis.spatial(1440, k_0_0 * 8 + (ax0_ax1_0_fused_0 * 128 + ax0_ax1_0_fused_1 * 64 + ax0_ax1_0_fused_2 * 32 + ax0_ax1_0_fused_3) % 8)
                T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8:v1_o % 160 * 8 + 8])
                T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8:v1_o * 8 + 8])
                for ax1_1 in T.vectorized(8):
                    with T.block("data_im2col_reindex_shared.dyn"):
                        v1_i = T.axis.spatial(8, ax1_1)
                        T.reads(A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i])
                        T.writes(data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i])
                        T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]})
                        data_im2col_reindex_shared_dyn[v0, v1_o * 8 + v1_i] = T.if_then_else(1 <= v1_o // 480 + v0 % 256 // 16 and v1_o // 480 + v0 % 256 // 16 < 17 and 1 <= v1_o % 480 // 160 + v0 % 16 and v1_o % 480 // 160 + v0 % 16 < 17, A[v0 // 256, v1_o // 480 + v0 % 256 // 16 - 1, v1_o % 480 // 160 + v0 % 16 - 1, v1_o % 160 * 8 + v1_i], T.float16(0))

this pr will make the code into:

  {
    unsigned int addr;
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }"
      : "=r"(addr)
      : "l"((void *)(buf_dyn_shmem + (((((ax0_ax1_0_fused_0 * 2304) + (((int)threadIdx.z) * 1152)) + (((int)threadIdx.y) * 576)) + ((((int)threadIdx.x) >> 3) * 144)) + ((((int)threadIdx.x) & 7) * 16))))
    );
    int pred_guard = (int)((((1 <= ((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0)) && (((((((int)blockIdx.y) & 3) * 4) + (k_0_0 / 60)) + ax0_ax1_0_fused_0) < 17)) && (1 <= ((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)))) && (((((((int)threadIdx.z) * 8) + (((int)threadIdx.y) * 4)) + ((k_0_0 % 60) / 20)) + (((int)threadIdx.x) >> 3)) < 17));
    __asm__ __volatile__(
        "{  .reg .pred p;"
        "  setp.ne.b32 p, %0, 0;"
      #if TVM_ENABLE_L2_PREFETCH
        " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;"
      #else
        " @p cp.async.ca.shared.global [%1], [%2], %3;"
      #endif
      "  @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}"
        :: "r"(pred_guard), "r"(addr), "l"((void*)(A + (((((((((((int)blockIdx.y) * 81920) + ((k_0_0 / 60) * 20480)) + (ax0_ax1_0_fused_0 * 20480)) + (((int)threadIdx.z) * 10240)) + (((int)threadIdx.y) * 5120)) + ((((int)threadIdx.x) >> 3) * 1280)) + ((k_0_0 % 60) * 64)) + ((((int)threadIdx.x) & 7) * 8)) - 21760))), "n"(16), "r"(0), "r"(0), "r"(0),"r"(0)
    );
  }

while the original injectptxasync pass can only support for a serial, 4 bytes aligned assignment of if_then_else, and the current code also has some bugs:

  std::string predicated_asm_code = R"(
  {
    unsigned int addr;
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)({smem_addr}))
    );
    int src_bytes = {pred_guard} ? {bytes} : 0;
    __asm__ __volatile__(
      "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
       :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
    );
  }
)";

if the condition is false, the code above will do nothing, however the right way is to assign zero to the shared memory address, or the value of memory is unpredictable, which will make the result uncorrected. we fixed by native shared memory copy asm:

    int pred_guard = (int){pred_guard};
    __asm__ __volatile__(
        "{  .reg .pred p;"
        "  setp.ne.b32 p, %0, 0;"
      #if TVM_ENABLE_L2_PREFETCH
        " @p cp.async.{cg_or_ca}.shared.global.L2::128B [%1], [%2], %3;"
      #else
        " @p cp.async.{cg_or_ca}.shared.global [%1], [%2], %3;"
      #endif
      "  @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};}"
        :: "r"(pred_guard), "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(0), "r"(0), "r"(0),"r"(0)
    );

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 18, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: tensorir See #10317 for details

Generated by tvm-bot

@Hzfengsy
Copy link
Member

Thanks @LeiWang1999 for the improvement. Could you please add some performance benchmarks as it is a perf-related PR :)

@LeiWang1999
Copy link
Contributor Author

@Hzfengsy I'll have some updates next week. : )

@junrushao
Copy link
Member

That makes a lot of sense. Thanks @LeiWang1999 for the bugfix! In addition, I am curious too if L2 prefetch could really help with performance :-)

@cblmemo @andy-yang-1 PTAL

@junrushao junrushao changed the title [Tensor IR][Bugfix] Support for L2 prefetch async copy and pred_guard enabled async in vectorized if_then_else [TensorIR] Support for L2 prefetch async copy and pred_guard enabled async in vectorized if_then_else Mar 18, 2023
@andy-yang-1
Copy link
Contributor

This is an enhanced version of the ptx_ldg32 pass. I didn’t think about using async when I wrote ptx_ldg32, but cp.async needs to be careful about the GPU architecture. Some architectures cannot use cp.async. Can we distinguish between different GPU architectures?

@LeiWang1999
Copy link
Contributor Author

@andy-yang-1 yeah I noticed that there's a pass named ptx_ldg32, but it seems to me that this pass is not perfect yet? because ldg32 only load 4bytes from global, but sometimes we need ldg64 and ldg128 for more efficient data load and store in GPU. For the second consideration, I don't think this is something that this pr should take into account, partly because support for asynchronous copy is already there in current tensor ir, and partly because it's not a pass that is enabled by default, so users need to manually annotate and enable it and in python interface, it will be handled more comfortable I think.

@andy-yang-1
Copy link
Contributor

@LeiWang1999 Yeah, you are right. The ptx_ldg32 pass only supports ldg32 instruction for loading 4 bytes from global memory. I will also add support for them in the future 😉 I confused this with the automatic async work done by Tian. This work does not need to consider the architecture issue. Thank you very much for your feedback!

@LeiWang1999
Copy link
Contributor Author

LeiWang1999 commented Mar 22, 2023

Microbenchmark

  • Test Device: A100 ( because async copy works better on devices with huge bandwith like the A100 or H100 gpu
  • CUDA Version: 12.0
  • Tested DIffusion Conv2d shapes
  • Tunner: without tuner ( scheduled hands on to see performance influence. schedule is not optimal.
  • diffusion conv2d benchmark of vectorized if_then_else async copy(nhwc_nhwc, fp16 precison, tensorcore enabled.).
N C H W CO K S D P
C8 2 640 64 64 640 3 1 1 1
C11 2 960 32 32 640 3 1 1 1
C13 2 1280 32 32 1280 3 1 1 1

performance (weight vectorized load is in async way in both cases)

without vectorized if_then_else async (ms) with vectorized if_then_else async (ms)
C8 0.779605 0.543061
C11 0.544085 0.356011
C13 0.267264 0.218794
  • gemm to see l2 cache intrin influnce (float16-float16-row-col-tensorcore)
M N K without :l2 (ms) with :l2 (ms)
GEM-0 256 256 256 0.020821 0.020821
GEM-1 16384 16384 16384 45.1103 45.1103

: The L2 performance was as expected, as in my previous tests I didn't really see any impact on performance. I leveraged the L2 Cache in a different way, and I will wait until it was ready before submitting another pull request.

Eventhough this :l2 feature has no effect in tests, I think it is still necessary to add such a feature because sota library like cutlass/cublas ‘s kernels have implemented it in their kernels?

image

This may need more discussions.

@junrushao
Copy link
Member

Very cool results @LeiWang1999!! Please fix the CI and let’s get this PR in!

@LeiWang1999
Copy link
Contributor Author

I got all test passed in my branch, however, after merge this pull request, which merged into main two days ago :
c6c89c3

std::optional<tvm::tir::MemCpyDetails> mem_copy = IdentifyMemCpy(GetRef<For>(loop), analyzer_);
if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
mem_copy->source->region.size() != 1) {
LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access";
}

add some constraints into the flow of dma lowering, it seems like this pr is made for Hexagon backend, but it will also effect on cuda codegen..

CC~ @junrushao @adstraw

@adstraw
Copy link
Contributor

adstraw commented Mar 22, 2023

The same TIR annotation tir.use_async_copy is used to trigger both the InjectPTXAsyncCopy pass for CUDA codegen as well as LowerAsyncDMA pass used by Hexagon. This is TIR annotation and behavior is legacy --- it is NOT changed with commit c6c89c3. See here and here.

The advantage of reusing the tir.use_async_copy TIR annotation is that we have a converged way of handling async copy across multiple devices with async copies lowering to PTX for CUDA and DMA e.g. for Hexagon.

The disadvantage of reusing the tir.use_async_copy TIR annotation is that (I believe) BOTH InjectPTXAsycCopy and LowerAsyncDMA passes are running during CUDA codegen. This worked without issue in the past, but it seems that recent changes in commit c6c89c3 combined with this PR are exposing an issue.

LowerAsyncDMA is meant to be a generic pass to lower async copies to DMA but it should probably have a target specific opt-in for devices (like Hexagon) that support this behavior rather than running for all devices. If it is within scope for this PR to make that change, please do. If not, please feel free to revert c6c89c3 so this PR can proceed.

@LeiWang1999 LeiWang1999 force-pushed the lei/fix_async_if_then_else branch from 35a2d3f to 630cdd5 Compare March 31, 2023 10:48
@LeiWang1999
Copy link
Contributor Author

I think this pr is ready to be merged in if I resolved the last lint ci failure (actually the link checker ci told me which file that do not follow the format rule, but I didn't find any scirpt that can auto-format the file, like clang-format or pep8, I may need some helps. @junrushao

@adstraw I have a tricky fix of the dma lowering conflict between cuda backend and Hexagon, in the latest commit e08fff9aeca293325d1c81734dea2543ddd7ce26 and `` non contiguous memory access will not be a illegal behavior and it will fall back to a case that cuda required, now it seems that this chage works, but I think in the long terms view, we should separate the asynchronous copy of gpu from the dma of cpu

@junrushao
Copy link
Member

junrushao commented Apr 2, 2023

@LeiWang1999 To run linters locally, below are some commands you may find useful:

# Assume you are now at TVM's root directory

# Run all linters
bash ./tests/scripts/task_lint.sh
# Run python's formatter: black. CI uses black 22.12.0.
bash ./tests/lint/git-black.sh
# Run c++'s formatter: clang-format
bash ./tests/lint/git-clang-format.sh
# Run whitespace checking, another linter unhappy with your code
bash ./tests/lint/whitespace.sh

You might need to ensure the version of black/clang-format are consistent with our CI - which can be a bit annoying at times.

@LeiWang1999
Copy link
Contributor Author

CC @junrushao @Hzfengsy , this pr is ready I think.

@junrushao
Copy link
Member

Please remove code instead of commenting them out :-)

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm merging this in as there is no objection :-)

@junrushao junrushao merged commit 44dd644 into apache:main Apr 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants