From 6c4a4fac829af58fa44316b5dcbadf9ff7c9086c Mon Sep 17 00:00:00 2001 From: Sasa Vuckovic <134393361+svuckovicTT@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:14:20 +0100 Subject: [PATCH] Add pass to create input tensor generator functions for emitc path (#1523) --- .../ttmlir/Dialect/TTNN/Transforms/Passes.td | 39 ++++ include/ttmlir/Dialect/TTNN/Utils/Utils.h | 4 + lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 28 +++ lib/Dialect/TTNN/Transforms/Passes.cpp | 195 +++++++++++++++++- lib/Dialect/TTNN/Utils/Utils.cpp | 12 -- .../Transforms/ttnn_create_input_gens_0.mlir | 36 ++++ 6 files changed, 300 insertions(+), 14 deletions(-) create mode 100644 test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index 444927e34..13253d131 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -36,4 +36,43 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> { }]; } +def TTNNCreateInputGenerators: Pass<"ttnn-create-input-gens", "::mlir::ModuleOp"> { + let summary = "Create input generators for the forward functions."; + let description = [{ + This pass creates input generators for the "forward" functions. It + additionally creates a main function to run the forward function with the + generated inputs. + + The pass is useful for EmitC path. By creating input generators before + converting to Emitc Dialect, followed by transformation to C++ code, the + resulting code won't require any edits to run. + + Given a forward function like this: + + ``` + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = "ttnn.add"(%arg0, %arg1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0 : tensor<32x32xbf16> + } + ``` + + The pass will create two function like this: + + ``` + func.func @createInputsFor_add() -> (tensor<32x32xbf16>, tensor<32x32xbf16>) { + %0 = "ttnn.empty"() <{shape = #ttnn.shape<32x32>}> : () -> tensor<32x32xbf16> + %1 = "ttnn.empty"() <{shape = #ttnn.shape<32x32>}> : () -> tensor<32x32xbf16> + return %0, %1 : tensor<32x32xbf16>, tensor<32x32xbf16> + } + + func.func @main() -> i32 { + %0:2 = call @createInputsFor_add() : () -> (tensor<32x32xbf16>, tensor<32x32xbf16>) + %1 = call @add(%0#0, %0#1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %c0_i32 = arith.constant 0 : i32 + return %c0_i32 : i32 + } + ``` + }]; +} + #endif diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index f214fa793..2c4b7a250 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -35,6 +35,10 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout( mlir::tt::MemorySpace toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType); +// Get Layout from MemRefType +// +Layout getLayoutFromMemRef(mlir::MemRefType memref); + mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype); diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 3986438e6..aedad4d29 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -9,6 +9,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" @@ -616,6 +617,29 @@ class DeallocateOpConversionPattern } }; +// arith::ConstantOp conversion pattern +// +class ArithConstantOpConversionPattern + : public OpConversionPattern { + +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp constOp, arith::ConstantOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type newTy = this->getTypeConverter()->convertType(constOp.getType()); + if (!newTy) { + return rewriter.notifyMatchFailure(constOp, "type conversion failed"); + } + + rewriter.replaceOpWithNewOp(constOp, newTy, + adaptor.getValue()); + return success(); + } +}; + // Module Op conversion pattern // // This conversion pattern removes attributes from the ModuleOp. Previously, @@ -762,6 +786,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, ctx); patterns.add>(typeConverter, ctx); + + // Arith ops + // + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 01971b6c6..c842c4075 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -3,15 +3,27 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Dialect/TTNN/Transforms/Passes.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" + #include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" -#include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNDEALLOCATE #define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS +#define GEN_PASS_DEF_TTNNCREATEINPUTGENERATORS #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" class TTNNDeallocate : public impl::TTNNDeallocateBase { @@ -873,4 +885,183 @@ class TTNNDecomposeLayouts } }; +class TTNNCreateInputGenerators + : public impl::TTNNCreateInputGeneratorsBase { + +public: + using impl::TTNNCreateInputGeneratorsBase< + TTNNCreateInputGenerators>::TTNNCreateInputGeneratorsBase; + + void runOnOperation() final { + ModuleOp module = getOperation(); + IRRewriter rewriter(&getContext()); + + // Ensure that the module has a single region and a single block within that + // region + assert(module->getRegions().size() == 1); + assert(module->getRegion(0).getBlocks().size() == 1); + + // Get the first block of the region at index 0 + // + Block *firstBlock = module.getBody(0); + + // Find all the func.func ops in the module + // + SmallVector forwardFuncOps; + for (mlir::Operation &op : firstBlock->getOperations()) { + if (mlir::func::FuncOp funcOp = dyn_cast(op)) { + + // Skip functions that are called elsewhere in the IR + // + // This will skip utility functions that are used by other functions, + // only top-level "forward" functions should be considered + // + if (!funcOp->getUses().empty()) { + continue; + } + + forwardFuncOps.push_back(funcOp); + } + } + + // Iterate over all the func ops and add input tensor generator functions + // + for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) { + // Get all the input tensors for the current forward func + // + llvm::SmallVector inputTensors; + for (auto input : forwardFuncOp.getFunctionType().getInputs()) { + inputTensors.push_back(llvm::cast(input)); + } + + // Create a new function that will generate the input tensors + // + std::string inputGenFuncName = + "createInputsFor_" + forwardFuncOp.getName().str(); + + // Create function type + // + mlir::TypeRange returnTypeRange = + mlir::TypeRange(forwardFuncOp.getFunctionType().getInputs()); + FunctionType functionType = + mlir::FunctionType::get(&getContext(), {}, returnTypeRange); + + // Set insertion point to end of first block + // + rewriter.setInsertionPointToEnd(firstBlock); + + // Create the function + // + func::FuncOp inputGenFuncOp = rewriter.create( + module->getLoc(), inputGenFuncName, functionType); + + // Add a Block to func op and set insertion point to the beginning of the + // Block + // + ::mlir::Block *currFnBlock = inputGenFuncOp.addEntryBlock(); + rewriter.setInsertionPointToStart(currFnBlock); + + // Create the input tensors + // + SmallVector generatedTensors; + for (Type tensorType : returnTypeRange) { + assert(llvm::isa(tensorType)); + + RankedTensorType tensor = + llvm::cast(tensorType); + + // Get the layout attribute + // + ttnn::TTNNLayoutAttr layoutAttr = + mlir::cast(tensor.getEncoding()); + + // Get the shape of the tensor, tensor layout, and data type + // + ShapeAttr shapeAttr = + ttnn::ShapeAttr::get(&getContext(), tensor.getShape()); + ttnn::LayoutAttr tensorLayoutAttr = + ttnn::LayoutAttr::get(&getContext(), layoutAttr.getLayout()); + DataTypeAttr dTypeAttr = + DataTypeAttr::get(&getContext(), layoutAttr.getDataType()); + + // Create a new tensor + // + // TODO(svuckovic): Move from ttnn::EmptyOp to ttnn::OnesOp once #1476 + // lands + // + mlir::Value tensorValue = rewriter.create( + forwardFuncOp->getLoc(), tensorType, nullptr, shapeAttr, dTypeAttr, + tensorLayoutAttr, nullptr); + + generatedTensors.push_back(tensorValue); + } + + // Return the generated tensors + // + rewriter.create(forwardFuncOp->getLoc(), + generatedTensors); + } + + // Create a main function to call input generators and forward funcs + // + { + // Create a new function that will generate the input tensors + // + std::string mainFuncName = "main"; + + // Create function type + // + mlir::TypeRange returnTypeRange = mlir::TypeRange(rewriter.getI32Type()); + FunctionType functionType = + mlir::FunctionType::get(&getContext(), {}, returnTypeRange); + + // Set insertion point to end of first block + // + rewriter.setInsertionPointToEnd(firstBlock); + + // Create the function + // + func::FuncOp mainFuncOp = rewriter.create( + module->getLoc(), mainFuncName, functionType); + + ::mlir::Block *currFnBlock = mainFuncOp.addEntryBlock(); + + // Set insertion point to the beginning of the block + // + rewriter.setInsertionPointToStart(currFnBlock); + + // Call the input generators + // + for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) { + std::string inputGenFuncName = + "createInputsFor_" + forwardFuncOp.getName().str(); + + // Get the input generator function + // + mlir::func::FuncOp inputGenFuncOp = + module.lookupSymbol(inputGenFuncName); + + // Call the input generator function + // + func::CallOp createdTensors = rewriter.create( + forwardFuncOp->getLoc(), inputGenFuncOp, ValueRange()); + + rewriter.create(forwardFuncOp->getLoc(), + forwardFuncOp, + createdTensors->getResults()); + } + + // Return 0 + // + // func::ReturnOp requires a Value to be returned, which means that an SSA + // needs to be returned, hence create a constant 0 via arith::ConstantOp + // + Value constantZero = rewriter.create( + rewriter.getUnknownLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(0)); + rewriter.create(mainFuncOp->getLoc(), constantZero); + } + } +}; + } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/Utils.cpp b/lib/Dialect/TTNN/Utils/Utils.cpp index 015629921..514e17e52 100644 --- a/lib/Dialect/TTNN/Utils/Utils.cpp +++ b/lib/Dialect/TTNN/Utils/Utils.cpp @@ -80,18 +80,6 @@ toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType) { llvm_unreachable("Unknown MemorySpace"); } -DataType getDataTypeFromMemRef(mlir::MemRefType memref) { - Type elementType = memref.getElementType(); - DataType dtype = DataType::Float32; - if (llvm::isa(elementType)) { - auto tileType = mlir::cast(elementType); - dtype = tileType.getDataType(); - } else { - dtype = elementTypeToDataType(elementType); - } - return dtype; -} - Layout getLayoutFromMemRef(mlir::MemRefType memref) { ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor; Type elementType = memref.getElementType(); diff --git a/test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir b/test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir new file mode 100644 index 000000000..8342c4f5a --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Transforms/ttnn_create_input_gens_0.mlir @@ -0,0 +1,36 @@ +// RUN: ttmlir-opt --ttnn-create-input-gens %s | FileCheck %s + +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 98816, erisc_l1_unreserved_base = 102624, dram_unreserved_base = 32, dram_unreserved_end = 1073083040, physical_cores = {worker = [ 1x1, 1x2, 1x3, 1x4, 1x6, 1x7, 1x8, 1x9, 2x1, 2x2, 2x3, 2x4, 2x6, 2x7, 2x8, 2x9, 3x1, 3x2, 3x3, 3x4, 3x6, 3x7, 3x8, 3x9, 4x1, 4x2, 4x3, 4x4, 4x6, 4x7, 4x8, 4x9, 5x1, 5x2, 5x3, 5x4, 5x6, 5x7, 5x8, 5x9, 7x1, 7x2, 7x3, 7x4, 7x6, 7x7, 7x8, 7x9, 8x1, 8x2, 8x3, 8x4, 8x6, 8x7, 8x8, 8x9, 9x1, 9x2, 9x3, 9x4, 9x6, 9x7, 9x8, 9x9] dram = [ 1x0, 1x5, 2x5, 3x5, 5x0, 5x5, 7x0, 7x5, 8x5, 9x5, 11x0, 11x5] eth_inactive = [ 0x1, 0x2, 0x3, 0x4, 0x6, 0x7, 0x8, 0x9, 6x2, 6x3, 6x6, 6x7, 6x8]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xbf16, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, > +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + // CHECK: func.func @add(%arg0: [[TENSOR_A:.*]], %arg1: [[TENSOR_B:.*]]) -> [[TENSOR_OUT:.*]] { + func.func @add(%arg0: tensor<32x32xbf16, #ttnn_layout>, %arg1: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_device"(%arg0, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%1) <{layout = #ttnn.layout}> : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1> + %3 = "ttnn.to_device"(%arg1, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %4 = "ttnn.to_layout"(%3) <{layout = #ttnn.layout}> : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1> + %5 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >, shape = #ttnn.shape<32x32>}> : (!tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %6 = "ttnn.add"(%2, %4, %5) <{operandSegmentSizes = array}> : (tensor<32x32xbf16, #ttnn_layout1>, tensor<32x32xbf16, #ttnn_layout1>, tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout1> + %7 = "ttnn.from_device"(%6) : (tensor<32x32xbf16, #ttnn_layout1>) -> tensor<32x32xbf16, #ttnn_layout> + %8 = "ttnn.to_layout"(%7) <{layout = #ttnn.layout}> : (tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout> + return %8 : tensor<32x32xbf16, #ttnn_layout> + } + +// Confirm that the generator func is generated, and that the tensor attrs match: +// +// CHECK: func.func @createInputsFor_add() -> ([[TENSOR_A]], [[TENSOR_B]]) { +// CHECK: {{.*}} -> [[TENSOR_A]] +// CHECK: {{.*}} -> [[TENSOR_B]] +// CHECK: return %0, %1 : [[TENSOR_A]], [[TENSOR_B]] + +// Confirm that the main func is generated, and that the tensor attrs match: +// +// CHECK: func.func @main() -> i32 { +// CHECK: %0:2 = call @createInputsFor_add() : () -> ([[TENSOR_A]], [[TENSOR_B]]) +// CHECK: %1 = call @add(%0#0, %0#1) : ([[TENSOR_A]], [[TENSOR_B]]) -> [[TENSOR_OUT]] +}