Skip to content

Commit

Permalink
Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Mar 15, 2024
1 parent d42ca11 commit cb13259
Show file tree
Hide file tree
Showing 58 changed files with 341 additions and 195 deletions.
11 changes: 11 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,15 @@ def TT_PropagateNanAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// F32Backend
def TT_F32BackendAttr : I32EnumAttr<
"F32Backend", "",
[
I32EnumAttrCase<"TF32", 0, "tf32">,
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
I32EnumAttrCase<"IEEE", 2, "ieee">
]>{
let cppNamespace = "::mlir::triton";
}

#endif
9 changes: 7 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,19 @@ def TT_DotOp : TT_Op<"dot", [Pure,
let summary = "dot";

let description = [{
$d = matrix_multiply($a, $b) + $c
$d = matrix_multiply($a, $b) + $c. $f32Backend describes how to exercise the TC
when the inputs are f32. It can be one of: tf32, tf32x3, ieee.
tf32: use TC with tf32 ops.
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp
ieee: don't use TC, implement dot in software.
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.
}];

let arguments = (ins
TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
TT_F32BackendAttr:$f32Backend,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ std::unique_ptr<Pass> createPipelinePass(int numStages = 3, int numWarps = 4,

std::unique_ptr<Pass> createAccelerateMatmulPass(int computeCapability = 80);

std::unique_ptr<Pass> createF32DotTCPass();

std::unique_ptr<Pass> createPrefetchPass();

std::unique_ptr<Pass> createCoalescePass();
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}

def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
let summary = "3xTF32 trick";

let description = [{
Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s
to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385
}];

let constructor = "mlir::triton::gpu::createF32DotTCPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
}

def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
let summary = "dot async";

let description = [{
$d = matrix_multiply($a, $b) + $c
$d = matrix_multiply($a, $b) + $c. For docs on F32BackendAttr, see TT_DotOp
}];

let arguments = (ins TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
BoolAttr:$allowTF32,
TT_F32BackendAttr:$f32Backend,
I32Attr:$maxNumImpreciseAcc);

let results = (outs TT_FpIntTensor:$d);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ bool supportMMA(triton::DotOp op, int version) {
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getAllowTF32() && version >= 2;
return op.getF32Backend() == F32Backend::TF32 && version >= 2;
}
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);

addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, c, adaptor.getAllowTF32(),
op, retType, a, b, c, adaptor.getF32Backend(),
adaptor.getMaxNumImpreciseAcc()),
adaptor.getAttributes());
return success();
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
rewriter.create<arith::ConstantOp>(op->getLoc(),
rewriter.getF32FloatAttr(0)));
rewriter.replaceOpWithNewOp<DotOp>(op, expandLhsOp.getSrc(),
expandRhsOp.getSrc(), newAcc, true, 0);
expandRhsOp.getSrc(), newAcc,
F32Backend::TF32, 0);
return success();
}
};
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ include "mlir/IR/PatternBase.td"
// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d)
def CombineDotAddIPattern : Pat<
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $overflow),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;

def CombineDotAddIRevPattern : Pat<
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $d, $overflow),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc, (location $res)),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $f32Backend, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $f32Backend, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class BlockedToMMA : public mlir::RewritePattern {
}
// convert dot instruction
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getAllowTF32(),
newAcc, dotOp.getF32Backend(),
dotOp.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(TritonGPUTransforms
AccelerateMatmul.cpp
Coalesce.cpp
F32DotTC.cpp
ReduceDataDuplication.cpp
OptimizeDotOperands.cpp
OptimizeThreadLocality.cpp
Expand Down
91 changes: 91 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"

using namespace mlir;
namespace tt = mlir::triton;

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

namespace {

// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
// For a, b f32
// dot(a, b, f32Backend="tf32x3") ->
// let aBig = f32ToTF32(a), aSmall = a - aBig;
// let bBig = f32ToTF32(b), bSmall = b - bBig;
// dot(aSmall, bBig, f32Backend="tf32") +
// dot(aBig, bSmall, f32Backend="tf32") +
// dot(aBig, bBig, f32Backend="tf32")
class TF32x3 : public OpRewritePattern<tt::DotOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(tt::DotOp dotOp,
PatternRewriter &rewriter) const override {

auto isF32 = [](Value operand) {
return operand.getType()
.cast<RankedTensorType>()
.getElementType()
.isF32();
};

if (!(dotOp.getF32Backend() == tt::F32Backend::TF32x3 &&
isF32(dotOp.getA()) && isF32(dotOp.getB()))) {
return failure();
}

// Aux functions
auto f32ToTF32 = [&](Value value) -> Value {
return rewriter
.create<tt::ElementwiseInlineAsmOp>(
dotOp.getLoc(), value.getType(), "cvt.rna.tf32.f32 $0, $1;",
"=r,r",
/*isPure=*/true, /*pack=*/1, ArrayRef<Value>{value})
.getResult()[0];
};
auto sub = [&](Value a, Value b) -> Value {
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
};
auto dot = [&](Value a, Value b, Value c) -> Value {
return rewriter.create<tt::DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
tt::F32Backend::TF32,
dotOp.getMaxNumImpreciseAcc());
};

auto aBig = f32ToTF32(dotOp.getA());
auto aSmall = sub(dotOp.getA(), aBig);

auto bBig = f32ToTF32(dotOp.getB());
auto bSmall = sub(dotOp.getB(), bBig);

auto dot1 = dot(aSmall, bBig, dotOp.getC());
auto dot2 = dot(aBig, bSmall, dot1);
auto dot3 = dot(aBig, bBig, dot2);

rewriter.replaceOp(dotOp, dot3);
return success();
}
};

