Skip to content

Commit

Permalink
TTIRToTTMetal conversion reworked to be a correct SFPU op
Browse files Browse the repository at this point in the history
  • Loading branch information
vroubtsovTT committed Dec 2, 2024
1 parent d46c4d5 commit 1d55fa0
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 36 deletions.
46 changes: 25 additions & 21 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,27 +256,6 @@ def TTKernel_MulTilesOp : TTKernel_Op<"mul_tiles"> {
let arguments = (ins TTKernel_CB:$in0_cb, TTKernel_CB:$in1_cb, I32:$in0_tile_index, I32:$in1_tile_index, I32:$dst_index);
}

def TTKernel_MaxTilesInitOp : TTKernel_Op<"max_tiles_init"> {
let summary = "Short init function";
let description = [{
Must be run before max_tiles.
}];

let arguments = (ins TTKernel_CB:$in0_cb, TTKernel_CB:$in1_cb); // FIXME: , BOOL:$acc_to_dst);
}

def TTKernel_MaxTilesOp : TTKernel_Op<"max_tiles"> {
let summary = "Max operation";
let description = [{
Performs element-wise C=max(A, B) of tiles in two CBs at given indices
and writes the result to the DST register at index dst_tile_index. The DST
register buffer must be in acquired state via *tile_regs_acquire* call. This call
is blocking and is only available on the compute engine.
}];

let arguments = (ins TTKernel_CB:$in0_cb, TTKernel_CB:$in1_cb, I32:$in0_tile_index, I32:$in1_tile_index, I32:$dst_index);
}

