Skip to content

Commit

Permalink
[BACKEND] Implement 3xTF32 trick (#3234)
Browse files Browse the repository at this point in the history
This PR implements the [3xTF32
trick](NVIDIA/cutlass#385) to make use of
the TCs on F32 tensors without sacrificing accuracy. This is
particularly relevant for PyTorch, as TF32 is off by default.

Benchmarks on A100 from `python/tutorials/03-matrix-multiplication.py`
run on `float32` data using `use_tf32=False`:
```
         M       N       K     cuBLAS     Triton    This PR
0    256.0   256.0   256.0   1.927529   1.092267   1.489455
1    384.0   384.0   384.0   5.026909   3.567484   3.686400
2    512.0   512.0   512.0   8.192000   6.553600   6.898527
3    640.0   640.0   640.0  12.190476  10.448980  10.666666
4    768.0   768.0   768.0  13.405091  10.287628  14.503869
5    896.0   896.0   896.0  14.049280  13.380267  20.070399
6   1024.0  1024.0  1024.0  15.887515  12.264046  19.239927
7   1152.0  1152.0  1152.0  16.681475  15.633424  24.883201
8   1280.0  1280.0  1280.0  16.516129  15.340824  28.248276
9   1408.0  1408.0  1408.0  17.090206  14.774461  24.016635
10  1536.0  1536.0  1536.0  17.014154  15.624477  26.021647
11  1664.0  1664.0  1664.0  17.043394  15.073554  25.858942
12  1792.0  1792.0  1792.0  17.107190  16.171833  29.577431
13  1920.0  1920.0  1920.0  17.883570  15.762828  26.331430
14  2048.0  2048.0  2048.0  17.623127  17.032706  27.413751
15  2176.0  2176.0  2176.0  17.887688  16.686275  29.945905
16  2304.0  2304.0  2304.0  19.019006  17.933838  33.787654
17  2432.0  2432.0  2432.0  17.940270  17.288901  31.181425
18  2560.0  2560.0  2560.0  18.164080  17.075561  31.844508
19  2688.0  2688.0  2688.0  17.594183  16.703239  30.370742
20  2816.0  2816.0  2816.0  18.766871  18.089676  33.242537
21  2944.0  2944.0  2944.0  18.735350  17.855977  33.695763
22  3072.0  3072.0  3072.0  18.420008  17.766898  32.768000
23  3200.0  3200.0  3200.0  18.470418  17.704011  33.255391
24  3328.0  3328.0  3328.0  18.253370  17.710036  32.753092
25  3456.0  3456.0  3456.0  18.546485  17.793328  33.634362
26  3584.0  3584.0  3584.0  18.368824  17.833278  33.142423
27  3712.0  3712.0  3712.0  18.665424  17.938112  34.036574
28  3840.0  3840.0  3840.0  18.638578  18.076496  33.794348
29  3968.0  3968.0  3968.0  18.965486  18.190808  34.324595
30  4096.0  4096.0  4096.0  19.035276  18.365864  34.450135
```

It's an overall win,  getting roughly a 85% speed-up on large sizes.

Note that the rounding is differs a little bit to the one [implemented
in
CUTLASS](https://github.com/NVIDIA/cutlass/blob/a8f2c80db0564c74f4efccac71993b971dfc448b/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h#L99-L100).
We could implement that rounding if we wanted though.

This is still a bit far from the 2x speed-ups announced by CUTLASS. To
get close to those numbers, we should probably need to remove the stores
to shared before `ldmatrix`.
  • Loading branch information
lezcano authored Mar 28, 2024
1 parent 0ba87e2 commit 47a35b6
Show file tree
Hide file tree
Showing 60 changed files with 404 additions and 245 deletions.
11 changes: 11 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,15 @@ def TT_PropagateNanAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// InputPrecision
def TT_InputPrecisionAttr : I32EnumAttr<
"InputPrecision", "",
[
I32EnumAttrCase<"TF32", 0, "tf32">,
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
I32EnumAttrCase<"IEEE", 2, "ieee">
]>{
let cppNamespace = "::mlir::triton";
}

#endif
9 changes: 7 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,19 @@ def TT_DotOp : TT_Op<"dot", [Pure,
let summary = "dot";

let description = [{
$d = matrix_multiply($a, $b) + $c
$d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC
when the inputs are f32. It can be one of: tf32, tf32x3, ieee.
tf32: use TC with tf32 ops.
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
ieee: don't use TC, implement dot in software.
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
}];

let arguments = (ins
TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
TT_InputPrecisionAttr:$inputPrecision,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ std::unique_ptr<Pass> createPipelinePass(int numStages = 3, int numWarps = 4,

std::unique_ptr<Pass> createAccelerateMatmulPass(int computeCapability = 80);

std::unique_ptr<Pass> createF32DotTCPass();

std::unique_ptr<Pass> createPrefetchPass();

std::unique_ptr<Pass> createCoalescePass();
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}

def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
let summary = "3xTF32 trick";

let description = [{
Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385
}];

let constructor = "mlir::triton::gpu::createF32DotTCPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
}

def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
let summary = "dot async";

let description = [{
$d = matrix_multiply($a, $b) + $c
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
}];

let arguments = (ins TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
TT_InputPrecisionAttr:$inputPrecision,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ bool supportMMA(triton::DotOp op, int version) {
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getAllowTF32() && version >= 2;
return op.getInputPrecision() == InputPrecision::TF32 && version >= 2;
}
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);

addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, c, adaptor.getAllowTF32(),
op, retType, a, b, c, adaptor.getInputPrecision(),
adaptor.getMaxNumImpreciseAcc()),
adaptor.getAttributes());
return success();
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
rewriter.create<arith::ConstantOp>(op->getLoc(),
rewriter.getF32FloatAttr(0)));
rewriter.replaceOpWithNewOp<DotOp>(op, expandLhsOp.getSrc(),
expandRhsOp.getSrc(), newAcc, true, 0);
expandRhsOp.getSrc(), newAcc,
InputPrecision::TF32, 0);
return success();
}
};
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ include "mlir/IR/PatternBase.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;

