Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TMA Cooperative GeMM with Stream-K scheduler hangs #1917

Open
NihalPotdar opened this issue Nov 4, 2024 · 10 comments
Open

[BUG] TMA Cooperative GeMM with Stream-K scheduler hangs #1917

NihalPotdar opened this issue Nov 4, 2024 · 10 comments
Labels
? - Needs Triage bug Something isn't working

Comments

@NihalPotdar
Copy link

Describe the bug
Gemm kernels with the following configurations hang for specific gemm shapes.

Type: uint4_t * half_t
Tile: m=16,n=2560,k=8192
Cluster: 1x1x1
Kernel Schedule: KernelTmaWarpSpecializedCooperative
Epilogue Schedule: TmaWarpSpecializedCooperative
Tile Scheduler: Stream-K

Expected behavior

With Cutlass 3.X, this kernel just hangs call with no changes. This is not expected.

Based on #1801, this should have been resolved.

@NihalPotdar NihalPotdar added ? - Needs Triage bug Something isn't working labels Nov 4, 2024
@thakkarV
Copy link
Collaborator

thakkarV commented Nov 4, 2024

@jackkosaian

@jackkosaian
Copy link
Contributor

Hi, @NihalPotdar. Can you please provide a reproducer for this bug?

@NihalPotdar
Copy link
Author

@jackkosaian set the scheduler in examples/55_hooper_mixed_dtype_gemm.cu to cutlass::gemm::StreamKScheduler and then run ./55_hopper_int4_bf16_gemm --m=16 --n=2560 --k=8192 --mode=1

@thakkarV
Copy link
Collaborator

thakkarV commented Nov 4, 2024

please include your build flags and full steps to repro starting at a checkout of the repo. We find that often users do not use our build system generated flags. Please also provide your CUDA toolkit version

@NihalPotdar
Copy link
Author

  1. clone cutlass, https://github.com/NVIDIA/cutlass
  2. follow https://github.com/NVIDIA/cutlass/blob/main/README.md to generate build files with cmake .. -DCUTLASS_NVCC_ARCHS=sm90a
  3. update 55_hopper_mixed_dtype_gemm.cu to include the streamK scheduler, eg.
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
    Shape<int,int,int,int>, // Indicates ProblemShape
    CollectiveMainloopScaleOnly,
    CollectiveEpilogue,
    cutlass::gemm::StreamKScheduler
>;
  1. run make from build/examples/55_hopper_mixed_dtype_gemm
  2. call ./55_hopper_int4_bf16_gemm --m=16 --n=2560 --k=8192 --mode=1

@jackkosaian
Copy link
Contributor

Thanks for the detailed steps.

It looks like that example is not calling gemm.initialize() before each run of the GEMM in the profiling loop here.

This is required for stream-K in order to initialize counters used for coordinating inter-CTA reduction. If these are not properly initialized, stream-K is likely to hang.

Can you please try changing the loop linked above to be the following?

    for (int iter = 0; iter < options.iterations; ++iter) {
      CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); // Added this line
      CUTLASS_CHECK(gemm.run());
    }

@NihalPotdar
Copy link
Author

@jackkosaian that works, thank you. However, for smaller problem shapes, like in this case where the atomic write overheard becomes significant - using a separate reduction wave would make a lot of sense. I noticed that in the cutlass implementation, https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/tile_scheduler_params.h#L1051, this is turned off currently.

Are there any plans to fix this and what's the ETA there? do you have any suggested workarounds in the mean time?

@jackkosaian
Copy link
Contributor

The current plan is to improve reduction performance in 3.7.

@NihalPotdar
Copy link
Author

NihalPotdar commented Nov 8, 2024

@jackkosaian Sounds good. I am also seeing correctness issues with the existing implementation, that is when compared with fp8 like in build/examples/55_hopper_mixed_dtype_gemm, the comparison fails only when using streamK. Any thoughts on why this might be the case?

@jackkosaian
Copy link
Contributor

How are tensors initialized? Containing random floating point values, or random integer values within some tight range?

How is error checking being performed? Exact match or relative?

Since stream-K involves splitting a GEMM along the K mode, it can accumulate results in a different order than a non-stream-K GEMM. Since floating point addition is not associative, the results can be different.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants