Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPTIMIZER] Added kWidth attribute to DotOperandEncoding #1584

Merged
merged 14 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,22 @@ section 9.7.13.4.1 for more details.
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"unsigned":$MMAv2kWidth
);

let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
MmaEncodingAttr parentAttr = parent.dyn_cast<MmaEncodingAttr>();
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
}]>
];

let hasCustomAssemblyFormat = 1;
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ struct ConvertLayoutOpConversion

if (needTrans) {
// do transpose
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
auto aEncoding =
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numN = accumSizePerThread / numM;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ SmallVector<CoordTy> getMNCoords(Value thread,
Value _fpw1 = i32_val(fpw[1]);

// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,11 +714,11 @@ class ConvertTritonGPUOpToLLVMPatternBase {
Value _fpw1 = i32_val(fpw[1]);

// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

Expand Down Expand Up @@ -775,12 +775,12 @@ class ConvertTritonGPUOpToLLVMPatternBase {
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
// A info
auto aEncoding =
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding =
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

Expand Down
18 changes: 10 additions & 8 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
// a & b must be of smem layout
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
Type aEltType = aType.getElementType();
Type bEltType = bType.getElementType();
Attribute aEncoding = aType.getEncoding();
Attribute bEncoding = bType.getEncoding();
if (!aEncoding || !bEncoding)
Expand All @@ -276,17 +278,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
Value b = adaptor.getB();
Value c = adaptor.getC();
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 0, dEncoding, aEltType);
auto dstType =
RankedTensorType::get(aType.getShape(), aEltType, encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 1, dEncoding, bEltType);
auto dstType =
RankedTensorType::get(bType.getShape(), bEltType, encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
Expand Down
19 changes: 16 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,14 +774,27 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
auto mmaParent = parent.dyn_cast<MmaEncodingAttr>();
unsigned kWidth = 0;
Attribute _kWidth = attrs.get("kWidth");
if (_kWidth) {
if (!mmaParent || mmaParent.isVolta()) {
auto loc = parser.getNameLoc();
parser.emitError(loc, "kWidth only supported for MMAv2+ parent");
return Attribute();
}
kWidth = _kWidth.cast<IntegerAttr>().getInt();
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, kWidth);
}

void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
if (mmaParent && mmaParent.isAmpere())
printer << ", kWidth = " << getMMAv2kWidth();
printer << "}>";
}

Expand Down
15 changes: 9 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,17 @@ class BlockedToMMA : public mlir::RewritePattern {
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();

auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(),
oldAType.getElementType());
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(),
oldBType.getElementType());

auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
oldBType.getShape(), oldBType.getElementType(), newBEncoding);

a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
Expand Down
84 changes: 62 additions & 22 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Utility.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -42,6 +43,9 @@ class LoopPipeliner {

/// Loads to be pipelined
SetVector<Value> loads;
/// Smallest data-type for each load (used to optimize swizzle and
/// (create DotOpEncoding layout)
DenseMap<Value, Type> loadsSmallestType;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
Expand Down Expand Up @@ -256,33 +260,62 @@ LogicalResult LoopPipeliner::initialize() {
use = *use->getResult(0).getUsers().begin();
}

if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}
} else
auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use);
if (!convertLayout)
continue;
auto tensorType =
convertLayout.getResult().getType().dyn_cast<RankedTensorType>();
if (!tensorType)
continue;
auto dotOpEnc =
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
if (!dotOpEnc)
continue;
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
}

else
isCandidate = false;

if (isCandidate)
loads.insert(loadOp);
}

// we need to find the smallest ocmmon dtype
// since this determines the layout of `mma.sync` operands
// in mixed-precision mode
Type smallestType;
for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
auto ty = loadOp.getType().cast<RankedTensorType>();
Type eltTy = ty.getElementType();
if (!smallestType ||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
smallestType = eltTy;
}

for (auto loadCvt : loadsMapping)
loadsSmallestType[loadCvt.first] = smallestType;

for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
Value cvt = loadCvt.second;
auto dotOpEnc = cvt.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<ttg::DotOperandEncodingAttr>();
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
loadsBufferType[loadOp] =
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
}

// We have some loads to pipeline
if (!loads.empty()) {
// Update depArgs & depOps
Expand Down Expand Up @@ -551,8 +584,15 @@ scf::ForOp LoopPipeliner::createNewForOp() {
}
// we replace the use new load use with a convert layout
size_t i = std::distance(loads.begin(), it);
auto cvtDstTy = op.getResult(0).getType().cast<RankedTensorType>();
auto cvtDstEnc = cvtDstTy.getEncoding().cast<ttg::DotOperandEncodingAttr>();
auto newDstTy = RankedTensorType::get(
cvtDstTy.getShape(), cvtDstTy.getElementType(),
ttg::DotOperandEncodingAttr::get(
cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(),
loadsSmallestType[op.getOperand(0)]));
auto cvt = builder.create<ttg::ConvertLayoutOp>(
op.getLoc(), op.getResult(0).getType(),
op.getResult(0).getLoc(), newDstTy,
newForOp.getRegionIterArgs()[loadIdx + i]);
mapping.map(op.getResult(0), cvt.getResult());
}
Expand Down
20 changes: 15 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});

auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -156,12 +156,22 @@ LogicalResult Prefetcher::initialize() {
};

for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.getA().getType().cast<RankedTensorType>().getShape()[1];
auto aType = dot.getA().getType().cast<RankedTensorType>();
auto bType = dot.getB().getType().cast<RankedTensorType>();
auto aEnc = aType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto bEnc = bType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
int aKWidth = aEnc.getMMAv2kWidth();
int bKWidth = bEnc.getMMAv2kWidth();
assert(aKWidth == bKWidth);

auto kSize = aType.getShape()[1];

// works better with nvidia tensor cores
unsigned elementWidth =
dot.getA().getType().cast<RankedTensorType>().getElementTypeBitWidth();
prefetchWidth = 256 / elementWidth;
unsigned elementWidth = aType.getElementTypeBitWidth();
if (aKWidth == 0)
prefetchWidth = 256 / elementWidth;
else
prefetchWidth = 8 * aKWidth;

// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ class RematerializeForward : public mlir::RewritePattern {
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
// XXX: why is this needed?
if (srcEncoding.isa<triton::gpu::SharedEncodingAttr>() ||
dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
return failure();
// heuristics for flash attention
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
return failure();
SetVector<Operation *> cvtSlices;
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,15 @@ int simulateBackwardRematerialization(

//

Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
if (newOp->getNumResults() == 0)
return newOp;
auto origType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
if (!origType || !argType)
return newOp;
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int simulateBackwardRematerialization(
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
Attribute targetEncoding);

Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);

void rematerializeConversionChain(
Expand Down
Binary file modified python/triton/third_party/cuda/bin/ptxas
Binary file not shown.
4 changes: 2 additions & 2 deletions test/Analysis/test-alias.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>

// CHECK-LABEL: matmul_loop
// There shouldn't be any aliasing with the dot op encoding.
Expand Down
4 changes: 2 additions & 2 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>

module attributes {"triton_gpu.num-warps" = 4 : i32} {

Expand Down
4 changes: 2 additions & 2 deletions test/Analysis/test-membar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>

module attributes {"triton_gpu.num-warps" = 4 : i32} {

Expand Down
Loading