def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class BlockedToMMA : public mlir::RewritePattern {
}
// convert dot instruction
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getAllowTF32(),
newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonGPUTransforms
AccelerateMatmul.cpp
Coalesce.cpp
F32DotTC.cpp
ReduceDataDuplication.cpp
OptimizeDotOperands.cpp
OptimizeThreadLocality.cpp
Expand Down
91 changes: 91 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"

using namespace mlir;
namespace tt = mlir::triton;

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

namespace {

// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
// For a, b f32
// dot(a, b, inputPrecision="tf32x3") ->
// let aBig = f32ToTF32(a), aSmall = a - aBig;
// let bBig = f32ToTF32(b), bSmall = b - bBig;
// dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32") +
// dot(aBig, bBig, inputPrecision="tf32")
class TF32x3 : public OpRewritePattern<tt::DotOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tt::DotOp dotOp,
PatternRewriter &rewriter) const override {

auto isF32 = [](Value operand) {
return operand.getType()
.cast<RankedTensorType>()
.getElementType()
.isF32();
};

if (!(dotOp.getInputPrecision() == tt::InputPrecision::TF32x3 &&
isF32(dotOp.getA()) && isF32(dotOp.getB()))) {
return failure();
}

// Aux functions
auto f32ToTF32 = [&](Value value) -> Value {
return rewriter
.create<tt::ElementwiseInlineAsmOp>(
dotOp.getLoc(), value.getType(), "cvt.rna.tf32.f32 $0, $1;",
"=r,r",
/*isPure=*/true, /*pack=*/1, ArrayRef<Value>{value})
.getResult()[0];
};
auto sub = [&](Value a, Value b) -> Value {
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
};
auto dot = [&](Value a, Value b, Value c) -> Value {
return rewriter.create<tt::DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
tt::InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());
};