def TTKernel_UnaryOpInitCommonOp : TTKernel_Op<"unary_op_init_common"> {
let summary = "Initialization function for unary operations.";
let description = [{
Expand Down Expand Up @@ -363,6 +342,31 @@ def TTKernel_ReduceTileOp : TTKernel_Op<"reduce_tile"> {
TTKernel_ReduceDimAttr:$reduce_dim);
}

//===----------------------------------------------------------------------===//
// TTKernel SFPU operations
//===----------------------------------------------------------------------===//

def TTKernel_MaxTilesInitOp : TTKernel_Op<"max_tile_init"> {
let summary = "Short init function";
let description = [{
Must be run before max_tile.
}];

let arguments = (ins);
}

def TTKernel_MaxTilesOp : TTKernel_Op<"max_tile"> {
let summary = "Max operation";
let description = [{
Performs element-wise computation of maximum operation
DST[dst0_index] <- max(DST[dst0_index], DST[dst1_index])
on DST register operands. The DST register buffer must be in
acquired state via *tile_regs_acquire* call.
}];

let arguments = (ins I32:$dst0_index, I32:$dst1_index);
}

//===----------------------------------------------------------------------===//
// TTKernel CB operations
//===----------------------------------------------------------------------===//
Expand Down
81 changes: 66 additions & 15 deletions lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,7 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
} else if (mlir::isa<arith::DivFOp>(arithOrMathOp)) {
builder.create<ttkernel::MulTilesInitFOp>(arithOrMathOp.getLoc());
} else if (mlir::isa<arith::MaximumFOp>(arithOrMathOp)) {
builder.create<ttkernel::MaxTilesInitOp>(arithOrMathOp.getLoc(), inCB0,
inCB1);
builder.create<ttkernel::MaxTilesInitOp>(arithOrMathOp.getLoc());
} else {
llvm_unreachable("Unhandled binary op init conversion.");
}
Expand Down Expand Up @@ -912,7 +911,7 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
// from cbOperands[1] and store the result C in DST register on
// dstTileIndex.
if (mlir::isa<arith::AddFOp>(arithOrMathOp)) {
commonComputeBinaryOp<ttkernel::AddTilesOp>(
convertComputeBinaryFPUOp<ttkernel::AddTilesOp>(
arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping,
builder);
} else if (mlir::isa<arith::MulFOp>(arithOrMathOp)) {
Expand Down Expand Up @@ -941,7 +940,7 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {

builder.create<ttkernel::CBPopFrontOp>(location, inCB1, one);
} else if (mlir::isa<arith::MaximumFOp>(arithOrMathOp)) {
commonComputeBinaryOp<ttkernel::MaxTilesOp>(
convertComputeBinarySFPUOp<ttkernel::MaxTilesOp>(
arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping,
builder);
} else {
Expand All @@ -950,13 +949,12 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
}
}

template <typename TilesOp>
void
commonComputeBinaryOp(Operation &arithOrMathOp,
ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
const SmallVector<unsigned> &blockArgIteratorMapping,
OpBuilder &builder) const {
template <typename TTKernelTilesOp>
void convertComputeBinaryFPUOp(
Operation &arithOrMathOp, ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
const SmallVector<unsigned> &blockArgIteratorMapping,
OpBuilder &builder) const {
auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB0 = cbOperands[0];
auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]];
Expand All @@ -967,14 +965,67 @@ class TTIRToTTMetalDispatchRewriter : public OpRewritePattern<ttir::GenericOp> {
auto location = arithOrMathOp.getLoc();

Value dstIndex = i32(0, builder);

// acquire DST register lock (MATH)
builder.create<ttkernel::TileRegsAcquireOp>(location);
builder.create<TilesOp>(location, inCB0, inCB1, inCB0TileIndex,
inCB1TileIndex, dstIndex);
{
builder.create<TTKernelTilesOp>(location, inCB0, inCB1, inCB0TileIndex,
inCB1TileIndex, dstIndex);
}
builder.create<ttkernel::TileRegsCommitOp>(location);
// release DST register lock (MATH)

// acquire DST register lock (PACK)
builder.create<ttkernel::TileRegsWaitOp>(location);
builder.create<ttkernel::PackTileOp>(location, dstIndex, outCB,
outCBTileIndex);
{
builder.create<ttkernel::PackTileOp>(location, dstIndex, outCB,
outCBTileIndex);
}
builder.create<ttkernel::TileRegsReleaseOp>(location);
// release DST register lock (PACK)
}

template <typename TTKernelTilesOp>
void convertComputeBinarySFPUOp(
Operation &arithOrMathOp, ArrayRef<BlockArgument> cbOperands,
ArrayRef<BlockArgument> iterators,
const SmallVector<unsigned> &blockArgIteratorMapping,
OpBuilder &builder) const {
auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]];
auto inCB0 = cbOperands[0];
auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]];
auto inCB1 = cbOperands[1];
auto outCB = cbOperands[2];
auto outCBTileIndex = iterators[blockArgIteratorMapping[2]];

auto location = arithOrMathOp.getLoc();

Value dstLhsTileIndex = i32(0, builder);
Value dstRhsTileIndex = i32(1, builder); // note: rhs is always lhs+1

// acquire DST register lock (MATH)
builder.create<ttkernel::TileRegsAcquireOp>(location);
{
// copy inCB0[inCB0TileIndex] and inCB1[inCB1TileIndex] to DST:
builder.create<ttkernel::CopyTileOp>(location, inCB0, inCB0TileIndex,
dstLhsTileIndex);
builder.create<ttkernel::CopyTileOp>(location, inCB1, inCB1TileIndex,
dstRhsTileIndex);
// SFPU ooperates on DST tiles:
builder.create<TTKernelTilesOp>(location, dstLhsTileIndex,
dstRhsTileIndex);
}
builder.create<ttkernel::TileRegsCommitOp>(location);
// release DST register lock (MATH)

// acquire DST register lock (PACK)
builder.create<ttkernel::TileRegsWaitOp>(location);
{
builder.create<ttkernel::PackTileOp>(location, dstLhsTileIndex, outCB,
outCBTileIndex);
}
builder.create<ttkernel::TileRegsReleaseOp>(location);
// release DST register lock (PACK)
}

void commonComputeMulOp(Operation &op, ArrayRef<BlockArgument> cbOperands,
Expand Down

0 comments on commit 1d55fa0

Please sign in to comment.