Skip to content

Commit

Permalink
Updated masking for scalars, updated lit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zahimoud committed May 4, 2023
1 parent 27ae36f commit 0dcdad0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 can
// write
mask = and_(mask, icmp_slt(tid, i32_val(1)));
mask = and_(mask, icmp_eq(tid, i32_val(0)));
}
return mask;
}
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
// CHECK: llvm.inline_asm
Expand All @@ -1026,6 +1025,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
tt.return
Expand All @@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32_scalar
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
tt.store %arg0, %arg1 : f32
Expand Down

0 comments on commit 0dcdad0

Please sign in to comment.