Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Implement 3xTF32 trick #3234

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,15 @@ def TT_PropagateNanAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// InputPrecision
def TT_InputPrecisionAttr : I32EnumAttr<
"InputPrecision", "",
[
I32EnumAttrCase<"TF32", 0, "tf32">,
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
I32EnumAttrCase<"IEEE", 2, "ieee">
]>{
let cppNamespace = "::mlir::triton";
}

#endif
9 changes: 7 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,19 @@ def TT_DotOp : TT_Op<"dot", [Pure,
let summary = "dot";

let description = [{
$d = matrix_multiply($a, $b) + $c
$d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC
when the inputs are f32. It can be one of: tf32, tf32x3, ieee.
tf32: use TC with tf32 ops.
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
ieee: don't use TC, implement dot in software.
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably these descriptions should be on the enum now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment still applies?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured it'd be better to leave them in the place where one would first look at when finding what this flag really does, but sure, I can move them

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example the enum is also used on dot_async.

(If you wanted a "see X" comment I'd be fine with that, although it's probably pretty obvious where to look.)

}];

let arguments = (ins
TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
TT_InputPrecisionAttr:$inputPrecision,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ std::unique_ptr<Pass> createPipelinePass(int numStages = 3, int numWarps = 4,

std::unique_ptr<Pass> createAccelerateMatmulPass(int computeCapability = 80);

std::unique_ptr<Pass> createF32DotTCPass();

std::unique_ptr<Pass> createPrefetchPass();

std::unique_ptr<Pass> createCoalescePass();
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}

def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
let summary = "3xTF32 trick";

let description = [{
Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385
}];

let constructor = "mlir::triton::gpu::createF32DotTCPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
}

def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
let summary = "dot async";

let description = [{
$d = matrix_multiply($a, $b) + $c
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
}];

let arguments = (ins TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
TT_InputPrecisionAttr:$inputPrecision,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ bool supportMMA(triton::DotOp op, int version) {
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getAllowTF32() && version >= 2;
return op.getInputPrecision() == InputPrecision::TF32 && version >= 2;
}
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);

addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, c, adaptor.getAllowTF32(),
op, retType, a, b, c, adaptor.getInputPrecision(),
adaptor.getMaxNumImpreciseAcc()),
adaptor.getAttributes());
return success();
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
rewriter.create<arith::ConstantOp>(op->getLoc(),
rewriter.getF32FloatAttr(0)));
rewriter.replaceOpWithNewOp<DotOp>(op, expandLhsOp.getSrc(),
expandRhsOp.getSrc(), newAcc, true, 0);
expandRhsOp.getSrc(), newAcc,
InputPrecision::TF32, 0);
return success();
}
};
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ include "mlir/IR/PatternBase.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;

def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ class BlockedToMMA : public mlir::RewritePattern {
}
// convert dot instruction
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getAllowTF32(),
newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonGPUTransforms
AccelerateMatmul.cpp
Coalesce.cpp
F32DotTC.cpp
ReduceDataDuplication.cpp
OptimizeDotOperands.cpp
OptimizeThreadLocality.cpp
Expand Down
91 changes: 91 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"

using namespace mlir;
namespace tt = mlir::triton;

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

namespace {

// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
// For a, b f32
// dot(a, b, inputPrecision="tf32x3") ->
// let aBig = f32ToTF32(a), aSmall = a - aBig;
// let bBig = f32ToTF32(b), bSmall = b - bBig;
// dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32") +
// dot(aBig, bBig, inputPrecision="tf32")
class TF32x3 : public OpRewritePattern<tt::DotOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tt::DotOp dotOp,
PatternRewriter &rewriter) const override {

auto isF32 = [](Value operand) {
return operand.getType()
.cast<RankedTensorType>()
.getElementType()
.isF32();
};

if (!(dotOp.getInputPrecision() == tt::InputPrecision::TF32x3 &&
isF32(dotOp.getA()) && isF32(dotOp.getB()))) {
return failure();
}

// Aux functions
auto f32ToTF32 = [&](Value value) -> Value {
return rewriter
.create<tt::ElementwiseInlineAsmOp>(
dotOp.getLoc(), value.getType(), "cvt.rna.tf32.f32 $0, $1;",
"=r,r",
/*isPure=*/true, /*pack=*/1, ArrayRef<Value>{value})
.getResult()[0];
};
auto sub = [&](Value a, Value b) -> Value {
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
};
auto dot = [&](Value a, Value b, Value c) -> Value {
return rewriter.create<tt::DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
tt::InputPrecision::TF32,
dotOp.getMaxNumImpreciseAcc());
};

