-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BACKEND] Implement 3xTF32 trick (#3234)
This PR implements the [3xTF32 trick](NVIDIA/cutlass#385) to make use of the TCs on F32 tensors without sacrificing accuracy. This is particularly relevant for PyTorch, as TF32 is off by default. Benchmarks on A100 from `python/tutorials/03-matrix-multiplication.py` run on `float32` data using `use_tf32=False`: ``` M N K cuBLAS Triton This PR 0 256.0 256.0 256.0 1.927529 1.092267 1.489455 1 384.0 384.0 384.0 5.026909 3.567484 3.686400 2 512.0 512.0 512.0 8.192000 6.553600 6.898527 3 640.0 640.0 640.0 12.190476 10.448980 10.666666 4 768.0 768.0 768.0 13.405091 10.287628 14.503869 5 896.0 896.0 896.0 14.049280 13.380267 20.070399 6 1024.0 1024.0 1024.0 15.887515 12.264046 19.239927 7 1152.0 1152.0 1152.0 16.681475 15.633424 24.883201 8 1280.0 1280.0 1280.0 16.516129 15.340824 28.248276 9 1408.0 1408.0 1408.0 17.090206 14.774461 24.016635 10 1536.0 1536.0 1536.0 17.014154 15.624477 26.021647 11 1664.0 1664.0 1664.0 17.043394 15.073554 25.858942 12 1792.0 1792.0 1792.0 17.107190 16.171833 29.577431 13 1920.0 1920.0 1920.0 17.883570 15.762828 26.331430 14 2048.0 2048.0 2048.0 17.623127 17.032706 27.413751 15 2176.0 2176.0 2176.0 17.887688 16.686275 29.945905 16 2304.0 2304.0 2304.0 19.019006 17.933838 33.787654 17 2432.0 2432.0 2432.0 17.940270 17.288901 31.181425 18 2560.0 2560.0 2560.0 18.164080 17.075561 31.844508 19 2688.0 2688.0 2688.0 17.594183 16.703239 30.370742 20 2816.0 2816.0 2816.0 18.766871 18.089676 33.242537 21 2944.0 2944.0 2944.0 18.735350 17.855977 33.695763 22 3072.0 3072.0 3072.0 18.420008 17.766898 32.768000 23 3200.0 3200.0 3200.0 18.470418 17.704011 33.255391 24 3328.0 3328.0 3328.0 18.253370 17.710036 32.753092 25 3456.0 3456.0 3456.0 18.546485 17.793328 33.634362 26 3584.0 3584.0 3584.0 18.368824 17.833278 33.142423 27 3712.0 3712.0 3712.0 18.665424 17.938112 34.036574 28 3840.0 3840.0 3840.0 18.638578 18.076496 33.794348 29 3968.0 3968.0 3968.0 18.965486 18.190808 34.324595 30 4096.0 4096.0 4096.0 19.035276 18.365864 34.450135 ``` It's an overall win, getting roughly a 85% speed-up on large sizes. Note that the rounding is differs a little bit to the one [implemented in CUTLASS](https://github.com/NVIDIA/cutlass/blob/a8f2c80db0564c74f4efccac71993b971dfc448b/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h#L99-L100). We could implement that rounding if we wanted though. This is still a bit far from the 2x speed-ups announced by CUTLASS. To get close to those numbers, we should probably need to remove the stores to shared before `ldmatrix`.
- Loading branch information
Showing
60 changed files
with
404 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" | ||
|
||
using namespace mlir; | ||
namespace tt = mlir::triton; | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers | ||
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385 | ||
// For a, b f32 | ||
// dot(a, b, inputPrecision="tf32x3") -> | ||
// let aBig = f32ToTF32(a), aSmall = a - aBig; | ||
// let bBig = f32ToTF32(b), bSmall = b - bBig; | ||
// dot(aSmall, bBig, inputPrecision="tf32") + | ||
// dot(aBig, bSmall, inputPrecision="tf32") + | ||
// dot(aBig, bBig, inputPrecision="tf32") | ||
class TF32x3 : public OpRewritePattern<tt::DotOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tt::DotOp dotOp, | ||
PatternRewriter &rewriter) const override { | ||
|
||
auto isF32 = [](Value operand) { | ||
return operand.getType() | ||
.cast<RankedTensorType>() | ||
.getElementType() | ||
.isF32(); | ||
}; | ||
|
||
if (!(dotOp.getInputPrecision() == tt::InputPrecision::TF32x3 && | ||
isF32(dotOp.getA()) && isF32(dotOp.getB()))) { | ||
return failure(); | ||
} | ||
|
||
// Aux functions | ||
auto f32ToTF32 = [&](Value value) -> Value { | ||
return rewriter | ||
.create<tt::ElementwiseInlineAsmOp>( | ||
dotOp.getLoc(), value.getType(), "cvt.rna.tf32.f32 $0, $1;", | ||
"=r,r", | ||
/*isPure=*/true, /*pack=*/1, ArrayRef<Value>{value}) | ||
.getResult()[0]; | ||
}; | ||
auto sub = [&](Value a, Value b) -> Value { | ||
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b); | ||
}; | ||
auto dot = [&](Value a, Value b, Value c) -> Value { | ||
return rewriter.create<tt::DotOp>(dotOp->getLoc(), c.getType(), a, b, c, | ||
tt::InputPrecision::TF32, | ||
dotOp.getMaxNumImpreciseAcc()); | ||
}; | ||
|
||
auto aBig = f32ToTF32(dotOp.getA()); | ||
auto aSmall = sub(dotOp.getA(), aBig); | ||
|
||
auto bBig = f32ToTF32(dotOp.getB()); | ||
auto bSmall = sub(dotOp.getB(), bBig); | ||
|
||
auto dot1 = dot(aSmall, bBig, dotOp.getC()); | ||
auto dot2 = dot(aBig, bSmall, dot1); | ||
auto dot3 = dot(aBig, bBig, dot2); | ||
|
||
rewriter.replaceOp(dotOp, dot3); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct F32DotTCPass : public TritonGPUF32DotTCBase<F32DotTCPass> { | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ModuleOp m = getOperation(); | ||
|
||
RewritePatternSet decomposePatterns(context); | ||
decomposePatterns.add<TF32x3>(context); | ||
if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) | ||
.failed()) { | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
} // anonymous namespace | ||
|
||
std::unique_ptr<Pass> mlir::triton::gpu::createF32DotTCPass() { | ||
return std::make_unique<F32DotTCPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.