diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 29e75ee5353f..107caae6ec62 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -80,6 +80,13 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, storeOp.getMaskMutable().assign(mask); return op; } + if (auto atomicRMWOp = dyn_cast(op)) { + rewriter.setInsertionPoint(atomicRMWOp); + Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(), + atomicRMWOp.getMask(), pred); + atomicRMWOp.getMaskMutable().assign(mask); + return op; + } assert("don't know how to predicate this op" && false); return op; diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 641ff165d32f..69b88196fa2b 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -263,3 +263,71 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// Check that the stream pipeliner updates atomic op in the k-loop correctly +// CHECK-LABEL: _triton_gemm_kernel_atomic_rmw +// CHECK: scf.for +// CHECK: tt.atomic_rmw fadd, acq_rel, gpu +// CHECK: tt.dot +// CHECK: scf.yield + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_triton_gemm_kernel_atomic_rmw(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant dense<32> : tensor<32x32xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %2 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<32x1xi32, #blocked> + %4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %6 = tt.broadcast %3 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<32x32xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %12 = tt.addptr %11, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %13 = tt.splat %arg2 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %14 = tt.addptr %13, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %15 = tt.broadcast %14 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %16 = tt.addptr %15, %7 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %17 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked> + %18 = arith.cmpi slt, %1, %17 : tensor<32x1xi32, #blocked> + %19 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #blocked> + %20 = arith.cmpi slt, %5, %19 : tensor<1x32xi32, #blocked> + %21 = tt.broadcast %18 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %22 = tt.broadcast %20 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked> + %23 = arith.andi %21, %22 : tensor<32x32xi1, #blocked> + %24 = arith.addi %arg3, %c31_i32 : i32 + %25 = arith.divsi %24, %c32_i32 : i32 + %26 = arith.muli %arg4, %c32_i32 : i32 + %27 = tt.splat %26 : i32 -> tensor<32x32xi32, #blocked> + %28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args(%arg6 = %cst_0, %arg7 = %10, %arg8 = %12) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { + %32 = tt.load %arg7 : tensor<32x32x!tt.ptr, #blocked> + %33 = tt.load %arg8 : tensor<32x32x!tt.ptr, #blocked> + %34 = triton_gpu.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %35 = triton_gpu.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %37 = tt.addptr %arg7, %cst : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %38 = tt.addptr %arg8, %27 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %39 = arith.truncf %36 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %40 = triton_gpu.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked> + %41 = tt.atomic_rmw fadd, acq_rel, gpu, %16, %40, %23 : (tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked> + scf.yield %36, %37, %38 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> + } + %29 = arith.truncf %28#0 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %30 = triton_gpu.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #mma> + %31 = triton_gpu.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma> + tt.store %30, %29, %31 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +}