Skip to content

Commit

Permalink
Inline all functions besides main function (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT authored Oct 24, 2024
1 parent b1595bb commit 2571a98
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/ttmlir/RegisterAll.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class DialectRegistry;
namespace mlir::tt {

void registerAllDialects(mlir::DialectRegistry &registry);
void registerAllExtensions(mlir::DialectRegistry &registry);
void registerAllPasses();

} // namespace mlir::tt
Expand Down
37 changes: 37 additions & 0 deletions lib/Dialect/TTIR/IR/TTIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,47 @@
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"

#include "mlir/InitAllDialects.h"
#include "mlir/Transforms/InliningUtils.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

using namespace mlir;
using namespace mlir::tt::ttir;

// This DialectInlinerInterface is nearly identical to the one found in
// mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp. We need
// to define one for the TTIRDialect as well since the IR uses
// FuncDialect for function definitions/calls, and TTIR for ops.
// We need to legalize inlining for all TTIR ops.
struct TTIRInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;

// Everything can be inlined
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}

bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
return true;
}

bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
return true;
}

void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
// This should only ever be a func::ReturnOp
auto returnOp = cast<func::ReturnOp>(op);

// Replace usages of the functions result with the result operands of the
// return op
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
}
};

#include "ttmlir/Dialect/TTIR/IR/TTIROpsDialect.cpp.inc"

//===----------------------------------------------------------------------===//
Expand All @@ -22,4 +58,5 @@ void TTIRDialect::initialize() {
#define GET_OP_LIST
#include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc"
>();
addInterfaces<TTIRInlinerInterface>();
}
5 changes: 5 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ void createTTNNPipelineTTIRPasses(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
ttir::TTIRLoadSystemDescOptions systemDescOptions;
systemDescOptions.path = options.systemDescPath;

// Inlines all private functions. I.e flattens the program into the main
// function. Removes all private functions.
pm.addPass(mlir::createInlinerPass());

pm.addPass(mlir::tt::ttir::createTTIRSlidingWindow2dFixShapes());
pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc(systemDescOptions));

Expand Down
8 changes: 7 additions & 1 deletion lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

#include "ttmlir/RegisterAll.h"

#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"

#include "ttmlir/Conversion/Passes.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
Expand Down Expand Up @@ -38,6 +38,12 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
#endif
}

void mlir::tt::registerAllExtensions(mlir::DialectRegistry &registry) {
// Both the inliner for TTIRDialect and FuncDialect must be registered
// since we use a combination of TTIRDialect and FuncDialect in the IR
mlir::func::registerInlinerExtension(registry);
}

void mlir::tt::registerAllPasses() {
// Register all dialect conversion passes.
mlir::tt::registerTTMLIRConversionPasses();
Expand Down
1 change: 1 addition & 0 deletions python/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ void populatePassesModule(py::module &m) {

mlir::DialectRegistry registry;
mlir::tt::registerAllDialects(registry);
mlir::tt::registerAllExtensions(registry);
mlir::MLIRContext *ctx = unwrap(mlirModuleGetContext(module));
ctx->appendDialectRegistry(registry);

Expand Down
16 changes: 16 additions & 0 deletions test/ttmlir/Dialect/TTNN/multiple_func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @main(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]]
%1 = call @do_mult(%arg0, %arg1, %0) : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}

func.func private @do_mult(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>, %arg2: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = "ttir.multiply"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %0 : tensor<64x128xf32>
}
}
1 change: 1 addition & 0 deletions tools/ttmlir-opt/ttmlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ int main(int argc, char **argv) {

mlir::DialectRegistry registry;
mlir::tt::registerAllDialects(registry);
mlir::tt::registerAllExtensions(registry);

return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "ttmlir optimizer driver\n", registry));
Expand Down

0 comments on commit 2571a98

Please sign in to comment.