auto aBig = f32ToTF32(dotOp.getA());
auto aSmall = sub(dotOp.getA(), aBig);

auto bBig = f32ToTF32(dotOp.getB());
auto bSmall = sub(dotOp.getB(), bBig);

auto dot1 = dot(aSmall, bBig, dotOp.getC());
auto dot2 = dot(aBig, bSmall, dot1);
auto dot3 = dot(aBig, bBig, dot2);

rewriter.replaceOp(dotOp, dot3);
return success();
}
};

struct F32DotTCPass : public TritonGPUF32DotTCBase<F32DotTCPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

RewritePatternSet decomposePatterns(context);
decomposePatterns.add<TF32x3>(context);
if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
.failed()) {
signalPassFailure();
}
}
};
} // anonymous namespace

std::unique_ptr<Pass> mlir::triton::gpu::createF32DotTCPass() {
return std::make_unique<F32DotTCPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -1174,8 +1174,8 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
if (isMMAv3Dot(dotOp)) {
builder.setInsertionPoint(dotOp);
builder.replaceOpWithNewOp<ttng::DotAsyncOp>(
dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConvertDotConvert : public RewritePattern {
auto _0 = rewriter.create<SplatOp>(op->getLoc(), dotOp.getType(), _0f);
auto newDot = rewriter.create<DotOp>(
op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1),
_0, dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
_0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
auto newCvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), dstTy,
newDot.getResult());
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, newCvt, cvtOp.getSrc());
Expand Down
13 changes: 10 additions & 3 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ void init_triton_ir(py::module &&m) {
.value("NONE", PropagateNan::NONE)
.value("ALL", PropagateNan::ALL);

py::enum_<InputPrecision>(m, "INPUT_PRECISION", py::module_local())
.value("TF32", InputPrecision::TF32)
.value("TF32x3", InputPrecision::TF32x3)
.value("IEEE", InputPrecision::IEEE)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local()).def(py::init<>());

m.def("load_dialects", [](MLIRContext &context) {
Expand Down Expand Up @@ -1363,9 +1369,10 @@ void init_triton_ir(py::module &&m) {
ProgramIDDim(axis)));
})
.def("create_dot",
[](TritonOpBuilder &self, Value &a, Value &b, Value &c,
bool allowTF32, int maxNumImpreciseAcc) -> Value {
return self.create<DotOp>(c.getType(), a, b, c, allowTF32,
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, InputPrecision inputPrecision,
int maxNumImpreciseAcc) -> mlir::Value {
return self.create<DotOp>(c.getType(), a, b, c, inputPrecision,
maxNumImpreciseAcc);
})
.def("create_floor",
Expand Down
1 change: 1 addition & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_prefetch", createPrefetchPass);
ADD_PASS_WRAPPER_1("add_accelerate_matmul", createAccelerateMatmulPass, int);
ADD_PASS_WRAPPER_0("add_reorder_instructions", createReorderInstructionsPass);
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createF32DotTCPass);
ADD_PASS_WRAPPER_0("add_optimize_dot_operands",
createOptimizeDotOperandsPass);
ADD_PASS_WRAPPER_0("add_remove_layout_conversions",
Expand Down
5 changes: 1 addition & 4 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)

allow_tf32 = True
# launch kernel
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
Expand All @@ -45,7 +44,6 @@ def matmul_kernel(A, B, C, M, N, K, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
# matrix multiplication
Expand Down Expand Up @@ -75,7 +73,7 @@ def matmul_kernel(A, B, C, M, N, K, #
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(C.dtype.element_ty)
Expand All @@ -91,7 +89,6 @@ def matmul_kernel(A, B, C, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
allow_tf32=allow_tf32, #
GROUP_M=8, #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {inputPrecision = 0 : i32, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Loading

0 comments on commit 47a35b6

Please sign in to comment.