diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 8397bb28932e..2a17c0e6d6d3 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -11,29 +11,39 @@ namespace mlir { class ReduceOpHelper { - ReduceOpHelper(Operation *op, int axis, bool withIndex) - : op(op), axis(axis), withIndex(withIndex) { - srcTy = op->getOperands().front().getType().cast(); - } - public: - explicit ReduceOpHelper(triton::ReduceOp op): - ReduceOpHelper( - op.getOperation(), - op.getAxis(), - triton::ReduceOp::withIndex(op.getRedOp())) { + explicit ReduceOpHelper(triton::ReduceOp rop): + op(rop.getOperation()), axis(rop.getAxis()) { + auto srcTy = rop.getOperand().getType().cast(); + srcShape = srcTy.getShape(); + srcEncoding = srcTy.getEncoding(); + srcElementTypes.push_back(srcTy.getElementType()); + + if (triton::ReduceOp::withIndex(rop.getRedOp())) { + srcElementTypes.push_back(Builder(op).getI32Type()); + } } - explicit ReduceOpHelper(triton::GenericReduceOp op): - ReduceOpHelper( - op.getOperation(), - op.getAxis(), - /*withIndex*/false) { + explicit ReduceOpHelper(triton::GenericReduceOp rop): + op(rop.getOperation()), axis(rop.getAxis()) { + auto firstTy = rop.getOperands()[0].getType().cast(); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = rop.getElementTypes(); + + for (const auto &t : rop.getInputTypes()) { + if (t.getShape() != srcShape) { + rop.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + rop.emitError() << "encoding mismatch"; + } + } } - ArrayRef getSrcShape() { return srcTy.getShape(); } + ArrayRef getSrcShape() { return srcShape; } - Attribute getSrcLayout() { return srcTy.getEncoding(); } + Attribute getSrcLayout() { return srcEncoding; } bool isFastReduction(); @@ -51,9 +61,10 @@ class ReduceOpHelper { private: Operation *op; - RankedTensorType srcTy{}; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; int axis; - bool withIndex; }; bool isSharedEncoding(Value value); diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index e0c37a3700f3..0564cdfcb69c 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -366,16 +366,25 @@ def TT_ReduceOp : TT_Op<"reduce", [Pure, def TT_GenericReduceOp: TT_Op<"generic_reduce", [Pure, DeclareOpInterfaceMethods, SingleBlock]> { let summary = "Reduction using generic combination algorithm"; - let arguments = (ins TT_Tensor:$operand, I32Attr:$axis); - let results = (outs TT_Type:$result); - let regions = (region SizedRegion<1>:$region); + let arguments = (ins Variadic:$operands, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$operands, "int":$axis)>, + ]; + let hasVerifier = 1; let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; } def TT_GenericReduceReturnOp: TT_Op<"generic_reduce.return", [HasParent<"GenericReduceOp">, Pure, Terminator, ReturnLike]> { let summary = "terminator for reduce operator"; - let arguments = (ins AnyType:$result); + let arguments = (ins Variadic:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 6cb541b21a86..9d5162373e4b 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -64,7 +64,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { // handle encodings // e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise, - SameOperandsAndResultShape, + SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "integer comparison operation"; @@ -78,7 +78,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [Pure, Elementwise, } def TTG_CmpFOp : TTG_Op<"cmpf", [Pure, Elementwise, - SameOperandsAndResultShape, + SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "floating-point comparison operation"; @@ -100,10 +100,10 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise, let description = [{}]; let arguments = (ins TT_BoolLike:$condition, - TT_Tensor:$true_value, - TT_Tensor:$false_value); + TT_Type:$true_value, + TT_Type:$false_value); - let results = (outs TT_Tensor:$result); + let results = (outs TT_Type:$result); } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 85915ecad340..a5e594b40357 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -10,29 +10,24 @@ namespace mlir { bool ReduceOpHelper::isFastReduction() { - auto srcLayout = srcTy.getEncoding(); - return axis == triton::gpu::getOrder(srcLayout)[0]; + return axis == triton::gpu::getOrder(getSrcLayout())[0]; } unsigned ReduceOpHelper::getInterWarpSize() { - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); auto srcReduceDimSize = static_cast(srcShape[axis]); unsigned sizeIntraWarps = getIntraWarpSize(); return std::min(srcReduceDimSize / sizeIntraWarps, - triton::gpu::getWarpsPerCTA(srcLayout)[axis]); + triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getIntraWarpSize() { - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); auto srcReduceDimSize = static_cast(srcShape[axis]); return std::min(srcReduceDimSize, - triton::gpu::getThreadsPerWarp(srcLayout)[axis]); + triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); } unsigned ReduceOpHelper::getThreadsReductionAxis() { - auto srcLayout = srcTy.getEncoding(); + auto srcLayout = getSrcLayout(); return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * triton::gpu::getWarpsPerCTA(srcLayout)[axis]; } @@ -46,7 +41,7 @@ SmallVector ReduceOpHelper::getScratchConfigBasic() { SmallVector> ReduceOpHelper::getScratchConfigsFast() { SmallVector> smemShapes(3); - auto argLayout = srcTy.getEncoding(); + auto argLayout = getSrcLayout(); auto argLayoutMma = argLayout.dyn_cast(); if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1) @@ -76,13 +71,11 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { elems = product(smemShape); } - auto tensorType = op->getOperand(0).getType().cast(); - unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8; - - if (withIndex) - bytes += elems * sizeof(int32_t); - - return bytes; + unsigned bytes_per_elem = 0; + for (const auto &ty: srcElementTypes) { + bytes_per_elem += ty.getIntOrFloatBitWidth() / 8; + } + return bytes_per_elem * elems; } bool isSharedEncoding(Value value) { diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 603c7d9ccbc6..35bb5379165e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -843,6 +843,8 @@ void populateElementwiseOpToLLVMPatterns( POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin #undef POPULATE_BINARY_OP #define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ diff --git a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp index 7d6731638b1a..37fddd329bb2 100644 --- a/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GenericReduceOpToLLVM.cpp @@ -24,71 +24,111 @@ struct GenericReduceOpConversion private: - void accumulate(ConversionPatternRewriter &rewriter, - Region &reduceOp, Value &acc, Value cur, bool isFirst) const { + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + llvm::SmallVectorImpl &acc, ValueRange cur, bool isFirst) const { if (isFirst) { - acc = cur; + acc.resize(cur.size()); + for (unsigned i = 0; i < cur.size(); ++i) { + acc[i] = cur[i]; + } return; } // Create a new copy of the reduce block, and inline it Block *currentBlock = rewriter.getBlock(); Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(reduceOp, &parent.front()); + rewriter.cloneRegionBefore(combineOp, &parent.front()); auto &newReduce = parent.front(); auto returnOp = dyn_cast(newReduce.getTerminator()); - rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), {acc, cur}); - acc = returnOp.getResult(); + + llvm::SmallVector combineArgs(2*acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + // Delete the terminator, which is no longer used rewriter.eraseOp(returnOp); } + SmallVector> unpackInputs( + Location loc, triton::GenericReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = getTypeConverter()->unpackLLElements( + loc, operands[i], rewriter, types[i]); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + // Use shared memory for reduction within warps and across warps LogicalResult matchAndRewriteBasic(triton::GenericReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Location loc = op->getLoc(); + Location loc = op.getLoc(); unsigned axis = op.getAxis(); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding().cast(); + ReduceOpHelper helper(op); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout().cast(); auto srcOrd = srcLayout.getOrder(); - auto srcShape = srcTy.getShape(); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); - ReduceOpHelper helper(op); auto smemShape = helper.getScratchConfigBasic(); unsigned elems = product(smemShape); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(elems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); - unsigned srcElems = getElemsPerThread(srcTy); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)); + } + + unsigned srcElems = getElemsPerThread(srcTys[0]); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + // Assumes offsets don't actually depend on type SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); + emitOffsetForLayout(srcLayout, srcTys[0]); - std::map, Value> accs; - std::map, Value> accIndices; + std::map, SmallVector> accs; std::map, SmallVector> indices; - Region *reduceOp = &op.getRegion(); + Region *combineOp = &op.getCombineOp(); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *reduceOp, accs[key], srcValues[i], isFirst); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -103,14 +143,17 @@ struct GenericReduceOpConversion // reduce across threads for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; + auto &acc = it.second; SmallVector writeIdx = indices[key]; writeIdx[axis] = udiv(writeIdx[axis], sizePerThread); Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, srcOrd); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - Value indexWritePtr = gep(indexPtrTy, indexSmemBase, writeOffset); - store(acc, writePtr); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + store(acc[i], writePtrs[i]); + } + SmallVector readIdx(writeIdx.size(), ints[0]); for (int N = smemShape[axis] / 2; N > 0; N >>= 1) { @@ -119,44 +162,56 @@ struct GenericReduceOpConversion Value readOffset = select( readMask, linearize(rewriter, loc, readIdx, smemShape, srcOrd), ints[0]); - Value readPtr = gep(elemPtrTy, writePtr, readOffset); + SmallVector readPtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + readPtrs[i] = gep(elemPtrTys[i], writePtrs[i], readOffset); + } + barrier(); - Value cur = load(readPtr); - accumulate(rewriter, *reduceOp, acc, cur, false); + SmallVector cur(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + cur[i] = load(gep(elemPtrTys[i], readPtrs[i], readOffset)); + } + accumulate(rewriter, *combineOp, acc, cur, false); barrier(); - store(acc, writePtr); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + store(acc[i], writePtrs[i]); + } } } barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding(); - - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (unsigned i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + + auto resultLayout = resultTy.getEncoding(); + + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (unsigned j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, ints[0]); + Value readOffset = linearize(rewriter, loc, readIdx, smemShape, srcOrd); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + results[i] = getTypeConverter()->packLLElements( + loc, resultVals, rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = load(smemBase); - rewriter.replaceOp(op, resultVal); } + auto parentBlock = op.getOperation()->getBlock(); + rewriter.replaceOp(op, results); return success(); } @@ -168,52 +223,54 @@ struct GenericReduceOpConversion Location loc = op->getLoc(); unsigned axis = adaptor.getAxis(); - auto srcTy = op.getOperand().getType().cast(); - auto srcLayout = srcTy.getEncoding(); - auto srcShape = srcTy.getShape(); - auto order = getOrder(srcLayout); - - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + ReduceOpHelper helper(op); + auto srcTys = op.getInputTypes(); + auto srcLayout = helper.getSrcLayout().cast(); + auto srcOrd = srcLayout.getOrder(); + auto srcShape = helper.getSrcShape(); - auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + SmallVector elemPtrTys(srcTys.size()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto ty = srcTys[i].getElementType(); + auto llvmElemTy = getTypeConverter()->convertType(ty); + elemPtrTys[i] = LLVM::LLVMPointerType::get(llvmElemTy, 3); + } auto llvmIndexTy = getTypeConverter()->getIndexType(); - auto elemPtrTy = LLVM::LLVMPointerType::get(llvmElemTy, 3); auto indexPtrTy = LLVM::LLVMPointerType::get(llvmIndexTy, 3); - Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation()); - smemBase = bitcast(smemBase, elemPtrTy); - ReduceOpHelper helper(op); auto smemShapes = helper.getScratchConfigsFast(); unsigned elems = product(smemShapes[0]); unsigned maxElems = std::max(elems, product(smemShapes[1])); - Value indexSmemBase = gep(elemPtrTy, smemBase, i32_val(maxElems)); - indexSmemBase = bitcast(indexSmemBase, indexPtrTy); + + SmallVector smemBases(op.getNumOperands()); + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)); + } unsigned sizeIntraWarps = helper.getIntraWarpSize(); unsigned sizeInterWarps = helper.getInterWarpSize(); - unsigned srcElems = getElemsPerThread(srcTy); - auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy); - auto srcValues = getTypeConverter()->unpackLLElements( - loc, adaptor.getOperand(), rewriter, srcTy); + unsigned srcElems = getElemsPerThread(srcTys[0]); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - SmallVector> offset = - emitOffsetForLayout(srcLayout, srcTy); - - std::map, Value> accs; - std::map, Value> accIndices; + std::map, SmallVector> accs; std::map, SmallVector> indices; - auto ¤tBlock = *rewriter.getBlock(); - auto *reduceOp = &op.getRegion(); + // Assumes offsets don't actually depend on type + SmallVector> offset = + emitOffsetForLayout(srcLayout, srcTys[0]); + + auto *combineOp = &op.getCombineOp(); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; key[axis] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *reduceOp, accs[key], srcValues[i], isFirst); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); if (isFirst) indices[key] = srcIndices[i]; } @@ -223,6 +280,9 @@ struct GenericReduceOpConversion Value warpId = udiv(threadId, warpSize); Value laneId = urem(threadId, warpSize); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcLayout); + auto order = getOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); SmallVector multiDimWarpId = @@ -236,21 +296,25 @@ struct GenericReduceOpConversion for (auto it : accs) { const SmallVector &key = it.first; - Value acc = it.second; - Value accIndex; + SmallVector acc = it.second; // Reduce within warps for (unsigned N = sizeIntraWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - accumulate(rewriter, *reduceOp, acc, shfl, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); + } + accumulate(rewriter, *combineOp, acc, shfl, false); } SmallVector writeIdx = indices[key]; writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; Value writeOffset = linearize(rewriter, loc, writeIdx, smemShapes[0], order); - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); - storeShared(rewriter, loc, writePtr, acc, laneZero); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value writePtr = gep(elemPtrTys[i], smemBases[i], writeOffset); + storeShared(rewriter, loc, writePtr, acc[i], laneZero); + } } barrier(); @@ -266,26 +330,37 @@ struct GenericReduceOpConversion unsigned elemsPerThread = std::max(elems / numThreads, 1); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { - Value readPtr = gep(elemPtrTy, smemBase, readOffset); // FIXME(Qingyi): need predicate icmp_slt(threadId, // i32_val(sizeInerWarps)) - Value acc = load(readPtr); - Value accIndex; + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + acc[i] = load(readPtr); + } for (unsigned N = sizeInterWarps / 2; N > 0; N >>= 1) { - Value shfl = shflSync(loc, rewriter, acc, N); - accumulate(rewriter, *reduceOp, acc, shfl, false); + SmallVector shfl(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + shfl[i] = shflSync(loc, rewriter, acc[i], N); + } + accumulate(rewriter, *combineOp, acc, shfl, false); } // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; - Value writePtr = gep(elemPtrTy, smemBase, writeOffset); + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + writePtrs[i] = gep(elemPtrTys[i], smemBases[i], writeOffset); + } Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); - storeShared(rewriter, loc, writePtr, acc, pred); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } if (round != elemsPerThread - 1) { readOffset = add(readOffset, i32_val(numThreads)); @@ -298,32 +373,33 @@ struct GenericReduceOpConversion barrier(); // set output values - if (auto resultTy = op.getType().dyn_cast()) { - // nd-tensor where n >= 1 - auto resultLayout = resultTy.getEncoding().cast(); - unsigned resultElems = getElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (size_t i = 0; i < resultElems; ++i) { - SmallVector readIdx = resultIndices[i]; - readIdx.insert(readIdx.begin() + axis, i32_val(0)); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[0], order); - Value readPtr = gep(elemPtrTy, smemBase, readOffset); - Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset); - resultVals[i] = load(readPtr); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = op.getResult()[i].getType().dyn_cast()) { + // nd-tensor where n >= 1 + auto resultLayout = resultTy.getEncoding().cast(); + unsigned resultElems = getElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + axis, i32_val(0)); + Value readOffset = + linearize(rewriter, loc, readIdx, smemShapes[0], order); + Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); + resultVals[j] = load(readPtr); + } + + results[i] = getTypeConverter()->packLLElements( + loc, resultVals, rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(smemBases[i]); } - - Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter, - resultTy); - rewriter.replaceOp(op, ret); - } else { - // 0d-tensor -> scalar - Value resultVal = load(smemBase); - rewriter.replaceOp(op, resultVal); } + rewriter.replaceOp(op, results); return success(); } diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 268c8c3e5684..79707d74fa50 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -47,15 +47,29 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( Value TritonGPUToLLVMTypeConverter::packLLElements( Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type type) { - auto structType = this->convertType(type); - if (!structType.isa()) { + auto structType = this->convertType(type).dyn_cast(); + if (!structType) { + assert(resultVals.size() == 1); return *resultVals.begin(); } + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + } Value llvmStruct = rewriter.create(loc, structType); - // llvm::outs() << structType << "\n"; for (const auto &v : llvm::enumerate(resultVals)) { - assert(v.value() && "can not insert null values"); + if (!v.value()) { + emitError(loc) << "cannot insert null values into struct, but tried to insert" + << v.value(); + } + if (v.value().getType() != elementTypes[v.index()]) { + emitError(loc) << "invalid element type in packLLEElements. Expected " + << elementTypes[v.index()] << " but got " << v.value().getType(); + + } llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); } return llvmStruct; @@ -179,4 +193,4 @@ TritonGPUToLLVMTypeConverter::convertTritonTensorType(RankedTensorType type) { } return std::nullopt; -} \ No newline at end of file +} diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index d6859177a11a..8775a3e97da0 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -69,13 +69,15 @@ class ArithConstantPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); auto value = adaptor.getValue().dyn_cast(); - assert(value); - if (value.getElementType().isInteger(1) && value.isSplat()) - // Workaround until https://reviews.llvm.org/D133743 is included. - value = DenseElementsAttr::get(retType, value.getSplatValue()); - else - // This is a hack. We just want to add encoding - value = value.reshape(retType); + if (dyn_cast(retType)) { + assert(value); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = DenseElementsAttr::get(retType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retType); + } addNamedAttrs( rewriter.replaceOpWithNewOp(op, retType, value), adaptor.getAttributes()); @@ -483,11 +485,11 @@ struct TritonGenericReducePattern : public OpConversionPattern( - op.getLoc(), adaptor.getOperand(), adaptor.getAxis()); + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); addNamedAttrs(newReduce, adaptor.getAttributes()); - auto &newRegion = newReduce.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), newRegion, newRegion.end()); + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.inlineRegionBefore(op.getCombineOp(), newCombineOp, newCombineOp.end()); rewriter.replaceOp(op, newReduce.getResult()); return success(); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 1882275811fa..c9984db71e60 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -318,44 +318,112 @@ bool mlir::triton::ReduceOp::withIndex(mlir::triton::RedOp redOp) { } //-- GenericReduceOp -- +void GenericReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = operands[i].getType().cast(); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + GenericReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + mlir::LogicalResult mlir::triton::GenericReduceOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - Value arg = operands[0]; - auto argTy = arg.getType().cast(); - auto retEltTy = argTy.getElementType(); - int axis = attributes.get("axis").cast().getInt(); - return inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + for (auto arg : operands) { + auto argTy = arg.getType().cast(); + auto retEltTy = argTy.getElementType(); + int axis = attributes.get("axis").cast().getInt(); + if ( + inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +mlir::LogicalResult mlir::triton::GenericReduceOp::verify() { + if (this->getOperands().size() < 1) { + return this->emitOpError() << "tt.generic_reduce must have at least 1 operand"; + } + for (const auto &operand: this->getOperands()) { + if (!dyn_cast(operand.getType())) { + return this->emitOpError() << "tt.generic_reduce operands must be RankedTensorType"; + } + } + return success(); } mlir::LogicalResult mlir::triton::GenericReduceOp::verifyRegions() { - auto argTy = getOperand().getType().cast(); - auto argElemTy = argTy.getElementType(); - - constexpr unsigned num_args = 2; - auto &block = this->getBody(); - if (block.getNumArguments() != num_args) { - return emitOpError() << "nested block must take " << num_args - << " arguments, but given block with " - << block.getNumArguments() << " arguments"; + auto argElementTypes = this->getElementTypes(); + const auto &operands = this->getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *this->getBody(); + if (block.getNumArguments() != numArgs) { + return this->emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; } unsigned i = 0; - for (const auto & blockArgTy: block.getArgumentTypes()) { + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; if (blockArgTy != argElemTy) { - return this->emitOpError() << "types mismatch on reduction block. Expected argument " << i + return this->emitOpError() << "type mismatch on combine operation. Expected argument " << i << " to have type " << argElemTy << " but got " << blockArgTy; } - ++i; } - if (!mlir::isa(block.getTerminator())) { - return this->emitOpError("the GenericReduceOp region must be terminated " - "with a GenericReduceReturnOp but got") << block.getTerminator(); + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return this->emitOpError() << "combine operation must be terminated " + << "with a GenericReduceReturnOp but got " + << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return this->emitOpError() << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return this->emitOpError() << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } } return mlir::success(); } +llvm::SmallVector GenericReduceOp::getInputTypes() { + llvm::SmallVector srcTys; + srcTys.reserve(this->getNumOperands()); + for (const auto &ty: this->getOperands().getTypes()) { + srcTys.push_back(ty.cast()); + } + return srcTys; +} + +llvm::SmallVector GenericReduceOp::getElementTypes() { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(this->getNumOperands()); + for (const auto &op: this->getOperands()) { + srcElemTys.push_back(op.getType().cast().getElementType()); + } + return srcElemTys; +} + +unsigned GenericReduceOp::getNumOperands() { + return this->getOperands().size(); +} + //-- SplatOp -- OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto constOperand = getSrc().getDefiningOp(); diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 38aab8ab906c..6d48e31cd435 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -80,6 +80,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( // Some ops from SCF are illegal addIllegalOp(); + // We have custom versions of some arith operators + addIllegalOp(); addDynamicallyLegalDialect( diff --git a/python/src/triton.cc b/python/src/triton.cc index 7ae7feefd495..8d1011902645 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -885,6 +885,18 @@ void init_triton_ir(py::module &&m) { auto loc = self.getUnknownLoc(); return self.create(loc, lhs, rhs); }) + .def("create_fmin", + [](mlir::OpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, lhs, rhs); + }) + .def("create_smin", + [](mlir::OpBuilder &self, mlir::Value &lhs, + mlir::Value &rhs) -> mlir::Value { + auto loc = self.getUnknownLoc(); + return self.create(loc, lhs, rhs); + }) .def("create_add", [](mlir::OpBuilder &self, mlir::Value &lhs, mlir::Value &rhs) -> mlir::Value { @@ -1325,24 +1337,21 @@ void init_triton_ir(py::module &&m) { operand, axis); }) .def("create_generic_reduce", - [](mlir::OpBuilder &self, mlir::Value &operand, int axis) -> mlir::triton::GenericReduceOp { + []( + mlir::OpBuilder &self, std::vector operands, int axis + ) -> mlir::triton::GenericReduceOp { auto loc = self.getUnknownLoc(); - auto inputTensorType = - operand.getType().dyn_cast(); - std::vector shape = inputTensorType.getShape(); - shape.erase(shape.begin() + axis); - mlir::Type resType = inputTensorType.getElementType(); - if (!shape.empty()) { - resType = mlir::RankedTensorType::get(shape, resType); - } - return self.create( - loc, resType, operand, axis); + return self.create(loc, operands, axis); }) .def("create_reduce_ret", - [](mlir::OpBuilder &self, mlir::Value &return_value) -> mlir::OpState { + [](mlir::OpBuilder &self, py::args args) -> mlir::OpState { auto loc = self.getUnknownLoc(); + llvm::SmallVector return_values; + for (const auto & arg: args) { + return_values.push_back(py::cast(arg)); + } return self.create( - loc, return_value); + loc, return_values); }) .def("create_ptr_to_int", [](mlir::OpBuilder &self, mlir::Value &val, diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 52f7f8e3afa3..74bd28e64499 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -10,6 +10,7 @@ abs, arange, argmin, + argmin2, argmax, atomic_add, atomic_and, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index acd0d17107c5..9b50a42ef1e0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1165,11 +1165,27 @@ def xor_sum(input, axis, _builder=None): @builtin @_add_reduction_docstr("prod") -def prod(input, axis, _builder): +def prod(input, axis, _builder=None): axis = _constexpr_to_value(axis) return semantic.prod(input, axis, _builder) +@builtin +@_add_reduction_docstr("argmin2") +def argmin2(input, axis, _builder=None): + + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + new_shape = [constexpr(1)] * len(input.shape) + new_shape[axis] = constexpr(n) + index = view(index, new_shape, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + values, indices = semantic.min_with_index(input, index, axis, _builder) + return indices + + # ----------------------- # Internal for debugging # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index b9fb4ddd2ad9..2810062e21e6 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1181,23 +1181,32 @@ def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: raise ValueError("xor_sum only supported for integers") return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR) -def reduction(input: tl.tensor, axis: int, region_builder_fn, builder: ir.builder) -> tl.tensor: - scalar_ty = input.type.scalar - - # get result type - shape = input.type.shape +def reduction( + inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder +) -> Tuple[tl.tensor, ...]: + # get result shape + shape = inputs[0].type.shape + print(shape, axis) ret_shape = [s for i, s in enumerate(shape) if i != axis] - if ret_shape: - res_ty = tl.block_type(scalar_ty, ret_shape) - else: - # 0d-tensor -> scalar - res_ty = out_scalar_ty + for t in inputs: + assert t.type.shape == shape + + def wrap_tensor(x, scalar_ty): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = out_scalar_ty + return tl.tensor(x, res_ty) - reduce_op = builder.create_generic_reduce(input.handle, axis) + reduce_op = builder.create_generic_reduce([t.handle for t in inputs], axis) region_builder_fn(reduce_op) reduce_op.verify() - return tl.tensor(reduce_op.get_result(0), res_ty) + return tuple( + wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) + for i in range(len(inputs)) + ) @contextmanager def insertion_guard(builder): @@ -1215,7 +1224,27 @@ def make_mul(reduce_op): fmul = builder.create_fmul(block.arg(0), block.arg(1)) builder.create_reduce_ret(fmul) - return reduction(input, axis, make_mul, builder) + return reduction((input,), axis, make_mul, builder)[0] + +def min_with_index(keys: tl.tensor, values: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + + def make_min_with_index_combine(reduce_op): + ir_key_ty = keys.type.scalar.to_ir(builder) + ir_value_ty = values.type.scalar.to_ir(builder) + region = reduce_op.get_region(0) + with insertion_guard(builder): + block = builder.create_block_with_parent(region, [ir_key_ty, ir_value_ty] * 2) + value1, index1, value2, index2 = [block.arg(i) for i in range(4)] + lt = builder.create_fcmpOLT(value1, value2) + gt = builder.create_fcmpOGT(value1, value2) + index_min = builder.create_smin(index1, index2) + index_ret = builder.create_select( + lt, index1, builder.create_select(gt, index2, index_min)) + + value_min = builder.create_fmin(value1, value2) + builder.create_reduce_ret(value_min, index_ret) + + return reduction((keys, values), axis, make_min_with_index_combine, builder) # ===----------------------------------------------------------------------=== # Math