auto aBig = f32ToTF32(dotOp.getA());
auto aSmall = sub(dotOp.getA(), aBig);

auto bBig = f32ToTF32(dotOp.getB());
auto bSmall = sub(dotOp.getB(), bBig);

auto dot1 = dot(aSmall, bBig, dotOp.getC());
auto dot2 = dot(aBig, bSmall, dot1);
auto dot3 = dot(aBig, bBig, dot2);

rewriter.replaceOp(dotOp, dot3);
return success();
}
};

struct F32DotTCPass : public TritonGPUF32DotTCBase<F32DotTCPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

RewritePatternSet decomposePatterns(context);
decomposePatterns.add<TF32x3>(context);
if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
.failed()) {
signalPassFailure();
}
}
};
} // anonymous namespace

std::unique_ptr<Pass> mlir::triton::gpu::createF32DotTCPass() {
return std::make_unique<F32DotTCPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -1174,8 +1174,8 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
if (isMMAv3Dot(dotOp)) {
builder.setInsertionPoint(dotOp);
builder.replaceOpWithNewOp<ttng::DotAsyncOp>(
dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConvertDotConvert : public RewritePattern {
auto _0 = rewriter.create<SplatOp>(op->getLoc(), dotOp.getType(), _0f);
auto newDot = rewriter.create<DotOp>(
op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1),
_0, dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
_0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
auto newCvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), dstTy,
newDot.getResult());
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, newCvt, cvtOp.getSrc());
Expand Down
13 changes: 10 additions & 3 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ void init_triton_ir(py::module &&m) {
.value("NONE", PropagateNan::NONE)
.value("ALL", PropagateNan::ALL);

py::enum_<InputPrecision>(m, "INPUT_PRECISION", py::module_local())
.value("TF32", InputPrecision::TF32)
.value("TF32x3", InputPrecision::TF32x3)
.value("IEEE", InputPrecision::IEEE)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local()).def(py::init<>());

m.def("load_dialects", [](MLIRContext &context) {
Expand Down Expand Up @@ -1363,9 +1369,10 @@ void init_triton_ir(py::module &&m) {
ProgramIDDim(axis)));
})
.def("create_dot",
[](TritonOpBuilder &self, Value &a, Value &b, Value &c,
bool allowTF32, int maxNumImpreciseAcc) -> Value {
return self.create<DotOp>(c.getType(), a, b, c, allowTF32,
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, InputPrecision inputPrecision,
int maxNumImpreciseAcc) -> mlir::Value {
return self.create<DotOp>(c.getType(), a, b, c, inputPrecision,
maxNumImpreciseAcc);
})
.def("create_floor",
Expand Down
1 change: 1 addition & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_prefetch", createPrefetchPass);
ADD_PASS_WRAPPER_1("add_accelerate_matmul", createAccelerateMatmulPass, int);
ADD_PASS_WRAPPER_0("add_reorder_instructions", createReorderInstructionsPass);
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createF32DotTCPass);
ADD_PASS_WRAPPER_0("add_optimize_dot_operands",
createOptimizeDotOperandsPass);
ADD_PASS_WRAPPER_0("add_remove_layout_conversions",
Expand Down
5 changes: 1 addition & 4 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)

allow_tf32 = True
# launch kernel
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
Expand All @@ -45,7 +44,6 @@ def matmul_kernel(A, B, C, M, N, K, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
# matrix multiplication
Expand Down Expand Up @@ -75,7 +73,7 @@ def matmul_kernel(A, B, C, M, N, K, #
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(C.dtype.element_ty)
Expand All @@ -91,7 +89,6 @@ def matmul_kernel(A, B, C, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
allow_tf32=allow_tf32, #
GROUP_M=8, #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {inputPrecision = 0 : i32, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {inputPrecision = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Loading
Loading