Skip to content

Commit

Permalink
Add pass to create input tensor generator functions for emitc path (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT authored Dec 9, 2024
1 parent d8cc464 commit 6c4a4fa
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 14 deletions.
39 changes: 39 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,43 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
}];
}

def TTNNCreateInputGenerators: Pass<"ttnn-create-input-gens", "::mlir::ModuleOp"> {
let summary = "Create input generators for the forward functions.";
let description = [{
This pass creates input generators for the "forward" functions. It
additionally creates a main function to run the forward function with the
generated inputs.

The pass is useful for EmitC path. By creating input generators before
converting to Emitc Dialect, followed by transformation to C++ code, the
resulting code won't require any edits to run.

Given a forward function like this:

```
func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
%0 = "ttnn.add"(%arg0, %arg1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %0 : tensor<32x32xbf16>
}
```

The pass will create two function like this:

```
func.func @createInputsFor_add() -> (tensor<32x32xbf16>, tensor<32x32xbf16>) {
%0 = "ttnn.empty"() <{shape = #ttnn.shape<32x32>}> : () -> tensor<32x32xbf16>
%1 = "ttnn.empty"() <{shape = #ttnn.shape<32x32>}> : () -> tensor<32x32xbf16>
return %0, %1 : tensor<32x32xbf16>, tensor<32x32xbf16>
}

func.func @main() -> i32 {
%0:2 = call @createInputsFor_add() : () -> (tensor<32x32xbf16>, tensor<32x32xbf16>)
%1 = call @add(%0#0, %0#1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
%c0_i32 = arith.constant 0 : i32
return %c0_i32 : i32
}
```
}];
}

#endif
4 changes: 4 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout(
mlir::tt::MemorySpace
toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType);

// Get Layout from MemRefType
//
Layout getLayoutFromMemRef(mlir::MemRefType memref);

mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context,
DataType dtype);

Expand Down
28 changes: 28 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
Expand Down Expand Up @@ -616,6 +617,29 @@ class DeallocateOpConversionPattern
}
};

// arith::ConstantOp conversion pattern
//
class ArithConstantOpConversionPattern
: public OpConversionPattern<arith::ConstantOp> {

public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::ConstantOp constOp, arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type newTy = this->getTypeConverter()->convertType(constOp.getType());
if (!newTy) {
return rewriter.notifyMatchFailure(constOp, "type conversion failed");
}

rewriter.replaceOpWithNewOp<emitc::ConstantOp>(constOp, newTy,
adaptor.getValue());
return success();
}
};

// Module Op conversion pattern
//
// This conversion pattern removes attributes from the ModuleOp. Previously,
Expand Down Expand Up @@ -762,6 +786,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::FillCacheOp>>(typeConverter,
ctx);

// Arith ops
//
patterns.add<ArithConstantOpConversionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
195 changes: 193 additions & 2 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,27 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"

namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNDEALLOCATE
#define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS
#define GEN_PASS_DEF_TTNNCREATEINPUTGENERATORS
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc"

class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
Expand Down Expand Up @@ -873,4 +885,183 @@ class TTNNDecomposeLayouts
}
};

