Skip to content

Commit

Permalink
[BACKEND] Updated predicate for atomic ops (triton-lang#1619)
Browse files Browse the repository at this point in the history
  • Loading branch information
zahimoud authored May 5, 2023
1 parent fda10c0 commit 0e2ef48
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
23 changes: 10 additions & 13 deletions lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@ struct AtomicCASOpConversion
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());

auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
Value mask = getMask(valueTy, rewriter, loc);
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
Expand All @@ -425,7 +425,7 @@ struct AtomicCASOpConversion
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();

Expand All @@ -434,7 +434,7 @@ struct AtomicCASOpConversion
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(pred);
st(dstOprStore, valOprStore).predicate(mask);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Expand Down Expand Up @@ -483,10 +483,11 @@ struct AtomicRMWOpConversion
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());

auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
: valueTy;
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
Expand All @@ -499,10 +500,7 @@ struct AtomicRMWOpConversion
// mask
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
Value mask = getMask(valueTy, rewriter, loc);

auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
Expand Down Expand Up @@ -582,7 +580,6 @@ struct AtomicRMWOpConversion
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 can
// write
mask = and_(mask, icmp_slt(tid, i32_val(1)));
mask = and_(mask, icmp_eq(tid, i32_val(0)));
}
return mask;
}
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
// CHECK: llvm.inline_asm
Expand All @@ -1026,6 +1025,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
tt.return
Expand All @@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32_scalar
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
tt.store %arg0, %arg1 : f32
Expand Down

0 comments on commit 0e2ef48

Please sign in to comment.