struct F32DotTCPass : public TritonGPUF32DotTCBase<F32DotTCPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

RewritePatternSet decomposePatterns(context);
decomposePatterns.add<TF32x3>(context);
if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
.failed()) {
signalPassFailure();
}
}
};
} // anonymous namespace

std::unique_ptr<Pass> mlir::triton::gpu::createF32DotTCPass() {
return std::make_unique<F32DotTCPass>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -1162,8 +1162,8 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
if (resEnc && resEnc.isHopper()) {
builder.setInsertionPoint(dotOp);
builder.replaceOpWithNewOp<ttng::DotAsyncOp>(
dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConvertDotConvert : public RewritePattern {
auto _0 = rewriter.create<SplatOp>(op->getLoc(), dotOp.getType(), _0f);
auto newDot = rewriter.create<DotOp>(
op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1),
_0, dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
_0, dotOp.getF32Backend(), dotOp.getMaxNumImpreciseAcc());
auto newCvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), dstTy,
newDot.getResult());
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, newCvt, cvtOp.getSrc());
Expand Down
13 changes: 10 additions & 3 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ void init_triton_ir(py::module &&m) {
.value("NONE", PropagateNan::NONE)
.value("ALL", PropagateNan::ALL);

py::enum_<F32Backend>(m, "F32BACKEND", py::module_local())
.value("TF32", F32Backend::TF32)
.value("TF32x3", F32Backend::TF32x3)
.value("IEEE", F32Backend::IEEE)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local()).def(py::init<>());

m.def("load_dialects", [](MLIRContext &context) {
Expand Down Expand Up @@ -1370,9 +1376,10 @@ void init_triton_ir(py::module &&m) {
ProgramIDDim(axis)));
})
.def("create_dot",
[](TritonOpBuilder &self, Value &a, Value &b, Value &c,
bool allowTF32, int maxNumImpreciseAcc) -> Value {
return self.create<DotOp>(c.getType(), a, b, c, allowTF32,
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, F32Backend f32Backend,
int maxNumImpreciseAcc) -> mlir::Value {
return self.create<DotOp>(c.getType(), a, b, c, f32Backend,
maxNumImpreciseAcc);
})
.def("create_floor",
Expand Down
1 change: 1 addition & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_prefetch", createPrefetchPass);
ADD_PASS_WRAPPER_1("add_accelerate_matmul", createAccelerateMatmulPass, int);
ADD_PASS_WRAPPER_0("add_reorder_instructions", createReorderInstructionsPass);
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createF32DotTCPass);
ADD_PASS_WRAPPER_0("add_optimize_dot_operands",
createOptimizeDotOperandsPass);
ADD_PASS_WRAPPER_0("add_remove_layout_conversions",
Expand Down
5 changes: 1 addition & 4 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
out_torch = torch.matmul(a.to(torch_dtype), b.to(torch_dtype))
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)

allow_tf32 = True
# launch kernel
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1)
Expand All @@ -45,7 +44,6 @@ def matmul_kernel(A, B, C, M, N, K, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
dot_out_dtype: tl.constexpr, #
allow_tf32: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, #
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
# matrix multiplication
Expand Down Expand Up @@ -75,7 +73,7 @@ def matmul_kernel(A, B, C, M, N, K, #
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
acc = acc.to(C.dtype.element_ty)
Expand All @@ -91,7 +89,6 @@ def matmul_kernel(A, B, C, M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, #
allow_tf32=allow_tf32, #
GROUP_M=8, #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {f32Backend = 0 : i32, transA = true, transB = true} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%27 = triton_gpu.convert_layout %21 : (tensor<64x16xf16, #blocked0>) -> tensor<64x16xf16, #shared0>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %27, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %27, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%25 = tt.addptr %24, %16 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi32, #blocked2>
%26 = tt.load %25 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x64xf16, #blocked2>
%28 = triton_gpu.convert_layout %26 : (tensor<16x64xf16, #blocked2>) -> tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
triton_gpu.async_commit_group
triton_gpu.async_wait {num = 1 : i32}
%28 = triton_gpu.extract_slice %bt0[%ci0, 0, 0][1, 16, 64][1, 1, 1] : tensor<1x16x64xf16, #shared1> to tensor<16x64xf16, #shared1>
%29 = tt.dot %21, %28, %cst {allowTF32 = true, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%29 = tt.dot %21, %28, %cst {f32Backend = 0 : i32, transA = false, transB = false} : tensor<64x16xf16, #shared0> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma>
%30 = tt.splat %arg5 : (i32) -> tensor<64x1xi32, #blocked1>
%31 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
%32 = tt.expand_dims %14 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
Expand Down
Loading

0 comments on commit cb13259

Please sign in to comment.