class TTNNCreateInputGenerators
: public impl::TTNNCreateInputGeneratorsBase<TTNNCreateInputGenerators> {

public:
using impl::TTNNCreateInputGeneratorsBase<
TTNNCreateInputGenerators>::TTNNCreateInputGeneratorsBase;

void runOnOperation() final {
ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

// Ensure that the module has a single region and a single block within that
// region
assert(module->getRegions().size() == 1);
assert(module->getRegion(0).getBlocks().size() == 1);

// Get the first block of the region at index 0
//
Block *firstBlock = module.getBody(0);

// Find all the func.func ops in the module
//
SmallVector<func::FuncOp, 1> forwardFuncOps;
for (mlir::Operation &op : firstBlock->getOperations()) {
if (mlir::func::FuncOp funcOp = dyn_cast<func::FuncOp>(op)) {

// Skip functions that are called elsewhere in the IR
//
// This will skip utility functions that are used by other functions,
// only top-level "forward" functions should be considered
//
if (!funcOp->getUses().empty()) {
continue;
}

forwardFuncOps.push_back(funcOp);
}
}

// Iterate over all the func ops and add input tensor generator functions
//
for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) {
// Get all the input tensors for the current forward func
//
llvm::SmallVector<mlir::RankedTensorType, 2> inputTensors;
for (auto input : forwardFuncOp.getFunctionType().getInputs()) {
inputTensors.push_back(llvm::cast<mlir::RankedTensorType>(input));
}

// Create a new function that will generate the input tensors
//
std::string inputGenFuncName =
"createInputsFor_" + forwardFuncOp.getName().str();

// Create function type
//
mlir::TypeRange returnTypeRange =
mlir::TypeRange(forwardFuncOp.getFunctionType().getInputs());
FunctionType functionType =
mlir::FunctionType::get(&getContext(), {}, returnTypeRange);

// Set insertion point to end of first block
//
rewriter.setInsertionPointToEnd(firstBlock);

// Create the function
//
func::FuncOp inputGenFuncOp = rewriter.create<mlir::func::FuncOp>(
module->getLoc(), inputGenFuncName, functionType);

// Add a Block to func op and set insertion point to the beginning of the
// Block
//
::mlir::Block *currFnBlock = inputGenFuncOp.addEntryBlock();
rewriter.setInsertionPointToStart(currFnBlock);

// Create the input tensors
//
SmallVector<Value, 2> generatedTensors;
for (Type tensorType : returnTypeRange) {
assert(llvm::isa<mlir::RankedTensorType>(tensorType));

RankedTensorType tensor =
llvm::cast<mlir::RankedTensorType>(tensorType);

// Get the layout attribute
//
ttnn::TTNNLayoutAttr layoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(tensor.getEncoding());

// Get the shape of the tensor, tensor layout, and data type
//
ShapeAttr shapeAttr =
ttnn::ShapeAttr::get(&getContext(), tensor.getShape());
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(&getContext(), layoutAttr.getLayout());
DataTypeAttr dTypeAttr =
DataTypeAttr::get(&getContext(), layoutAttr.getDataType());

// Create a new tensor
//
// TODO(svuckovic): Move from ttnn::EmptyOp to ttnn::OnesOp once #1476
// lands
//
mlir::Value tensorValue = rewriter.create<ttnn::EmptyOp>(
forwardFuncOp->getLoc(), tensorType, nullptr, shapeAttr, dTypeAttr,
tensorLayoutAttr, nullptr);

generatedTensors.push_back(tensorValue);
}

// Return the generated tensors
//
rewriter.create<func::ReturnOp>(forwardFuncOp->getLoc(),
generatedTensors);
}

// Create a main function to call input generators and forward funcs
//
{
// Create a new function that will generate the input tensors
//
std::string mainFuncName = "main";

// Create function type
//
mlir::TypeRange returnTypeRange = mlir::TypeRange(rewriter.getI32Type());
FunctionType functionType =
mlir::FunctionType::get(&getContext(), {}, returnTypeRange);

// Set insertion point to end of first block
//
rewriter.setInsertionPointToEnd(firstBlock);

// Create the function
//
func::FuncOp mainFuncOp = rewriter.create<mlir::func::FuncOp>(
module->getLoc(), mainFuncName, functionType);

::mlir::Block *currFnBlock = mainFuncOp.addEntryBlock();

// Set insertion point to the beginning of the block
//
rewriter.setInsertionPointToStart(currFnBlock);

// Call the input generators
//
for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) {
std::string inputGenFuncName =
"createInputsFor_" + forwardFuncOp.getName().str();

// Get the input generator function
//
mlir::func::FuncOp inputGenFuncOp =
module.lookupSymbol<mlir::func::FuncOp>(inputGenFuncName);

// Call the input generator function
//
func::CallOp createdTensors = rewriter.create<mlir::func::CallOp>(
forwardFuncOp->getLoc(), inputGenFuncOp, ValueRange());

rewriter.create<mlir::func::CallOp>(forwardFuncOp->getLoc(),
forwardFuncOp,
createdTensors->getResults());
}

// Return 0
//
// func::ReturnOp requires a Value to be returned, which means that an SSA
// needs to be returned, hence create a constant 0 via arith::ConstantOp
//
Value constantZero = rewriter.create<arith::ConstantOp>(
rewriter.getUnknownLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(0));
rewriter.create<func::ReturnOp>(mainFuncOp->getLoc(), constantZero);
}
}
};

} // namespace mlir::tt::ttnn
12 changes: 0 additions & 12 deletions lib/Dialect/TTNN/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,6 @@ toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType) {
llvm_unreachable("Unknown MemorySpace");
}

