Skip to content

Commit

Permalink
Add tilize / untilize ops to the ttkernel dialect (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmithtt authored Aug 14, 2024
1 parent 6b58bde commit 03049c5
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 7 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
MemorySpace getMemorySpace() const;
bool isSystemMemorySpace() const { return ::mlir::tt::isSystemMemorySpace(getMemorySpace()); }
bool isDeviceMemorySpace() const { return ::mlir::tt::isDeviceMemorySpace(getMemorySpace()); }
bool isTiled() const;
Type getElementType() const;
uint64_t getElementSizeBytes() const;
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
Expand Down
48 changes: 48 additions & 0 deletions include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,54 @@ def TTKernel_CBWaitFrontOp : TTKernel_Op<"cb_wait_front"> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// TTKernel Tile operations
//===----------------------------------------------------------------------===//

def TTKernel_TilizeInitOp : TTKernel_Op<"tilize_init"> {
let summary = "TilizeInitOp call.";
let description = [{
TilizeInitOp operation
}];

let arguments = (ins TTKernel_CB:$cbIn, I32:$numTiles, TTKernel_CB:$cbOut);

let hasVerifier = 1;
}

def TTKernel_UntilizeInitOp : TTKernel_Op<"untilize_init"> {
let summary = "UntilizeInitOp call.";
let description = [{
UntilizeInitOp operation
}];

let arguments = (ins TTKernel_CB:$cbIn, TTKernel_CB:$cbOut);

let hasVerifier = 1;
}

def TTKernel_TilizeBlockOp : TTKernel_Op<"tilize_block"> {
let summary = "TilizeBlockOp call.";
let description = [{
TilizeBlockOp operation
}];

let arguments = (ins TTKernel_CB:$cbIn, I32:$numTiles, TTKernel_CB:$cbOut);

let hasVerifier = 1;
}

def TTKernel_UntilizeBlockOp : TTKernel_Op<"untilize_block"> {
let summary = "UntilizeBlockOp call.";
let description = [{
UntilizeBlockOp operation
}];

let arguments = (ins TTKernel_CB:$cbIn, I32:$numTiles, TTKernel_CB:$cbOut);

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// TTKernel NOC operations
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ mlir::Type LayoutAttr::getElementType() const {
return getMemref().getElementType();
}

bool LayoutAttr::isTiled() const {
return ::mlir::isa<::mlir::tt::TileType>(getElementType());
}

uint64_t LayoutAttr::getElementSizeBytes() const {
mlir::Type elementType = getElementType();
if (mlir::isa<TileType>(elementType)) {
Expand Down
78 changes: 71 additions & 7 deletions lib/Dialect/TTKernel/IR/TTKernelOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,45 +54,109 @@ static bool insideThread(mlir::Operation *op, ttkernel::ThreadType threadType) {
}

::mlir::LogicalResult BuiltinOp::verify() {
if (not insideDispatchOpRegion(getOperation())) {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("KernelOp must be inside of a DispatchOp region");
}
if (not insideThread(getOperation(), ttkernel::ThreadType::Tensix)) {
if (!insideThread(getOperation(), ttkernel::ThreadType::Tensix)) {
return emitOpError("KernelOp must be inside of a Tensix thread");
}
return success();
}

::mlir::LogicalResult CBPushBackOp::verify() {
if (not insideDispatchOpRegion(getOperation())) {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("CBPushBackOp must be inside of a DispatchOp region");
}
return success();
}

::mlir::LogicalResult CBPopFrontOp::verify() {
if (not insideDispatchOpRegion(getOperation())) {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("CBPopFrontOp must be inside of a DispatchOp region");
}
return success();
}

::mlir::LogicalResult CBReserveBackOp::verify() {
if (not insideDispatchOpRegion(getOperation())) {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("CBReserveBackOp must be inside of a DispatchOp region");
}
return success();
}

::mlir::LogicalResult CBWaitFrontOp::verify() {
if (not insideDispatchOpRegion(getOperation())) {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("CBWaitFrontOp must be inside of a DispatchOp region");
}
return success();
}

static std::string verifyTilizeUntilizeCBs(CBType tilizedCB, CBType scalarCB) {
if (tilizedCB.getPort() == scalarCB.getPort()) {
return "Input circular buffer port and output circular buffer "
"port must be different";
}
if (mlir::isa<tt::TileType>(scalarCB.getMemref().getElementType())) {
return "Input to TilizeOp or Output to UntilizeOp must have scalar "
"element type";
}
if (!mlir::isa<tt::TileType>(tilizedCB.getMemref().getElementType())) {
return "Input to UntilizeOp or Output to TilizeOp must have tile "
"element type";
}
return std::string();
}

::mlir::LogicalResult TilizeInitOp::verify() {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("TilizeInitOp must be inside of a DispatchOp region");
}
std::string err =
verifyTilizeUntilizeCBs(getCbOut().getType(), getCbIn().getType());
if (!err.empty()) {
return emitOpError(err);
}
return success();
}

::mlir::LogicalResult UntilizeInitOp::verify() {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("UntilizeInitOp must be inside of a DispatchOp region");
}
std::string err =
verifyTilizeUntilizeCBs(getCbIn().getType(), getCbOut().getType());
if (!err.empty()) {
return emitOpError(err);
}
return success();
}

::mlir::LogicalResult TilizeBlockOp::verify() {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("TilizeBlockOp must be inside of a DispatchOp region");
}
std::string err =
verifyTilizeUntilizeCBs(getCbOut().getType(), getCbIn().getType());
if (!err.empty()) {
return emitOpError(err);
}
return success();
}

::mlir::LogicalResult UntilizeBlockOp::verify() {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("UntilizeBlockOp must be inside of a DispatchOp region");
}
std::string err =
verifyTilizeUntilizeCBs(getCbIn().getType(), getCbOut().getType());
if (!err.empty()) {
return emitOpError(err);
}
return success();
}

::mlir::LogicalResult ReturnOp::verify() {
if (not insideDispatchOpRegion(getOperation())) {
if (!insideDispatchOpRegion(getOperation())) {
return emitOpError("ReturnOp must be inside of a DispatchOp region");
}
return success();
Expand Down
20 changes: 20 additions & 0 deletions lib/Dialect/TTMetal/Transforms/KernelsToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp,
builder.create<emitc::IncludeOp>(module.getLoc(),
"compute_kernel_api/common.h",
/*isStandard=*/false);
builder.create<emitc::IncludeOp>(module.getLoc(),
"compute_kernel_api/tilize.h",
/*isStandard=*/false);
builder.create<emitc::IncludeOp>(module.getLoc(),
"compute_kernel_api/untilize.h",
/*isStandard=*/false);
}

if (threadTypeAttr.getValue() == ttkernel::ThreadType::Tensix) {
builder.create<emitc::VerbatimOp>(module.getLoc(), "namespace NAMESPACE {");
}

// Create a new func op and move the existing block into it.
Expand All @@ -154,6 +164,12 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp,
IRMapping irMapper;
funcBody->takeBody(region);

if (threadTypeAttr.getValue() == ttkernel::ThreadType::Tensix) {
builder.create<emitc::VerbatimOp>(module.getLoc(),
"void MAIN { kernel_main(); }");
builder.create<emitc::VerbatimOp>(module.getLoc(), "}");
}

// Apply arith to emitc conversion first
{
ConversionTarget target(*module.getContext());
Expand All @@ -179,6 +195,10 @@ LogicalResult emitDispatchOpRegionAsCpp(DispatchOp origOp,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBPopFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBReserveBackOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::CBWaitFrontOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TilizeInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UntilizeInitOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::TilizeBlockOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::UntilizeBlockOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::GetNocAddrOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadOp>,
TTMetalToEmitCOpaqueRewriter<ttkernel::NocAsyncReadBarrierOp>,
Expand Down

0 comments on commit 03049c5

Please sign in to comment.