From 6ce1307862b769a0e9ab815e9b3ef1a02ee6ad2c Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 4 Mar 2024 13:24:57 +0000 Subject: [PATCH] Implementation --- .../Dialect/Triton/IR/TritonAttrDefs.td | 11 +++ include/triton/Dialect/Triton/IR/TritonOps.td | 9 +- .../Dialect/TritonGPU/Transforms/Passes.h | 2 + .../Dialect/TritonGPU/Transforms/Passes.td | 14 +++ .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 4 +- lib/Analysis/Utility.cpp | 2 +- .../TritonToTritonGPUPass.cpp | 2 +- lib/Dialect/Triton/Transforms/Combine.cpp | 3 +- lib/Dialect/Triton/Transforms/Combine.td | 16 ++-- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 2 +- .../TritonGPU/Transforms/CMakeLists.txt | 1 + lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp | 91 +++++++++++++++++++ .../Pipeliner/MatmulLoopPipeline.cpp | 4 +- .../Transforms/RemoveLayoutConversions.cpp | 2 +- python/src/ir.cc | 13 ++- python/src/passes.cc | 1 + python/test/regression/test_cast_matmul.py | 5 +- .../ttgir_tests/wgmma_64_64_16_f16_NT.ttgir | 2 +- .../ttgir_tests/wgmma_64_64_16_f16_TN.ttgir | 2 +- .../wgmma_a_ldgsts_64_64_16_f16.ttgir | 2 +- ...wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir | 2 +- .../wgmma_ldgsts_64_64_16_f16.ttgir | 2 +- .../wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir | 2 +- ...mma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir | 2 +- .../ttgir_tests/wgmma_tma_64_64_16_f16.ttgir | 2 +- python/test/unit/language/test_core.py | 57 ++++++------ python/triton/language/core.py | 16 ++-- python/triton/language/semantic.py | 18 +++- python/triton/ops/flash_attention.py | 16 ++-- python/triton/ops/matmul.py | 14 +-- python/triton/runtime/interpreter.py | 5 +- test/Analysis/test-alias.mlir | 2 +- test/Analysis/test-allocation.mlir | 6 +- test/Analysis/test-membar.mlir | 2 +- .../amd/tritongpu_wmma_dot_to_llvm.mlir | 2 +- test/Conversion/invalid.mlir | 6 +- test/Conversion/triton_ops.mlir | 8 +- test/Conversion/triton_to_tritongpu.mlir | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 16 ++-- test/Conversion/tritongpu_to_llvm_hopper.mlir | 12 +-- test/Triton/combine.mlir | 10 +- test/Triton/rewrite-tensor-pointer.mlir | 2 +- test/TritonGPU/accelerate-matmul.mlir | 14 +-- test/TritonGPU/combine.mlir | 8 +- test/TritonGPU/dot-operands.mlir | 18 ++-- test/TritonGPU/fence-inserstion.mlir | 4 +- test/TritonGPU/loop-pipeline-hopper.mlir | 26 +++--- test/TritonGPU/loop-pipeline.mlir | 36 ++++---- test/TritonGPU/matmul.mlir | 2 +- .../pipeline-hopper-remove-wait.mlir | 4 +- test/TritonGPU/prefetch.mlir | 2 +- test/TritonGPU/reorder-instructions.mlir | 8 +- third_party/amd/backend/compiler.py | 4 +- .../AccelerateAMDMatmul.cpp | 2 +- .../RemoveLayoutConversions.cpp | 2 +- third_party/nvidia/backend/compiler.py | 6 +- .../DotOpToLLVM/MMAv2.cpp | 2 +- .../DotOpToLLVM/WGMMA.cpp | 8 +- 58 files changed, 342 insertions(+), 196 deletions(-) create mode 100644 lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index f65daa1ddc19..96e92a54241d 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -107,4 +107,15 @@ def TT_PropagateNanAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } +// F32Backend +def TT_F32BackendAttr : I32EnumAttr< + "F32Backend", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + #endif diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 75200dde8e8a..739a2cdf36d7 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -566,14 +566,19 @@ def TT_DotOp : TT_Op<"dot", [Pure, let summary = "dot"; let description = [{ - $d = matrix_multiply($a, $b) + $c + $d = matrix_multiply($a, $b) + $c. $f32Backend describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. }]; let arguments = (ins TT_TensorOrMemDesc:$a, TT_TensorOrMemDesc:$b, TT_FpIntTensor:$c, - BoolAttr:$allowTF32, + TT_F32BackendAttr:$f32Backend, I32Attr:$maxNumImpreciseAcc); let results = (outs TT_FpIntTensor:$d); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index f4e7e95ca5b1..05f853de8f64 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -14,6 +14,8 @@ std::unique_ptr createPipelinePass(int numStages = 3, int numWarps = 4, std::unique_ptr createAccelerateMatmulPass(int computeCapability = 80); +std::unique_ptr createF32DotTCPass(); + std::unique_ptr createPrefetchPass(); std::unique_ptr createCoalescePass(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index dc4553b52b6d..09cf8a0910f8 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -36,6 +36,20 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { ]; } +def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { + let summary = "3xTF32 trick"; + + let description = [{ + Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s + to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385 + }]; + + let constructor = "mlir::triton::gpu::createF32DotTCPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; +} + def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { let summary = "prefetch"; diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 4c9c285da388..68afac9314c7 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -85,13 +85,13 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure, let summary = "dot async"; let description = [{ - $d = matrix_multiply($a, $b) + $c + $d = matrix_multiply($a, $b) + $c. For docs on F32BackendAttr, see TT_DotOp }]; let arguments = (ins TT_TensorOrMemDesc:$a, TT_TensorOrMemDesc:$b, TT_FpIntTensor:$c, - BoolAttr:$allowTF32, + TT_F32BackendAttr:$f32Backend, I32Attr:$maxNumImpreciseAcc); let results = (outs TT_FpIntTensor:$d); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 43e6e183037e..fcfd846c192d 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -501,7 +501,7 @@ bool supportMMA(triton::DotOp op, int version) { } } if (aElemTy.isF32() && bElemTy.isF32()) { - return op.getAllowTF32() && version >= 2; + return op.getF32Backend() == F32Backend::TF32 && version >= 2; } return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 8846553dab0b..e3f703a3eb8c 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -270,7 +270,7 @@ struct TritonDotPattern : public OpConversionPattern { c = rewriter.create(c.getLoc(), retType, c); addNamedAttrs(rewriter.replaceOpWithNewOp( - op, retType, a, b, c, adaptor.getAllowTF32(), + op, retType, a, b, c, adaptor.getF32Backend(), adaptor.getMaxNumImpreciseAcc()), adaptor.getAttributes()); return success(); diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 3cfa8e65a11b..51398668e179 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -217,7 +217,8 @@ class CombineBroadcastMulReducePattern : public RewritePattern { rewriter.create(op->getLoc(), rewriter.getF32FloatAttr(0))); rewriter.replaceOpWithNewOp(op, expandLhsOp.getSrc(), - expandRhsOp.getSrc(), newAcc, true, 0); + expandRhsOp.getSrc(), newAcc, + F32Backend::TF32, 0); return success(); } }; diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 49e6950aa69a..c827f185f4f9 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -12,25 +12,25 @@ include "mlir/IR/PatternBase.td" // AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) // AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) def CombineDotAddIPattern : Pat< - (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $overflow), - (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)), + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $overflow), + (TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath), - (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddIRevPattern : Pat< - (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $overflow), - (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)), + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $d, $overflow), + (TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath), - (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), (ConstrainthasOneUse()">, "dot result has a single use">)]>; diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 615111f2ec8d..b23e2bcfe061 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -325,7 +325,7 @@ class BlockedToMMA : public mlir::RewritePattern { } // convert dot instruction auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getAllowTF32(), + newAcc, dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc()); rewriter.replaceOpWithNewOp(op, oldRetType, diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index c263a18d4295..2f4a4e9a9a52 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonGPUTransforms AccelerateMatmul.cpp Coalesce.cpp + F32DotTC.cpp ReduceDataDuplication.cpp OptimizeDotOperands.cpp OptimizeThreadLocality.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp new file mode 100644 index 000000000000..b733ca56759d --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -0,0 +1,91 @@ +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +using namespace mlir; +namespace tt = mlir::triton; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers +// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385 +// For a, b f32 +// dot(a, b, f32Backend="tf32x3") -> +// let aBig = f32ToTF32(a), aSmall = a - aBig; +// let bBig = f32ToTF32(b), bSmall = b - bBig; +// dot(aSmall, bBig, f32Backend="tf32") + +// dot(aBig, bSmall, f32Backend="tf32") + +// dot(aBig, bBig, f32Backend="tf32") +class TF32x3 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tt::DotOp dotOp, + PatternRewriter &rewriter) const override { + + auto isF32 = [](Value operand) { + return operand.getType() + .cast() + .getElementType() + .isF32(); + }; + + if (!(dotOp.getF32Backend() == tt::F32Backend::TF32x3 && + isF32(dotOp.getA()) && isF32(dotOp.getB()))) { + return failure(); + } + + // Aux functions + auto f32ToTF32 = [&](Value value) -> Value { + return rewriter + .create( + dotOp.getLoc(), value.getType(), "cvt.rna.tf32.f32 $0, $1;", + "=r,r", + /*isPure=*/true, /*pack=*/1, ArrayRef{value}) + .getResult()[0]; + }; + auto sub = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; + auto dot = [&](Value a, Value b, Value c) -> Value { + return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, + tt::F32Backend::TF32, + dotOp.getMaxNumImpreciseAcc()); + }; + + auto aBig = f32ToTF32(dotOp.getA()); + auto aSmall = sub(dotOp.getA(), aBig); + + auto bBig = f32ToTF32(dotOp.getB()); + auto bSmall = sub(dotOp.getB(), bBig); + + auto dot1 = dot(aSmall, bBig, dotOp.getC()); + auto dot2 = dot(aBig, bSmall, dot1); + auto dot3 = dot(aBig, bBig, dot2); + + rewriter.replaceOp(dotOp, dot3); + return success(); + } +}; + +struct F32DotTCPass : public TritonGPUF32DotTCBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { + signalPassFailure(); + } + } +}; +} // anonymous namespace + +std::unique_ptr mlir::triton::gpu::createF32DotTCPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index b3d34dc22e96..61bb087f7b78 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1162,8 +1162,8 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { if (resEnc && resEnc.isHopper()) { builder.setInsertionPoint(dotOp); builder.replaceOpWithNewOp( - dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(), - dotOp.getMaxNumImpreciseAcc()); + dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc()); } } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index d92a048af753..a45e9d98979e 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -64,7 +64,7 @@ class ConvertDotConvert : public RewritePattern { auto _0 = rewriter.create(op->getLoc(), dotOp.getType(), _0f); auto newDot = rewriter.create( op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1), - _0, dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); + _0, dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc()); auto newCvt = rewriter.create(op->getLoc(), dstTy, newDot.getResult()); rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getSrc()); diff --git a/python/src/ir.cc b/python/src/ir.cc index 65bddee98cdf..37478ed1ebe3 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -191,6 +191,12 @@ void init_triton_ir(py::module &&m) { .value("NONE", PropagateNan::NONE) .value("ALL", PropagateNan::ALL); + py::enum_(m, "F32BACKEND", py::module_local()) + .value("TF32", F32Backend::TF32) + .value("TF32x3", F32Backend::TF32x3) + .value("IEEE", F32Backend::IEEE) + .export_values(); + py::class_(m, "context", py::module_local()).def(py::init<>()); m.def("load_dialects", [](MLIRContext &context) { @@ -1370,9 +1376,10 @@ void init_triton_ir(py::module &&m) { ProgramIDDim(axis))); }) .def("create_dot", - [](TritonOpBuilder &self, Value &a, Value &b, Value &c, - bool allowTF32, int maxNumImpreciseAcc) -> Value { - return self.create(c.getType(), a, b, c, allowTF32, + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, F32Backend f32Backend, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, f32Backend, maxNumImpreciseAcc); }) .def("create_floor", diff --git a/python/src/passes.cc b/python/src/passes.cc index 2836ad7f7646..12f23db205f5 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -51,6 +51,7 @@ void init_triton_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_prefetch", createPrefetchPass); ADD_PASS_WRAPPER_1("add_accelerate_matmul", createAccelerateMatmulPass, int); ADD_PASS_WRAPPER_0("add_reorder_instructions", createReorderInstructionsPass); + ADD_PASS_WRAPPER_0("add_f32_dot_tc", createF32DotTCPass); ADD_PASS_WRAPPER_0("add_optimize_dot_operands", createOptimizeDotOperandsPass); ADD_PASS_WRAPPER_0("add_remove_layout_conversions", diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index 1477bc5a41bd..36250a3b35d8 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -34,7 +34,6 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype)) out_triton = torch.empty((M, N), device=device, dtype=torch_dtype) - allow_tf32 = True # launch kernel BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32 grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1) @@ -45,7 +44,6 @@ def matmul_kernel(A, B, C, M, N, K, # stride_bk, stride_bn, # stride_cm, stride_cn, # dot_out_dtype: tl.constexpr, # - allow_tf32: tl.constexpr, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): # matrix multiplication @@ -75,7 +73,7 @@ def matmul_kernel(A, B, C, M, N, K, # b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) a = a.to(C.dtype.element_ty) b = b.to(C.dtype.element_ty) - acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + acc += tl.dot(a, b, out_dtype=dot_out_dtype) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk acc = acc.to(C.dtype.element_ty) @@ -91,7 +89,6 @@ def matmul_kernel(A, B, C, M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, # - allow_tf32=allow_tf32, # GROUP_M=8, # BLOCK_M=BLOCK_M, # BLOCK_N=BLOCK_N, # diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_NT.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_NT.ttgir index a42660a9b909..a13c08018890 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_NT.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_NT.ttgir @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> %27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0> %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> - %29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %27, %28, %cst {f32Backend = 0 : i32, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_TN.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_TN.ttgir index 80aca369353b..ed880081d570 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_TN.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_64_64_16_f16_TN.ttgir @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> %27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0> %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> - %29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %27, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_64_64_16_f16.ttgir index 8a7c32878592..45f4d5b8e9aa 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_64_64_16_f16.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_64_64_16_f16.ttgir @@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> - %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir index e0e92a0fd11e..5f4bf09e7ae7 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_a_ldgsts_mbarrier_64_64_16_f16.ttgir @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr, #blocked2>, tensor<16x64xi32, #blocked2> %26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2> %28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1> - %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_64_64_16_f16.ttgir index d324abd362db..200e8bdcf9b2 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_64_64_16_f16.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_64_64_16_f16.ttgir @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : triton_gpu.async_commit_group triton_gpu.async_wait {num = 1 : i32} %28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> - %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir index 9fb4b6f93f02..2ada6f4ebfc2 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_64_64_16_f16.ttgir @@ -51,7 +51,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : triton_nvidia_gpu.mbarrier_wait %mbar1, %i1_false : !tt.ptr // triton_gpu.async_wait {num = 1 : i32} %28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> - %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir index 979cad7c9fb4..9a127479b7a0 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_ldgsts_mbarrier_vec_64_64_16_f16.ttgir @@ -53,7 +53,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : triton_nvidia_gpu.mbarrier_wait %mbar1_s, %i1_false : !tt.ptr // triton_gpu.async_wait {num = 1 : i32} %28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> - %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1> %31 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> %32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1> diff --git a/python/test/unit/hopper/ttgir_tests/wgmma_tma_64_64_16_f16.ttgir b/python/test/unit/hopper/ttgir_tests/wgmma_tma_64_64_16_f16.ttgir index f64283a1b39c..15d102c6fa27 100644 --- a/python/test/unit/hopper/ttgir_tests/wgmma_tma_64_64_16_f16.ttgir +++ b/python/test/unit/hopper/ttgir_tests/wgmma_tma_64_64_16_f16.ttgir @@ -46,7 +46,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %28 = triton_gpu.extract_slice %b_smem_loaded[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> // Calling MMA - %29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> // Epilogue %30 = tt.splat %cStride0: (i32) -> tensor<64x1xi32, #blocked1> diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 22095c66c67b..c40a60d10b54 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2759,28 +2759,28 @@ def convert_fp8_to_fp32(x, device, dtype_str): @pytest.mark.interpreter @pytest.mark.parametrize( - "M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", - [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype) + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype", + [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype) for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] - for allow_tf32 in [True, False] + for input_precision in ['tf32', 'tf32x3', 'ieee'] for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] - if not (allow_tf32 and (in_dtype in ['float16']))] + - [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype) + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype) for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] - for allow_tf32 in [True] + for input_precision in ["tf32"] for col_a in [True, False] for col_b in [True, False] for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] + - [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') + [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32') for col_a in [True, False] - for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', False, 'bfloat16', 'float32')] + - [(128, 128, 64, 4, False, False, 'chain-dot', False, float8_type, 'float32') + for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32')] + + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32') for float8_type in ["float8e5", "float8e4nv"]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device): +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, num_ctas, device): check_cuda_only(device) capability = torch.cuda.get_device_capability() @@ -2790,7 +2790,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o if capability[0] < 8: if capability[1] == 0 and in_dtype == 'int8': pytest.skip("Only test int8 on devices with sm >= 75") - if allow_tf32: + if input_precision != "ieee": pytest.skip("Only test tf32 on devices with sm >= 80") if capability[0] == 7: if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: @@ -2807,7 +2807,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o "numpy.dot with int8 inputs will overflow while tl.dot doesn't because MMA instruction's accumulator is 32-bit" ) - torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + torch.backends.cuda.matmulallow_tf32 = input_precision == "tf32" if num_ctas > 1 and in_dtype == 'int8': # FIXME: mma v2 with num_ctas > 1 does not work @@ -2817,7 +2817,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, - ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, ALLOW_TF32: tl.constexpr, DO_SOFTMAX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, F32_BACKEND: tl.constexpr, DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) @@ -2829,7 +2829,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn x = tl.load(Xs) y = tl.load(Ys) - z = tl.dot(x, y, allow_tf32=ALLOW_TF32, out_dtype=out_dtype) + z = tl.dot(x, y, input_precision=F32_BACKEND, out_dtype=out_dtype) if ADD_MATRIX: z += tl.load(Zs) if ADD_ROWS: @@ -2846,7 +2846,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid z = num / den[:, None] if CHAIN_DOT: w = tl.load(Ws) - z = tl.dot(z.to(w.dtype), w, allow_tf32=ALLOW_TF32, out_dtype=out_dtype) + z = tl.dot(z.to(w.dtype), w, input_precision=F32_BACKEND, out_dtype=out_dtype) tl.store(Zs, z) # input @@ -2863,7 +2863,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid if 'int' not in in_dtype and 'float8' not in in_dtype: x *= .1 y *= .1 - if in_dtype == 'float32' and allow_tf32: + if in_dtype == 'float32' and input_precision == "tf32": x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') @@ -2894,10 +2894,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), COL_A=col_a, COL_B=col_b, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', ADD_ROWS=epilogue == 'add-rows', ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', CHAIN_DOT=epilogue == 'chain-dot', ALLOW_TF32=allow_tf32, - num_warps=num_warps, num_ctas=num_ctas, out_dtype=out_dtype) + DO_SOFTMAX=epilogue == 'softmax', CHAIN_DOT=epilogue == 'chain-dot', + F32_BACKEND=input_precision, num_warps=num_warps, num_ctas=num_ctas, out_dtype=out_dtype) - if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32): + if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): if not is_cuda(): pass else: @@ -2956,7 +2956,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'st.global.v2' in ptx else: assert 'st.global.v4' in ptx - if in_dtype == 'float32' and allow_tf32: + if in_dtype == 'float32' and input_precision != "ieee": assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) elif in_dtype == 'float16' and out_dtype == tl.float32: if capability[0] == 7 and capability[1] == 5: # Turing @@ -3012,7 +3012,7 @@ def kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ALLOW_TF32: tl.constexpr, + F32_BACKEND: tl.constexpr, out_dtype: tl.constexpr = tl.float32, ): startm = tl.program_id(0) * BLOCK_M @@ -3027,7 +3027,7 @@ def kernel( None, None, :] * stride_kn q = tl.load(q_ptrs) k = tl.load(k_ptrs) - qk = tl.dot(q, k, allow_tf32=ALLOW_TF32, out_dtype=out_dtype) + qk = tl.dot(q, k, input_precision=F32_BACKEND, out_dtype=out_dtype) o_ptrs = o_ptr + offs_b[:, None, None] * stride_ob + offs_m[None, :, None] * stride_om + offs_n[ None, None, :] * stride_on tl.store(o_ptrs, qk) @@ -3076,7 +3076,7 @@ def kernel( BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - ALLOW_TF32=bool(in_dtype_str == 'float32'), + F32_BACKEND="tf32" if in_dtype_str == 'float32' else "ieee", out_dtype=out_dtype, num_warps=num_warps, ) @@ -3296,17 +3296,12 @@ def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.co @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ['float32', 'float16']) def test_dot_without_load(dtype_str, device): - if is_cuda(): - capability = torch.cuda.get_device_capability() - allow_tf32 = capability[0] > 7 - else: - allow_tf32 = True @triton.jit - def _kernel(out, ALLOW_TF32: tl.constexpr): + def _kernel(out): a = GENERATE_TEST_HERE b = GENERATE_TEST_HERE - c = tl.dot(a, b, allow_tf32=ALLOW_TF32) + c = tl.dot(a, b) out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] tl.store(out_ptr, c) @@ -3315,7 +3310,7 @@ def _kernel(out, ALLOW_TF32: tl.constexpr): b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) out_ref = torch.matmul(a, b) out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) - kernel[(1, )](out, ALLOW_TF32=allow_tf32) + kernel[(1, )](out) assert torch.all(out == out_ref) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index db749aa13933..0c1857c8bacb 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1463,7 +1463,7 @@ def expand_dims(input, axis, _builder=None): @builtin -def dot(input, other, acc=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, _builder=None): +def dot(input, other, acc=None, input_precision=None, max_num_imprecise_acc=None, out_dtype=float32, _builder=None): """ Returns the matrix product of two blocks. @@ -1473,16 +1473,16 @@ def dot(input, other, acc=None, allow_tf32=None, max_num_imprecise_acc=None, out :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} :param other: The second tensor to be multiplied. :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input_precision: How to exercise the Tenors cores for f32 x f32. If the device does not have Tensor Cores + or the inputs are not of dtype f32, this option is ignored. + :type other: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. """ - if allow_tf32 is None: - if get_bool_env_var("TRITON_F32_DEFAULT"): - allow_tf32 = False - else: - allow_tf32 = True - allow_tf32 = _constexpr_to_value(allow_tf32) + if input_precision is None: + input_precision = os.getenv("TRITON_F32_DEFAULT", None) + input_precision = _constexpr_to_value(input_precision) out_dtype = _constexpr_to_value(out_dtype) max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) - return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ac8e5af53201..f8eb11eef9f1 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1292,7 +1292,16 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope # ===----------------------------------------------------------------------===// -def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int, +def _str_to_dot_f32_backend(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.F32BACKEND, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): @@ -1327,6 +1336,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + f32_backend = _str_to_dot_f32_backend(input_precision, builder) + lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" @@ -1368,7 +1382,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): else: max_num_imprecise_acc = 0 - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty) + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, f32_backend, max_num_imprecise_acc), ret_ty) # ===----------------------------------------------------------------------===// diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index 8fa2d61f5885..f18e480a07a7 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -78,14 +78,14 @@ def _fwd_kernel(Q, K, V, sm_scale, # qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) if IS_CAUSAL: qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk += tl.dot(q, k, allow_tf32=True) + qk += tl.dot(q, k) # -- compute scaling constant --- m_i_new = tl.maximum(m_i, tl.max(qk, 1)) alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) # -- scale and update acc -- acc *= alpha[:, None] - acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) + acc += tl.dot(p.to(V.dtype.element_ty), v) # -- update m_i and l_i -- l_i = l_i * alpha + tl.sum(p, 1) m_i = m_i_new @@ -197,26 +197,26 @@ def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, # p = tl.math.exp2(qk - l_i[:, None]) # compute dv do = tl.load(DO_block_ptr) - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do, allow_tf32=True) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp = tl.dot(do, tl.trans(v), allow_tf32=True) + dp = tl.dot(do, tl.trans(v)) # compute ds = p * (dp - delta[:, None]) ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty) # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds), q, allow_tf32=True) + dk += tl.dot(tl.trans(ds), q) # compute dq if not SEQUENCE_PARALLEL: dq = tl.load(DQ_block_ptr) - dq += tl.dot(ds, k, allow_tf32=True) + dq += tl.dot(ds, k) tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) elif SEQUENCE_PARALLEL: if MMA_V3: - dq = tl.dot(ds, k, allow_tf32=True) + dq = tl.dot(ds, k) else: # not work with mma v3, because M % 64 != 0 - dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True)) + dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) # increment pointers diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index e170091d67e4..f7f577a1b80b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -90,7 +90,7 @@ def _kernel(A, B, C, M, N, K, # stride_bk, stride_bn, # stride_cm, stride_cn, # acc_dtype: tl.constexpr, # - allow_tf32: tl.constexpr, # + input_precision: tl.constexpr, # fp8_fast_accum: tl.constexpr, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # @@ -129,9 +129,9 @@ def _kernel(A, B, C, M, N, K, # a = a.to(AB_DTYPE) b = b.to(AB_DTYPE) if fp8_fast_accum: - acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=input_precision) else: - acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(C.dtype.element_ty) @@ -153,7 +153,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype): + def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -205,14 +205,14 @@ def to_tl_type(ty): b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # acc_dtype=acc_dtype, # - allow_tf32=allow_tf32, # + input_precision=input_precision, # fp8_fast_accum=fp8_fast_accum, # GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod - def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None): - return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, + def forward(ctx, a, b, acc_dtype=None, input_precision=None, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum, output_dtype=output_dtype) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 304b8ca53358..ecb44e587a71 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1,4 +1,5 @@ import inspect +from typing import Tuple import numpy as np @@ -62,6 +63,8 @@ class InterpreterOptions: debug: bool = False arch: str = None allow_fp8e4nv: bool = False + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "3xtf32", "ieee") max_num_imprecise_acc_default: int = 0 @@ -342,7 +345,7 @@ def unary_op(self, arg, op): def create_trans(self, arg, perm): return TensorHandle(np.transpose(arg.data, perm), arg.dtype) - def create_dot(self, a, b, d, allow_tf32, max_num_imprecise_acc): + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): return TensorHandle(np.matmul(a.data, b.data) + d.data, d.dtype) def create_make_range(self, start, stop): diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 9735dac15a68..59da32eebd45 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -29,7 +29,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 8d0187ab6f38..583dd1c05aa7 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -34,7 +34,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: offset = 0, size = 4224 %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -64,11 +64,11 @@ tt.func @reusable(%A : !tt.ptr) { %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 4608 %a3 = triton_gpu.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> - %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a1, %a2, %c_init {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> // CHECK-NEXT: offset = 0, size = 1152 %a4 = triton_gpu.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> - %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c1 = tt.dot %a3, %a4, %c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return // CHECK-NEXT: size = 4608 } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 3a467f0414ad..3505b7424826 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -32,7 +32,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index 0944659ff986..a8ddf30254f2 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -20,7 +20,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: llvm.mlir.undef : vector<16xf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2 {allowTF32 = false, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<16x16xf16, #mma> + %0 = tt.dot %arg0, %arg1, %arg2 {f32Backend = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<16x16xf16, #mma> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16> // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> diff --git a/test/Conversion/invalid.mlir b/test/Conversion/invalid.mlir index c1a104be544d..8c8457daa2aa 100644 --- a/test/Conversion/invalid.mlir +++ b/test/Conversion/invalid.mlir @@ -6,7 +6,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{element types of operands A and B must have same bit width}} - %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching encoding between A and B operands}} - %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching kWidth between A and B operands}} - %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index 842338f7662c..71f5d1ee5563 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -163,13 +163,13 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { %zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32> - %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> + %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> + %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32> - %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> + %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> %ptr128x128 = tt.splat %ptr : !tt.ptr -> tensor<128x128x!tt.ptr> %ptr32x32 = tt.splat %ptr : !tt.ptr -> tensor<32x32x!tt.ptr> diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 40340ad85efe..26dd1fe2e4cf 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -6,7 +6,7 @@ tt.func @ops() { %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> - %0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %0 = tt.dot %a, %b, %c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index c5e00971c08d..d8b81d71fd4f 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -820,7 +820,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 - %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -969,7 +969,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared> -> tensor<128x32xf16, #dot_operand_a> %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared> -> tensor<32x256xf16, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> @@ -995,7 +995,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0> -> tensor<32x64xf16, #dot_operand_a> %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1> -> tensor<64x64xf16, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x64xf32, #mma> -> tensor<32x64xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> @@ -1018,7 +1018,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared> -> tensor<32x16xf32, #dot_operand_a> %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared> -> tensor<16x32xf32, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> + %28 = tt.dot %a_mat, %b_mat, %cst {f32Backend = 2 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> tt.store %36, %28 : tensor<32x32xf32, #blocked> @@ -1055,7 +1055,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> @@ -1279,7 +1279,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> - %28 = tt.dot %a, %b_mat, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %28 = tt.dot %a, %b_mat, %c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> @@ -1309,7 +1309,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> - %0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %0 = tt.dot %cst_0, %cst_1, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> @@ -1479,7 +1479,7 @@ module attributes {"triton_gpu.compute-capability" = 70 : i32, "triton_gpu.num-c %a = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> %b = arith.constant dense<0.000000e+00> : tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> - %87 = tt.dot %a, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<32x32xf32, #mma> + %87 = tt.dot %a, %b, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<32x32xf32, #mma> tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index b09e4ee9d3ce..2b0de245a7f4 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd %m = triton_nvidia_gpu.dot_async %a, %b, %c - {maxNumImpreciseAcc = 32 : i32, allowTF32 = true} : + {maxNumImpreciseAcc = 32 : i32, f32Backend = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return } @@ -39,7 +39,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-NOT: llvm.fadd // CHECK: llvm.return %m = triton_nvidia_gpu.dot_async %a, %b, %c - {maxNumImpreciseAcc = 129 : i32, allowTF32 = true} : + {maxNumImpreciseAcc = 129 : i32, f32Backend = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return } @@ -63,7 +63,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-128: llvm.fadd // CHECK: llvm.return %m = triton_nvidia_gpu.dot_async %a, %b, %c - {maxNumImpreciseAcc = 64 : i32, allowTF32 = true} : + {maxNumImpreciseAcc = 64 : i32, f32Backend = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return } @@ -80,7 +80,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { tt.func @dot_zero_acc(%a: !tt.memdesc<128x64xf16, #shared>, %b: !tt.memdesc<64x64xf16, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %m = triton_nvidia_gpu.dot_async %a, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + %m = triton_nvidia_gpu.dot_async %a, %b, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } @@ -98,7 +98,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - %m = tt.dot %opA, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + %m = tt.dot %opA, %b, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } @@ -116,7 +116,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> - %m = tt.dot %a, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 1073741824 : i32} : + %m = tt.dot %a, %b, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> tt.return } diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index a7afba95bbc3..3b73c6c4fdbe 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -11,7 +11,7 @@ tt.func @test_combine_dot_add_invalid_pattern() -> (tensor<128x128xf32>, tensor< %d = arith.constant dense<3.0> : tensor<128x128xf32> %e = arith.constant dense<4.0> : tensor<128x128xf32> - %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %dot_out = tt.dot %a, %b, %zero {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> // CHECK: arith.addf %{{.*}}, %[[d]] : tensor<128x128xf32> %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> @@ -33,9 +33,9 @@ tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>) { %zero = arith.constant dense<0.0> : tensor<128x128xf32> %d = arith.constant dense<3.0> : tensor<128x128xf32> - %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %dot_out = tt.dot %a, %b, %zero {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> - // CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> // CHECK-NEXT: tt.return %[[res]] : tensor<128x128xf32> %res = arith.addf %dot_out, %d : tensor<128x128xf32> @@ -53,9 +53,9 @@ tt.func @test_combine_dot_add_rev_pattern() -> (tensor<128x128xf32>) { %zero = arith.constant dense<0.0> : tensor<128x128xf32> %d = arith.constant dense<3.0> : tensor<128x128xf32> - %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %dot_out = tt.dot %a, %b, %zero {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> - // CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> // CHECK-NEXT: tt.return %[[res]] : tensor<128x128xf32> %res = arith.addf %d, %dot_out : tensor<128x128xf32> diff --git a/test/Triton/rewrite-tensor-pointer.mlir b/test/Triton/rewrite-tensor-pointer.mlir index 2449374420e7..3c8fa7506e30 100644 --- a/test/Triton/rewrite-tensor-pointer.mlir +++ b/test/Triton/rewrite-tensor-pointer.mlir @@ -46,7 +46,7 @@ tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %55 = tt.load %arg11 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr> -> tensor<128x32xf16> // CHECK: tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16> %56 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 2 : i32} : !tt.ptr> -> tensor<32x32xf16> - %57 = tt.dot %55, %56, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16> * tensor<32x32xf16> -> tensor<128x32xf32> + %57 = tt.dot %55, %56, %arg10 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16> * tensor<32x32xf16> -> tensor<128x32xf32> // CHECK-NOT: tt.advance %58 = tt.advance %arg11, [%c0_i32, %c32_i32] : !tt.ptr> // CHECK-NOT: tt.advance diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index fe6b03503ab8..2dd122e288ed 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -27,18 +27,18 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { - %172 = tt.dot %170, %171, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> + %172 = tt.dot %170, %171, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> %178 = triton_gpu.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %180 = tt.dot %178, %179, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + %180 = tt.dot %178, %179, %arg16 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> scf.yield %180 : tensor<128x64xf16, #blocked1> } // CHECK: scf.for // CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { - %166 = tt.dot %164, %165, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> + %166 = tt.dot %164, %165, %cst_2 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> %172 = triton_gpu.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> - %174 = tt.dot %172, %173, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + %174 = tt.dot %172, %173, %arg16 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> scf.yield %174 : tensor<128x64xf16, #blocked1> } tt.store %153, %149 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked1> @@ -61,12 +61,12 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> // CHECK-80: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> - %d = tt.dot %arg0, %arg1, %cst_0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + %d = tt.dot %arg0, %arg1, %cst_0 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> %c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> // CHECK-80: tt.dot {{.*}} -> tensor<64x128xf32, #[[$MMA]]> - %r = tt.dot %c, %arg2, %cst_1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + %r = tt.dot %c, %arg2, %cst_1 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> tt.return %r : tensor<64x128xf32, #blocked1> } @@ -88,7 +88,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c // CHECK-80: tt.fp_to_fp {{.*}} : tensor<64x128xf8E4M3B11FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$MMA]], kWidth = 4}>> // CHECK-80: tt.fp_to_fp {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$MMA]], kWidth = 4}>> // CHECK-80: tt.dot {{.*}} -> tensor<64x64xf32, #[[$MMA]]> - %d = tt.dot %arg0, %arg1, %cst_0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + %d = tt.dot %arg0, %arg1, %cst_0 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<64x128xf8E4M3B11FNUZ, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> tt.return %d : tensor<64x64xf32, #blocked> } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 276b36d4d6a7..8410f4ba4b6d 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1548,7 +1548,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %26 = triton_gpu.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> %27 = triton_gpu.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> %28 = triton_gpu.convert_layout %cst : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked5> - %29 = tt.dot %26, %27, %28 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> + %29 = tt.dot %26, %27, %28 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> %30 = triton_gpu.convert_layout %29 : tensor<32x32xf32, #blocked5> -> tensor<32x32xf32, #blocked> %31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({ ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): @@ -1698,7 +1698,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %117 = tt.load %116 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked3> %118 = triton_gpu.convert_layout %41 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %119 = triton_gpu.convert_layout %97 : tensor<64x64xf16, #blocked6> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %120 = tt.dot %118, %119, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %120 = tt.dot %118, %119, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> %121 = triton_gpu.convert_layout %120 : tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #blocked2> %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ @@ -1727,7 +1727,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %142 = triton_gpu.convert_layout %141 : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %143 = triton_gpu.convert_layout %117 : tensor<64x64xf16, #blocked3> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> %144 = triton_gpu.convert_layout %140 : tensor<128x64xf32, #blocked2> -> tensor<128x64xf32, #blocked> - %145 = tt.dot %142, %143, %144 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> %146 = triton_gpu.convert_layout %145 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked2> %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ @@ -1980,7 +1980,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %74 = triton_gpu.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> %75 = triton_gpu.convert_layout %72 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %76 = triton_gpu.convert_layout %73 : tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %77 = tt.dot %75, %76, %74 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> + %77 = tt.dot %75, %76, %74 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x256xf32, #mma> %78 = triton_gpu.convert_layout %77 : tensor<32x256xf32, #mma> -> tensor<32x256xf32, #blocked3> scf.yield %78 : tensor<32x256xf32, #blocked3> } diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index bdef81286803..5d64397ad7df 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -36,7 +36,7 @@ tt.func @push_elementwise( %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> %dota = triton_gpu.convert_layout %a : tensor<16x16xf16, #ALR> -> tensor<16x16xf16, #Av2k4> %dotb = triton_gpu.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + %newc = tt.dot %dota, %dotb, %c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -58,7 +58,7 @@ tt.func @succeeds_if_arg_is_not_convert_layout( %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> %dotb = triton_gpu.convert_layout %b : tensor<16x16xf16, #BLC> -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + %newc = tt.dot %dota, %dotb, %c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil // CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @push_convert_both_operands( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -93,7 +93,7 @@ tt.func @push_convert_both_operands( %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> %al = triton_gpu.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %bl = triton_gpu.convert_layout %be : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %r = tt.dot %al, %bl, %c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -119,7 +119,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capabil // CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : tensor<16x16xf16, #[[BB]]> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @update_kwidth_slice( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -132,7 +132,7 @@ tt.func @update_kwidth_slice( %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> %al = triton_gpu.convert_layout %ae : tensor<16x16xf32, #blockedA> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %bl = triton_gpu.convert_layout %add : tensor<16x16xf32, #blockedB> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %r = tt.dot %al, %bl, %c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -149,7 +149,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> - %r = tt.dot %A, %arg1, %arg2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %r = tt.dot %A, %arg1, %arg2 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -165,7 +165,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> - %r = tt.dot %A, %arg1, %arg2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> + %r = tt.dot %A, %arg1, %arg2 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -188,7 +188,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %tb = tt.broadcast %tc : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> %ts = arith.select %tb, %tl, %cst_4 : tensor<128x128xi1, #blocked>, tensor<128x128xf16, #blocked> %conv = triton_gpu.convert_layout %ts : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %td = tt.dot %cst_0, %conv, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %td = tt.dot %cst_0, %conv, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> tt.return %td : tensor<128x128xf32, #mma> } } diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index 3b2fe3633937..8bbf6473cc3a 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -12,7 +12,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %0 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> %1 = triton_gpu.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> // CHECK: triton_nvidia_gpu.fence_async_shared - %2 = tt.dot %0, %1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = tt.dot %0, %1, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } } @@ -39,7 +39,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: tt.dot scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { - %2 = tt.dot %0, %1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = tt.dot %0, %1, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> } } tt.return diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 78c7857d6bcc..1fdae088a182 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -85,7 +85,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -184,7 +184,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -261,7 +261,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } @@ -334,7 +334,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // // %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA> // %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB> -// %c = triton_gpu_nvidia.dot_async %sa, %sb, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// %c = triton_gpu_nvidia.dot_async %sa, %sb, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> // %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> @@ -394,11 +394,11 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %18 = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %21 = tt.dot %19, %20, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %21 = tt.dot %19, %20, %cst_2 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = tt.dot %24, %23, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %25 = tt.dot %24, %23, %arg4 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -447,8 +447,6 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %18 = tt.load %16 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - // CHECK: %[[ALLOC1:.+]] = triton_gpu.local_alloc - // CHECK: %[[ALLOC2:.+]] = triton_gpu.local_alloc // CHECK: %[[R:.+]]:{{.+}} = scf.for // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.dot_async{{.*}} // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} @@ -458,11 +456,11 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: scf.yield // CHECK: %{{.*}}:2 = triton_nvidia_gpu.dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { - %21 = tt.dot %19, %20, %arg6 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %21 = tt.dot %19, %20, %arg6 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> - %25 = tt.dot %cst_4, %23, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %25 = tt.dot %cst_4, %23, %arg4 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -526,7 +524,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: scf.for // CHECK: triton_nvidia_gpu.dot_async // CHECK-NEXT: triton_nvidia_gpu.dot_wait - %39 = tt.dot %37, %38, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared> * !tt.memdesc<64x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + %39 = tt.dot %37, %38, %arg4 {f32Backend = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared> * !tt.memdesc<64x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> @@ -606,9 +604,9 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: triton_nvidia_gpu.dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { // This one can be async. - %dot0 = tt.dot %19, %20, %prev_dot1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot0 = tt.dot %19, %20, %prev_dot1 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> // This can't be async because its result is modified before it's yielded. - %dot1 = tt.dot %19, %20, %prev_dot1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot1 = tt.dot %19, %20, %prev_dot1 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> @@ -616,7 +614,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = tt.dot %cst_4, %23, %prev_dot2.1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dot2 = tt.dot %cst_4, %23, %prev_dot2.1 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index dc52a7fb2bfb..bf22d4b8d952 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -90,7 +90,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %b_ = triton_gpu.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -178,7 +178,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -246,7 +246,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } @@ -288,7 +288,7 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr @@ -335,7 +335,7 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92, %83 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr, i64 @@ -382,7 +382,7 @@ tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> @@ -432,7 +432,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %117 = triton_gpu.convert_layout %112 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %118 = triton_gpu.convert_layout %116 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %119 = tt.dot %117, %118, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %119 = tt.dot %117, %118, %arg10 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %131 = arith.index_cast %arg9 : index to i32 %120 = arith.addi %131, %c1_i32 : i32 %121 = arith.muli %120, %c32_i32 : i32 @@ -489,7 +489,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %151 = triton_gpu.convert_layout %146 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %152 = triton_gpu.convert_layout %150 : tensor<32x32xf32, #AL> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %153 = tt.dot %151, %152, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %153 = tt.dot %151, %152, %arg10 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %162 = arith.index_cast %arg9 : index to i32 %154 = arith.addi %162, %c2_i32 : i32 %155 = arith.muli %154, %c32_i32 : i32 @@ -561,7 +561,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %199 = tt.load %arg24, %198, %88 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %200 = triton_gpu.convert_layout %193 : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> %201 = triton_gpu.convert_layout %199 : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> - %202 = tt.dot %200, %201, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> + %202 = tt.dot %200, %201, %arg23 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> %203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi64, #BL> scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> } @@ -617,13 +617,13 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %18 = tt.load %16 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %21 = tt.dot %19, %20, %cst_1 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %27 = tt.dot %23, %26, %arg4 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -671,13 +671,13 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %18 = tt.load %16 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked> %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %21 = tt.dot %19, %20, %cst_1 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + %27 = tt.dot %23, %26, %arg4 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -764,7 +764,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %30 = tt.load %29 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %31 = triton_gpu.convert_layout %30 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %32 = triton_gpu.convert_layout %17 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %33 = tt.dot %31, %32, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %33 = tt.dot %31, %32, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %34 = tt.addptr %23, %28 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %35 = triton_gpu.convert_layout %33 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %34, %35 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> @@ -858,11 +858,11 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared> %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared> -> !tt.memdesc<64x32xf32, #shared1> %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %75 = tt.dot %71, %74, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %75 = tt.dot %71, %74, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked1> %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %79 = tt.dot %77, %78, %arg7 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %79 = tt.dot %77, %78, %arg7 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> scf.yield %79 : tensor<64x32xf32, #mma> } %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> @@ -920,7 +920,7 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : tensor<16x16xf16, #AL> -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : tensor<16x16xf16, #BL> -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> @@ -969,7 +969,7 @@ module attributes {"triton_gpu.compute-capability" = 86 : i32, "triton_gpu.num-c %50 = tt.broadcast %49 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> %51 = tt.load %46, %50, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %52 = triton_gpu.convert_layout %51 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %53 = tt.dot %cst_1, %52, %arg8 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %53 = tt.dot %cst_1, %52, %arg8 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %54 = triton_gpu.convert_layout %53 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %34, %54 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> scf.yield %cst1 : tensor<32x32xf32, #mma> diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index e048dcb15f2e..257f4341b7cb 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -62,7 +62,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1 %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> %77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> - %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> + %78 = tt.dot %76, %77, %cst_0 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> %79 = arith.addf %arg13, %78 : tensor<64x64xf32> %80 = arith.muli %arg7, %c64_i32 : i32 %81 = tt.splat %80 : i32 -> tensor<64x64xi32> diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 5878d841a9f9..6f1fb5af43fd 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -110,7 +110,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %112 = tt.load %111 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128xf16, #blocked> %113 = triton_gpu.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> %114 = triton_gpu.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> - %115 = tt.dot %113, %114, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %115 = tt.dot %113, %114, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> @@ -121,7 +121,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c // CHECK: triton_nvidia_gpu.dot_async // CHECK-NOT: triton_nvidia_gpu.dot_wait // CHECK: scf.yield - %119 = tt.dot %118, %117, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = tt.dot %118, %117, %arg23 {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 56bf5e3b0451..5a0cdcbdd879 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -56,7 +56,7 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %a_op_ = triton_gpu.local_load %a : !tt.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> %b_op = triton_gpu.local_load %b : !tt.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> - %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> + %c = tt.dot %a_op, %b_op, %prev_c {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir index eecea36373c1..90e885c3d0fc 100644 --- a/test/TritonGPU/reorder-instructions.mlir +++ b/test/TritonGPU/reorder-instructions.mlir @@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %9 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %11, %cst_0, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %12 = tt.dot %11, %cst_0, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> tt.return @@ -64,7 +64,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %A = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %AS = triton_gpu.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> %AD = triton_gpu.local_load %AS : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD, %BD, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %12 = tt.dot %AD, %BD, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> tt.return @@ -92,11 +92,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %A0 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %AS0 = triton_gpu.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> %AD0 = triton_gpu.local_load %AS0 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD0, %BD, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %12 = tt.dot %AD0, %BD, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %A1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %AS1 = triton_gpu.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> %AD1 = triton_gpu.local_load %AS1 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %13 = tt.dot %AD1, %BD, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %13 = tt.dot %AD1, %BD, %cst {f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> tt.return } } diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index eee4fd77721e..3137e25e7df7 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,7 +1,7 @@ from triton.backends.compiler import BaseBackend from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass -from typing import Any +from typing import Any, Tuple import hashlib import tempfile import os @@ -21,6 +21,8 @@ class HIPOptions: debug: bool = False arch: str = None allow_fp8e4nv: bool = False + default_dot_input_precision: str = "ieee" + allowed_dot_input_precisions: Tuple[str] = ("ieee",) enable_fp_fusion: bool = True capability: int = None # TODO: diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 449007cfa234..ade3b8125556 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -296,7 +296,7 @@ class BlockedToMFMA : public mlir::RewritePattern { a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create(dotOp.getLoc(), newAcc.getType(), - a, b, newAcc, dotOp.getAllowTF32(), + a, b, newAcc, dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc()); Value dotOutput = diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/RemoveLayoutConversions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/RemoveLayoutConversions.cpp index 5dbdaade7177..1aec9bcce24a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/RemoveLayoutConversions.cpp @@ -66,7 +66,7 @@ class ConvertDotConvert : public mlir::RewritePattern { op->getLoc(), dotOp.getType(), _0f); auto newDot = rewriter.create( op->getLoc(), dotOp.getType(), dotOp.getOperand(0), - dotOp.getOperand(1), _0, dotOp.getAllowTF32(), + dotOp.getOperand(1), _0, dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc()); auto newCvt = rewriter.create( op->getLoc(), dstTy, newDot.getResult()); diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 337b5b0408eb..b1ed44cb2c3a 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import functools -from typing import Any +from typing import Any, Tuple import hashlib import re import tempfile @@ -69,6 +69,8 @@ class CUDAOptions: ptx_version: int = None enable_fp_fusion: bool = True allow_fp8e4nv: bool = False + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False @@ -138,6 +140,8 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttir.add_convert_to_ttgpuir(pm, opt.num_warps, 32, opt.num_ctas, capability) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_f32_dot_tc(pm) # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) passes.ttgpuir.add_remove_layout_conversions(pm) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index ebfde2f6b9f2..dc10a9029e3f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -128,7 +128,7 @@ TensorCoreType getMmaType(triton::DotOp op) { if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16()) return TensorCoreType::FP32_BF16_BF16_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && - op.getAllowTF32()) + op.getF32Backend() == F32Backend::TF32) return TensorCoreType::FP32_TF32_TF32_FP32; } else if (dTy.getElementType().isInteger(32)) { if (aTy.getElementType().isInteger(8) && bTy.getElementType().isInteger(8)) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 5bc776d20c36..aef988ee1f58 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -523,8 +523,8 @@ LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // op.getA(), op.getB(), op.getC(), op.getD(), // adaptor.getA(), adaptor.getB(), adaptor.getC(), // - op.getAllowTF32(), op.getMaxNumImpreciseAcc(), true, - thread); + op.getF32Backend() == F32Backend::TF32, + op.getMaxNumImpreciseAcc(), true, thread); } LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, @@ -540,6 +540,6 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // op.getA(), op.getB(), op.getC(), op.getD(), // adaptor.getA(), adaptor.getB(), adaptor.getC(), - op.getAllowTF32(), op.getMaxNumImpreciseAcc(), false, - thread); + op.getF32Backend() == F32Backend::TF32, + op.getMaxNumImpreciseAcc(), false, thread); }