Skip to content

Commit

Permalink
Lower tt.generic_reduce to LLVM IR
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Mar 8, 2023
1 parent ef844eb commit f6d4247
Show file tree
Hide file tree
Showing 9 changed files with 422 additions and 12 deletions.
23 changes: 20 additions & 3 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,24 @@
namespace mlir {

class ReduceOpHelper {
ReduceOpHelper(Operation *op, int axis, bool withIndex)
: op(op), axis(axis), withIndex(withIndex) {
srcTy = op->getOperands().front().getType().cast<RankedTensorType>();
}

public:
explicit ReduceOpHelper(triton::ReduceOp op) : op(op) {
srcTy = op.getOperand().getType().cast<RankedTensorType>();
explicit ReduceOpHelper(triton::ReduceOp op):
ReduceOpHelper(
op.getOperation(),
op.getAxis(),
triton::ReduceOp::withIndex(op.getRedOp())) {
}

explicit ReduceOpHelper(triton::GenericReduceOp op):
ReduceOpHelper(
op.getOperation(),
op.getAxis(),
/*withIndex*/false) {
}

ArrayRef<int64_t> getSrcShape() { return srcTy.getShape(); }
Expand All @@ -35,8 +50,10 @@ class ReduceOpHelper {
unsigned getScratchSizeInBytes();

private:
triton::ReduceOp op;
Operation *op;
RankedTensorType srcTy{};
int axis;
bool withIndex;
};

bool isSharedEncoding(Value value);
Expand Down
4 changes: 4 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ class AllocationAnalysis {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto reduceOp = dyn_cast<triton::GenericReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
Expand Down
3 changes: 3 additions & 0 deletions lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ void MembarAnalysis::visitTerminator(Operation *op,
}
return;
}
if (isa<triton::GenericReduceReturnOp>(op)) {
return;
}
// Otherwise, it could be a return op
assert(isa<func::ReturnOp>(op) && "Unknown terminator");
}
Expand Down
12 changes: 3 additions & 9 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@ namespace mlir {

bool ReduceOpHelper::isFastReduction() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
return axis == triton::gpu::getOrder(srcLayout)[0];
}

unsigned ReduceOpHelper::getInterWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
Expand All @@ -28,28 +26,24 @@ unsigned ReduceOpHelper::getInterWarpSize() {
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcLayout = srcTy.getEncoding();
auto srcShape = srcTy.getShape();
auto axis = op.getAxis();
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarp(srcLayout)[axis]);
}

unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = srcTy.getEncoding();
auto axis = op.getAxis();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
}

SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
auto axis = op.getAxis();
auto smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
return smemShape;
}

SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
auto axis = op.getAxis();
SmallVector<SmallVector<unsigned>> smemShapes(3);

auto argLayout = srcTy.getEncoding();
Expand All @@ -64,7 +58,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {

/// FIXME(Qingyi): This size is actually larger than required.
/// shared memory block1:
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
smemShapes[1].push_back(numWarps * 32);

Expand All @@ -82,10 +76,10 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {
elems = product<unsigned>(smemShape);
}

auto tensorType = op.getOperand().getType().cast<RankedTensorType>();
auto tensorType = op->getOperand(0).getType().cast<RankedTensorType>();
unsigned bytes = elems * tensorType.getElementTypeBitWidth() / 8;

if (triton::ReduceOp::withIndex(op.getRedOp()))
if (withIndex)
bytes += elems * sizeof(int32_t);

return bytes;
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
TritonGPUToLLVMPass.cpp
PTXAsmFormat.cpp
ReduceOpToLLVM.cpp
GenericReduceOpToLLVM.cpp
Utility.cpp
TypeConverter.cpp
ViewOpToLLVM.cpp
Expand Down
Loading

0 comments on commit f6d4247

Please sign in to comment.