DataType getDataTypeFromMemRef(mlir::MemRefType memref) {
Type elementType = memref.getElementType();
DataType dtype = DataType::Float32;
if (llvm::isa<TileType>(elementType)) {
auto tileType = mlir::cast<TileType>(elementType);
dtype = tileType.getDataType();
} else {
dtype = elementTypeToDataType(elementType);
}
return dtype;
}

Layout getLayoutFromMemRef(mlir::MemRefType memref) {
ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor;
Type elementType = memref.getElementType();
Expand Down
36 changes: 36 additions & 0 deletions test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: ttmlir-opt --ttnn-create-input-gens %s | FileCheck %s

#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #ttnn.buffer_type<dram>
#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux"}], [{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 98816, erisc_l1_unreserved_base = 102624, dram_unreserved_base = 32, dram_unreserved_end = 1073083040, physical_cores = {worker = [ 1x1, 1x2, 1x3, 1x4, 1x6, 1x7, 1x8, 1x9, 2x1, 2x2, 2x3, 2x4, 2x6, 2x7, 2x8, 2x9, 3x1, 3x2, 3x3, 3x4, 3x6, 3x7, 3x8, 3x9, 4x1, 4x2, 4x3, 4x4, 4x6, 4x7, 4x8, 4x9, 5x1, 5x2, 5x3, 5x4, 5x6, 5x7, 5x8, 5x9, 7x1, 7x2, 7x3, 7x4, 7x6, 7x7, 7x8, 7x9, 8x1, 8x2, 8x3, 8x4, 8x6, 8x7, 8x8, 8x9, 9x1, 9x2, 9x3, 9x4, 9x6, 9x7, 9x8, 9x9] dram = [ 1x0, 1x5, 2x5, 3x5, 5x0, 5x5, 7x0, 7x5, 8x5, 9x5, 11x0, 11x5] eth_inactive = [ 0x1, 0x2, 0x3, 0x4, 0x6, 0x7, 0x8, 0x9, 6x2, 6x3, 6x6, 6x7, 6x8]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#system_memory = #ttnn.buffer_type<system_memory>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xbf16, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
// CHECK: func.func @add(%arg0: [[TENSOR_A:.*]], %arg1: [[TENSOR_B:.*]]) -> [[TENSOR_OUT:.*]] {
func.func @add(%arg0: tensor<32x32xbf16, #ttnn_layout>, %arg1: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_device"(%arg0, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%1) <{layout = #ttnn.layout<tile>}> : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1>
%3 = "ttnn.to_device"(%arg1, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%4 = "ttnn.to_layout"(%3) <{layout = #ttnn.layout<tile>}> : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1>
%5 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>, shape = #ttnn.shape<32x32>}> : (!tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%6 = "ttnn.add"(%2, %4, %5) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16, #ttnn_layout1>, tensor<32x32xbf16, #ttnn_layout1>, tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1>
%7 = "ttnn.from_device"(%6) : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout>
%8 = "ttnn.to_layout"(%7) <{layout = #ttnn.layout<row_major>}> : (tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout>
return %8 : tensor<32x32xbf16, #ttnn_layout>
}

// Confirm that the generator func is generated, and that the tensor attrs match:
//
// CHECK: func.func @createInputsFor_add() -> ([[TENSOR_A]], [[TENSOR_B]]) {
// CHECK: {{.*}} -> [[TENSOR_A]]
// CHECK: {{.*}} -> [[TENSOR_B]]
// CHECK: return %0, %1 : [[TENSOR_A]], [[TENSOR_B]]

// Confirm that the main func is generated, and that the tensor attrs match:
//
// CHECK: func.func @main() -> i32 {
// CHECK: %0:2 = call @createInputsFor_add() : () -> ([[TENSOR_A]], [[TENSOR_B]])
// CHECK: %1 = call @add(%0#0, %0#1) : ([[TENSOR_A]], [[TENSOR_B]]) -> [[TENSOR_OUT]]
}

0 comments on commit 6c4a4fa

Please sign in to comment.