Skip to content

Commit

Permalink
Rename
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Mar 21, 2024
1 parent c4334e6 commit e614d5b
Show file tree
Hide file tree
Showing 47 changed files with 146 additions and 146 deletions.
6 changes: 3 additions & 3 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def TT_PropagateNanAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// F32Backend
def TT_F32BackendAttr : I32EnumAttr<
"F32Backend", "",
// InputPrecision
def TT_InputPrecisionAttr : I32EnumAttr<
"InputPrecision", "",
[
I32EnumAttrCase<"TF32", 0, "tf32">,
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
let summary = "dot";

let description = [{
$d = matrix_multiply($a, $b) + $c. $f32Backend describes how to exercise the TC
$d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC
when the inputs are f32. It can be one of: tf32, tf32x3, ieee.
tf32: use TC with tf32 ops.
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
Expand All @@ -578,7 +578,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
TT_F32BackendAttr:$f32Backend,
TT_InputPrecisionAttr:$inputPrecision,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
let summary = "dot async";

let description = [{
$d = matrix_multiply($a, $b) + $c. For docs on F32BackendAttr, see TT_DotOp
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
}];

let arguments = (ins TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
TT_F32BackendAttr:$f32Backend,
TT_InputPrecisionAttr:$inputPrecision,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ bool supportMMA(triton::DotOp op, int version) {
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getF32Backend() == F32Backend::TF32 && version >= 2;
return op.getInputPrecision() == InputPrecision::TF32 && version >= 2;
}
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);

addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, c, adaptor.getF32Backend(),
op, retType, a, b, c, adaptor.getInputPrecision(),
adaptor.getMaxNumImpreciseAcc()),
adaptor.getAttributes());
return success();
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
rewriter.getF32FloatAttr(0)));
rewriter.replaceOpWithNewOp<DotOp>(op, expandLhsOp.getSrc(),
expandRhsOp.getSrc(), newAcc,
F32Backend::TF32, 0);
InputPrecision::TF32, 0);
return success();
}
};
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ include "mlir/IR/PatternBase.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;

def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class BlockedToMMA : public mlir::RewritePattern {
}
// convert dot instruction
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getF32Backend(),
newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ namespace {
// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
// For a, b f32
// dot(a, b, f32Backend="tf32x3") ->
// dot(a, b, inputPrecision="tf32x3") ->
// let aBig = f32ToTF32(a), aSmall = a - aBig;
// let bBig = f32ToTF32(b), bSmall = b - bBig;
// dot(aSmall, bBig, f32Backend="tf32") +
// dot(aBig, bSmall, f32Backend="tf32") +
// dot(aBig, bBig, f32Backend="tf32")
// dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32") +
// dot(aBig, bBig, inputPrecision="tf32")
class TF32x3 : public OpRewritePattern<tt::DotOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -33,7 +33,7 @@ class TF32x3 : public OpRewritePattern<tt::DotOp> {
.isF32();
};

if (!(dotOp.getF32Backend() == tt::F32Backend::TF32x3 &&
if (!(dotOp.getInputPrecision() == tt::InputPrecision::TF32x3 &&
isF32(dotOp.getA()) && isF32(dotOp.getB()))) {
return failure();
}
Expand All @@ -52,7 +52,7 @@ class TF32x3 : public OpRewritePattern<tt::DotOp> {
};
auto dot = [&](Value a, Value b, Value c) -> Value {
return rewriter.create<tt::DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
tt::F32Backend::TF32,
tt::InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
builder.setInsertionPoint(dotOp);
builder.replaceOpWithNewOp<ttng::DotAsyncOp>(
dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc());
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConvertDotConvert : public RewritePattern {
auto _0 = rewriter.create<SplatOp>(op->getLoc(), dotOp.getType(), _0f);
auto newDot = rewriter.create<DotOp>(
op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1),
_0, dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc());
_0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
auto newCvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), dstTy,
newDot.getResult());
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, newCvt, cvtOp.getSrc());
Expand Down
12 changes: 6 additions & 6 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ void init_triton_ir(py::module &&m) {
.value("NONE", PropagateNan::NONE)
.value("ALL", PropagateNan::ALL);

py::enum_<F32Backend>(m, "F32BACKEND", py::module_local())
.value("TF32", F32Backend::TF32)
.value("TF32x3", F32Backend::TF32x3)
.value("IEEE", F32Backend::IEEE)
py::enum_<InputPrecision>(m, "INPUT_PRECISION", py::module_local())
.value("TF32", InputPrecision::TF32)
.value("TF32x3", InputPrecision::TF32x3)
.value("IEEE", InputPrecision::IEEE)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local()).def(py::init<>());
Expand Down Expand Up @@ -1377,9 +1377,9 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, F32Backend f32Backend,
mlir::Value &c, InputPrecision inputPrecision,
int maxNumImpreciseAcc) -> mlir::Value {
return self.create<DotOp>(c.getType(), a, b, c, f32Backend,
return self.create<DotOp>(c.getType(), a, b, c, inputPrecision,
maxNumImpreciseAcc);
})
.def("create_floor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {f32Backend = 0 : i32, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {inputPrecision = 0 : i32, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
triton_nvidia_gpu.mbarrier_wait %mbar1, %i1_false : !tt.ptr<i64, 3>
// triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
triton_nvidia_gpu.mbarrier_wait %mbar1_s, %i1_false : !tt.ptr<i64, 3>
// triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%28 = triton_gpu.extract_slice %b_smem_loaded[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>

// Calling MMA
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>

// Epilogue
%30 = tt.splat %cStride0: (i32) -> tensor<64x1xi32, #blocked1>
Expand Down
2 changes: 1 addition & 1 deletion python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ def _str_to_dot_f32_backend(input_precision, builder):
input_precision = input_precision.upper()
if input_precision == "TF32X3":
input_precision = "TF32x3"
return getattr(ir.F32BACKEND, input_precision)
return getattr(ir.INPUT_PRECISION, input_precision)


def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int,
Expand Down
2 changes: 1 addition & 1 deletion test/Analysis/test-alias.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>

%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
Expand Down
Loading

0 comments on commit e614d5b

Please sign in to comment.