Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] TMA Cooperative GeMM with Stream-K scheduler hangs for specific gemm shapes #1801

Open
Algy opened this issue Sep 10, 2024 · 6 comments
Labels
? - Needs Triage bug Something isn't working

Comments

@Algy
Copy link
Contributor

Algy commented Sep 10, 2024

Describe the bug

Gemm kernels with the following configurations hang for specific gemm shapes.

  • Type: e4m3 x e4m3 -> bf16
  • Tile: 256x32x128
  • Cluster: 2x1x1
  • Kernel Schedule: KernelTmaWarpSpecializedCooperative
  • Epilogue Schedule: TmaWarpSpecializedCooperative
  • Tile Scheduler: Stream-K

Tested gemm shapes(MxNxK):

  • 3584x1x4736: Hang
  • 3328x1x4736: Hang
  • 3200x1x4736: Hang
  • 3136x1x4736: Hang
  • 3104x1x4736: Hang
  • 3088x1x4736: Hang
  • 3328x1x4736: Hang
  • 3200x1x4736: Hang
  • 3136x1x4736: Hang
  • 3104x1x4736: Hang
  • 3088x1x4736: Hang
  • 3072x1x4736: OK

When I change the epilogue schedule to NoSmemWarpSpecialized, this issue seems to disappear. Therefore, I guess there's something wrong with the TMA epilogue when it is used with Stream-K.

Steps/Code to reproduce bug

Apply the following patch file to 48_hopper_warp_specialized_gemm.cu:

diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
index f26f4da3..da827d6d 100644
--- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
+++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
@@ -60,6 +60,8 @@
 
 #include "cute/tensor.hpp"
 #include "cutlass/tensor_ref.h"
+#include "cutlass/float8.h"
+#include "cutlass/bfloat16.h"
 #include "cutlass/epilogue/collective/default_epilogue.hpp"
 #include "cutlass/epilogue/thread/linear_combination.h"
 #include "cutlass/gemm/dispatch_policy.hpp"
@@ -89,17 +91,17 @@ using namespace cute;
 /////////////////////////////////////////////////////////////////////////////////////////////////
 
 // A matrix configuration
-using         ElementA    = float;                                          // Element type for A matrix operand
+using         ElementA    = cutlass::float_e4m3_t;                                          // Element type for A matrix operand
 using         LayoutA     = cutlass::layout::RowMajor;                      // Layout type for A matrix operand
 constexpr int AlignmentA  = 128 / cutlass::sizeof_bits<ElementA>::value;    // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
 
 // B matrix configuration
-using         ElementB    = float;                                          // Element type for B matrix operand
+using         ElementB    = cutlass::float_e4m3_t;                                          // Element type for B matrix operand
 using         LayoutB     = cutlass::layout::ColumnMajor;                   // Layout type for B matrix operand
 constexpr int AlignmentB  = 128 / cutlass::sizeof_bits<ElementB>::value;    // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
 
 // C/D matrix configuration
-using         ElementC    = float;                                          // Element type for C and D matrix operands
+using         ElementC    = cutlass::bfloat16_t;                                          // Element type for C and D matrix operands
 using         LayoutC     = cutlass::layout::ColumnMajor;                   // Layout type for C and D matrix operands
 constexpr int AlignmentC  = 128 / cutlass::sizeof_bits<ElementC>::value;    // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
 
@@ -107,8 +109,8 @@ constexpr int AlignmentC  = 128 / cutlass::sizeof_bits<ElementC>::value;    // M
 using ElementAccumulator  = float;                                          // Element type for internal accumulation
 using ArchTag             = cutlass::arch::Sm90;                            // Tag indicating the minimum SM that supports the intended feature
 using OperatorClass       = cutlass::arch::OpClassTensorOp;                 // Operator class tag
-using TileShape           = Shape<_128,_128,_32>;                           // Threadblock-level tile size
-using ClusterShape        = Shape<_1,_2,_1>;                                // Shape of the threadblocks in a cluster
+using TileShape           = Shape<_256,_32,_128>;                           // Threadblock-level tile size
+using ClusterShape        = Shape<_2,_1,_1>;                                // Shape of the threadblocks in a cluster
 using StageCountType = cutlass::gemm::collective::StageCountAuto;           // Stage count maximized based on the tile size
 using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;       // Kernel to launch based on the default setting in the Collective Builder
 
