Skip to content

Commit

Permalink
[BACKEND] Updated slice layout semantics, updated vectorization logic…
Browse files Browse the repository at this point in the history
… used for load/store ops. (triton-lang#1587)
  • Loading branch information
zahimoud authored and zhanglx13 committed Jun 5, 2023
1 parent 5c943d8 commit e031b27
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 39 deletions.
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ SmallVector<unsigned> getSizePerThread(Attribute layout);

SmallVector<unsigned> getContigPerThread(Attribute layout);

SmallVector<unsigned> getUniqueContigPerThread(Type type);

SmallVector<unsigned> getThreadsPerCTA(Attribute layout);

SmallVector<unsigned>
Expand Down
10 changes: 6 additions & 4 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,11 +921,13 @@ unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);

unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
contigPerThread = std::min(align, contigPerThread);
contigPerThread = std::min<unsigned>(shape[order[0]], contigPerThread);
auto uniqueContigPerThread = triton::gpu::getUniqueContigPerThread(tensorTy);
assert(order[0] < uniqueContigPerThread.size() &&
"Unxpected uniqueContigPerThread size");
unsigned contiguity = uniqueContigPerThread[order[0]];
contiguity = std::min(align, contiguity);

return contigPerThread;
return contiguity;
}

unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
Expand Down
12 changes: 8 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,17 @@ struct ConvertLayoutOpConversion
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentSizePerThread = getSizePerThread(parentEncoding);
unsigned stride = 1;
if (getOrder(parentEncoding)[0] == dim)
stride = parentSizePerThread[dim];
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
auto multiDimOffsetParent =
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
auto multiDimOffsetParent = getMultiDimOffset(
parentEncoding, loc, rewriter, elemId * stride, parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
Expand Down
42 changes: 28 additions & 14 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"

#include <set>
using namespace mlir;
using namespace mlir::triton;

Expand Down Expand Up @@ -633,6 +633,13 @@ class ConvertTritonGPUOpToLLVMPatternBase {
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentLayout);
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy);
result.erase(result.begin() + sliceLayout.getDim());
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
Expand All @@ -654,6 +661,8 @@ class ConvertTritonGPUOpToLLVMPatternBase {
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return emitOffsetForSliceLayout(sliceLayout, type);
llvm_unreachable("unsupported emitOffsetForLayout");
}

Expand Down Expand Up @@ -681,7 +690,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
} else if (auto mfma = layout.dyn_cast<MfmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mfma, type);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, type);
result = emitIndicesForDistributedLayout(loc, b, slice, type);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
Expand Down Expand Up @@ -1100,24 +1109,29 @@ class ConvertTritonGPUOpToLLVMPatternBase {
return multiDimIdx;
}

SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
SmallVector<SmallVector<unsigned>>
emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
auto parentEncoding = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentEncoding);
auto parentIndices = emitIndices(loc, rewriter, parentEncoding, parentTy);
unsigned numIndices = parentIndices.size();
SmallVector<SmallVector<Value>> resultIndices;
for (unsigned i = 0; i < numIndices; ++i) {
SmallVector<Value> indices = parentIndices[i];
indices.erase(indices.begin() + dim);
resultIndices.push_back(indices);
auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy);

unsigned numOffsets = parentOffsets.size();
SmallVector<SmallVector<unsigned>> resultOffsets;
std::set<SmallVector<unsigned>> uniqueOffsets;

for (unsigned i = 0; i < numOffsets; ++i) {
SmallVector<unsigned> offsets = parentOffsets[i];
offsets.erase(offsets.begin() + dim);
if (uniqueOffsets.find(offsets) == uniqueOffsets.end()) {
resultOffsets.push_back(offsets);
uniqueOffsets.insert(offsets);
}
}
return resultIndices;
return resultOffsets;
}

protected:
Expand Down
64 changes: 52 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,64 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
}
};

template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
using OpAdaptor = typename ViewOp::Adaptor;
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
Value view =
Value ret =
this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy);
rewriter.replaceOp(op, view);
rewriter.replaceOp(op, ret);
return success();
}
};

struct ExpandDimsOpConversion
: public ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp> {
using OpAdaptor = typename ExpandDimsOp::Adaptor;
explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp>(typeConverter, benefit) {}

LogicalResult
matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcVals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());

auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto resultTy = op.getType().template cast<RankedTensorType>();

assert(srcTy.getEncoding().isa<SliceEncodingAttr>() &&
"ExpandDimsOp only support SliceEncodingAttr");
auto srcLayout = srcTy.getEncoding().dyn_cast<SliceEncodingAttr>();
auto resultLayout = resultTy.getEncoding();

auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
srcValues[srcOffsets[i]] = srcVals[i];
}

SmallVector<Value> resultVals;
for (size_t i = 0; i < resultOffsets.size(); i++) {
auto offset = resultOffsets[i];
offset.erase(offset.begin() + srcLayout.getDim());
resultVals.push_back(srcValues.lookup(offset));
}
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
};
Expand Down Expand Up @@ -168,9 +209,8 @@ void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
PatternBenefit benefit) {
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);
Expand Down
40 changes: 36 additions & 4 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto ret = getSizePerThread(sliceLayout.getParent());
return ret;
// ret.erase(ret.begin() + sliceLayout.getDim());
return ret;
auto sizePerThread = getSizePerThread(sliceLayout.getParent());
sizePerThread.erase(sizePerThread.begin() + sliceLayout.getDim());
return sizePerThread;
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
return {2, 2};
Expand Down Expand Up @@ -158,11 +157,43 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
return {1, 2};
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
return {4, 1};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
return getContigPerThread(parentLayout);
} else {
return getSizePerThread(layout);
}
}

SmallVector<unsigned> getUniqueContigPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return SmallVector<unsigned>(1, 1);
auto tensorType = type.cast<RankedTensorType>();
auto shape = tensorType.getShape();
// If slice layout, call recursively on parent layout, and drop
// sliced dim
if (auto sliceLayout =
tensorType.getEncoding().dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(
parentShape, tensorType.getElementType(), parentLayout);
auto parentUniqueContigPerThread = getUniqueContigPerThread(parentTy);
parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() +
sliceLayout.getDim());
return parentUniqueContigPerThread;
}
// Base case
auto rank = shape.size();
SmallVector<unsigned> ret(rank);
auto contigPerThread = getContigPerThread(tensorType.getEncoding());
assert(contigPerThread.size() == rank && "Unexpected contigPerThread size");
for (int d = 0; d < rank; ++d) {
ret[d] = std::min<unsigned>(shape[d], contigPerThread[d]);
}
return ret;
}

SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
Expand Down Expand Up @@ -395,6 +426,7 @@ SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
auto parent = getParent();
auto parentElemsPerThread =
::getElemsPerThread(parent, paddedShape(shape), eltTy);
parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim());
return parentElemsPerThread;
}
unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
// CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return
}
Expand Down

0 comments on commit e031b27

Please sign in to comment.