Skip to content

Commit

Permalink
[OPTIMIZER] Added kWidth attribute to DotOperandEncoding (triton-lang…
Browse files Browse the repository at this point in the history
…#1584)

This is a pre-requisist for efficient mixed-precision matmul
  • Loading branch information
ptillet authored Apr 27, 2023
1 parent 517144d commit b4437fe
Show file tree
Hide file tree
Showing 19 changed files with 162 additions and 74 deletions.
14 changes: 13 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,22 @@ section 9.7.13.4.1 for more details.
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"unsigned":$MMAv2kWidth
);

let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
MmaEncodingAttr parentAttr = parent.dyn_cast<MmaEncodingAttr>();
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
}]>
];

let hasCustomAssemblyFormat = 1;
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ struct ConvertLayoutOpConversion

if (needTrans) {
// do transpose
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
auto aEncoding =
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numN = accumSizePerThread / numM;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ SmallVector<CoordTy> getMNCoords(Value thread,
Value _fpw1 = i32_val(fpw[1]);

// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,11 +714,11 @@ class ConvertTritonGPUOpToLLVMPatternBase {
Value _fpw1 = i32_val(fpw[1]);

// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

Expand Down Expand Up @@ -775,12 +775,12 @@ class ConvertTritonGPUOpToLLVMPatternBase {
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
// A info
auto aEncoding =
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding =
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

Expand Down
18 changes: 10 additions & 8 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
// a & b must be of smem layout
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
Type aEltType = aType.getElementType();
Type bEltType = bType.getElementType();
Attribute aEncoding = aType.getEncoding();
Attribute bEncoding = bType.getEncoding();
if (!aEncoding || !bEncoding)
Expand All @@ -276,17 +278,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
Value b = adaptor.getB();
Value c = adaptor.getC();
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 0, dEncoding, aEltType);
auto dstType =
RankedTensorType::get(aType.getShape(), aEltType, encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 1, dEncoding, bEltType);
auto dstType =
RankedTensorType::get(bType.getShape(), bEltType, encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
Expand Down
19 changes: 16 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,14 +774,27 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
auto mmaParent = parent.dyn_cast<MmaEncodingAttr>();
unsigned kWidth = 0;
Attribute _kWidth = attrs.get("kWidth");
if (_kWidth) {
if (!mmaParent || mmaParent.isVolta()) {
auto loc = parser.getNameLoc();
parser.emitError(loc, "kWidth only supported for MMAv2+ parent");
return Attribute();
}
kWidth = _kWidth.cast<IntegerAttr>().getInt();
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, kWidth);
}

void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
if (mmaParent && mmaParent.isAmpere())
printer << ", kWidth = " << getMMAv2kWidth();
printer << "}>";
}

Expand Down
15 changes: 9 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,17 @@ class BlockedToMMA : public mlir::RewritePattern {
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();

auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(),
oldAType.getElementType());
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(),
oldBType.getElementType());

auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
oldBType.getShape(), oldBType.getElementType(), newBEncoding);

a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
Expand Down
84 changes: 62 additions & 22 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Utility.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -42,6 +43,9 @@ class LoopPipeliner {

/// Loads to be pipelined
SetVector<Value> loads;
/// Smallest data-type for each load (used to optimize swizzle and
/// (create DotOpEncoding layout)
DenseMap<Value, Type> loadsSmallestType;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
Expand Down Expand Up @@ -256,33 +260,62 @@ LogicalResult LoopPipeliner::initialize() {
use = *use->getResult(0).getUsers().begin();
}

if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}
} else
auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use);
if (!convertLayout)
continue;
auto tensorType =
convertLayout.getResult().getType().dyn_cast<RankedTensorType>();
if (!tensorType)
continue;
auto dotOpEnc =
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
if (!dotOpEnc)
continue;
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
}

else
isCandidate = false;

if (isCandidate)
loads.insert(loadOp);
}

// we need to find the smallest ocmmon dtype
// since this determines the layout of `mma.sync` operands
// in mixed-precision mode
Type smallestType;
for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
auto ty = loadOp.getType().cast<RankedTensorType>();
Type eltTy = ty.getElementType();
if (!smallestType ||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
smallestType = eltTy;
}

for (auto loadCvt : loadsMapping)
loadsSmallestType[loadCvt.first] = smallestType;

for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
Value cvt = loadCvt.second;
auto dotOpEnc = cvt.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<ttg::DotOperandEncodingAttr>();
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
loadsBufferType[loadOp] =
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
}

// We have some loads to pipeline
if (!loads.empty()) {
// Update depArgs & depOps
Expand Down Expand Up @@ -551,8 +584,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
// we replace the use new load use with a convert layout
size_t i = std::distance(loads.begin(), it);
auto cvtDstTy = op.getResult(0).getType().cast<RankedTensorType>();
auto cvtDstEnc = cvtDstTy.getEncoding().cast<ttg::DotOperandEncodingAttr>();
auto newDstTy = RankedTensorType::get(
cvtDstTy.getShape(), cvtDstTy.getElementType(),
ttg::DotOperandEncodingAttr::get(
cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(),
loadsSmallestType[op.getOperand(0)]));
auto cvt = builder.create<ttg::ConvertLayoutOp>(
op.getLoc(), op.getResult(0).getType(),
op.getResult(0).getLoc(), newDstTy,
newForOp.getRegionIterArgs()[loadIdx + i]);
mapping.map(op.getResult(0), cvt.getResult());
}
Expand Down
20 changes: 15 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});

auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -156,12 +156,22 @@ LogicalResult Prefetcher::initialize() {
};

for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.getA().getType().cast<RankedTensorType>().getShape()[1];
auto aType = dot.getA().getType().cast<RankedTensorType>();
auto bType = dot.getB().getType().cast<RankedTensorType>();
auto aEnc = aType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto bEnc = bType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
int aKWidth = aEnc.getMMAv2kWidth();
int bKWidth = bEnc.getMMAv2kWidth();
assert(aKWidth == bKWidth);

auto kSize = aType.getShape()[1];

// works better with nvidia tensor cores
unsigned elementWidth =
dot.getA().getType().cast<RankedTensorType>().getElementTypeBitWidth();
prefetchWidth = 256 / elementWidth;
unsigned elementWidth = aType.getElementTypeBitWidth();
if (aKWidth == 0)
prefetchWidth = 256 / elementWidth;
else
prefetchWidth = 8 * aKWidth;

// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ class RematerializeForward : public mlir::RewritePattern {
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
// XXX: why is this needed?
if (srcEncoding.isa<triton::gpu::SharedEncodingAttr>() ||
dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
return failure();
// heuristics for flash attention
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
return failure();
SetVector<Operation *> cvtSlices;
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,15 @@ int simulateBackwardRematerialization(

//

Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
if (newOp->getNumResults() == 0)
return newOp;
auto origType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
if (!origType || !argType)
return newOp;
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int simulateBackwardRematerialization(
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
Attribute targetEncoding);

Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);

void rematerializeConversionChain(
Expand Down
4 changes: 2 additions & 2 deletions test/Analysis/test-alias.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>

// CHECK-LABEL: matmul_loop
// There shouldn't be any aliasing with the dot op encoding.
Expand Down
4 changes: 2 additions & 2 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>

module attributes {"triton_gpu.num-warps" = 4 : i32} {

Expand Down
4 changes: 2 additions & 2 deletions test/Analysis/test-membar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>

module attributes {"triton_gpu.num-warps" = 4 : i32} {

Expand Down
Loading

0 comments on commit b4437fe

Please sign in to comment.