@@ -119,7 +121,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
     ElementAccumulator, ElementAccumulator,
     ElementC, LayoutC, AlignmentC,
     ElementC, LayoutC, AlignmentC,
-    cutlass::epilogue::collective::EpilogueScheduleAuto
+    cutlass::epilogue::TmaWarpSpecializedCooperative
   >::CollectiveOp;
 
 using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
@@ -130,13 +132,14 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
     TileShape, ClusterShape,
     cutlass::gemm::collective::StageCountAutoCarveout<
       static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
-    cutlass::gemm::collective::KernelScheduleAuto
+    cutlass::gemm::KernelTmaWarpSpecializedCooperative
   >::CollectiveOp;
 
 using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
     Shape<int,int,int>, // Indicates ProblemShape
     CollectiveMainloop,
-    CollectiveEpilogue
+    CollectiveEpilogue,
+    cutlass::gemm::StreamKScheduler
 >;
 
 using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
@@ -303,14 +306,14 @@ bool initialize_block(
   int bits_input = cutlass::sizeof_bits<Element>::value;
 
   if (bits_input == 1) {
-    scope_max = 2;
-    scope_min = 0;
+    scope_max = Element(2);
+    scope_min = Element(0);
   } else if (bits_input <= 8) {
-    scope_max = 2;
-    scope_min = -2;
+    scope_max = Element(2);
+    scope_min = Element(-2);
   } else {
-    scope_max = 8;
-    scope_min = -8;
+    scope_max = Element(8);
+    scope_min = Element(-8);
   }
 
   cutlass::reference::device::BlockFillRandomUniform(

(To apply the patch, use patch -p1 < xxx.patch)

Then execute the example with the command

./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=3584 --n=1 --k=4736

Environment details

  • Environment location: Bare metal on H100 80GB HBM3
@Algy Algy added ? - Needs Triage bug Something isn't working labels Sep 10, 2024
@thakkarV
Copy link
Collaborator

@jackkosaian

@jackkosaian
Copy link
Contributor

Thanks for reporting. This is due to a bug in the CUTLASS 3.x implementation of "separate reduction." For the time being, you can circumvent this with the following change, which go this problem size to work for me.

diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h
index 36888a29..46adb3ed 100644
--- a/include/cutlass/gemm/kernel/tile_scheduler_params.h
+++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h
@@ -1047,11 +1047,7 @@ struct PersistentTileSchedulerSm90StreamKParams {
   CUTLASS_HOST_DEVICE
   static bool
   should_perform_separate_reduction(uint32_t epilogue_subtile, uint64_t sk_units, uint64_t sk_tiles, uint64_t dp_tiles, uint64_t ctas_per_wave) {
-    // We perform separate reduction if we have fewer than one wave of output tiles
-    // and each output tile is covered by at least to stream-K units. When sk_units is
-    // multiple of sk_tiles, will choose basic split-k path instead of separate reduction for now.
-    return (epilogue_subtile != 1) && (dp_tiles == 0) && (sk_units > 2u * sk_tiles) &&
-           (sk_units + sk_tiles * epilogue_subtile <= ctas_per_wave);
+    return false;
   }

   // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when

@Algy
Copy link
Contributor Author

Algy commented Sep 20, 2024

How long is this bug expected to be fixed on the main branch? If it takes pretty long, maybe I should fork the branch and use it with the patch you provided. The buggy GeMM shapes are from LLMs which are pretty popular now.

And I also wonder if there's any performance implication applying your patch? That is to say, is there any potential performance penalty when I always turn off the separate reduction?

@jackkosaian
Copy link
Contributor

There is no timeline for when the separate reduction implementation will be fixed. We plan to roll out the patch I described soon, though.

There is no performance implication because, as far as I have seen, separate reduction is currently broken in any of its use cases.

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

@NihalPotdar
Copy link

NihalPotdar commented Nov 2, 2024

@jackkosaian curious how long the separate-reduction fix is expected to take and any suggested workarounds?

My understanding is that for small GEMM shapes with large K dimension, separate reduction would be very helpful and since it's disabled, this directly affects the performance for these GEMMs. One such GEMM configuration is m=16,n=2560,k=8192.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants