diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index fe019fae40a2..f775936045ab 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -123,7 +123,7 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) { return expensiveLoadOrStore(op, targetEncoding); if (isa(op)) + triton::AtomicCASOp, triton::DotOp, triton::ReduceOp>(op)) return true; if (isa( op)) diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 35c389239508..190e687b7591 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1013,12 +1013,15 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { %1 = triton_gpu.convert_layout %0 : (tensor<2xi32, #blocked1>) -> tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> %2 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) -> tensor<1x2xi32, #blocked> %3 = "triton_gpu.cmpi"(%2, %cst_0) {predicate = 2 : i64} : (tensor<1x2xi32, #blocked>, tensor<1x2xi32, #blocked>) -> tensor<1x2xi1, #blocked> + // CHECK-DAG: }) {axis = 1 : i32} %4 = "tt.reduce" (%cst) ({ ^bb0(%arg3: i32, %arg4: i32): %add = arith.addi %arg3, %arg4 : i32 tt.reduce.return %add : i32 }) {axis = 1 : i32} : (tensor<1x2xi32, #blocked>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + // CHECK-NEXT: triton_gpu.convert_layout {{%.*}} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %5 = triton_gpu.convert_layout %4 : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<1xi32, #blocked1> + // CHECK-NOT: triton_gpu.convert_layout %6 = triton_gpu.convert_layout %5 : (tensor<1xi32, #blocked1>) -> tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> %7 = tt.expand_dims %6 {axis = 1 : i32} : (tensor<1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>) -> tensor<1x1xi32, #blocked2> %8 = triton_gpu.convert_layout %7 : (tensor<1x1xi32, #blocked2>) -> tensor<1x1xi32, #blocked>