From d2cd95c196de5395538c3a4132c9674e1a3703da Mon Sep 17 00:00:00 2001 From: Stefan Djordjevic <157365107+sdjordjevicTT@users.noreply.github.com> Date: Thu, 12 Dec 2024 14:16:44 +0100 Subject: [PATCH] Merging consecutive ttnn ToLayoutOps (#1552) --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.h | 13 +-- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 2 + lib/Dialect/TTNN/IR/TTNNOps.cpp | 61 ++++++++++++++ lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 1 + .../simple_to_layout_op_canonicalizer.mlir | 84 +++++++++++++++++++ 5 files changed, 155 insertions(+), 6 deletions(-) create mode 100644 test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h index 457c7722b..3122e2323 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.h @@ -5,22 +5,23 @@ #ifndef TTMLIR_DIALECT_TTNN_IR_TTNNOPS_H #define TTMLIR_DIALECT_TTNN_IR_TTNNOPS_H +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h" + #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" - -#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc" -#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h" #define GET_OP_CLASSES +#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h.inc" #endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 29505d102..758fb41d7 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -61,6 +61,8 @@ def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> { OptionalAttr:$memory_config, Optional:$device); let results = (outs AnyRankedTensor:$result); + + let hasCanonicalizeMethod = 1; } def TTNN_TypecastOp : TTNN_Op<"typecast"> { diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 3798745d5..cfe0dacc3 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -666,6 +666,67 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ToLayoutOp +//===----------------------------------------------------------------------===// + +// ToLayoutOp canonicalization +// ToLayoutOp can be canonicalized if the previous op is also a ToLayoutOp. The +// previous op can be merged with the current ToLayoutOp op if the previous op +// has only one use. df - data format, l - layout, ms - memory space, tml - +// tensor memory layout +// +// | +// ----------------------- +// | ToLayoutOp | | +// | df1, l1, ms1, tml1 | ----------------------- +// ----------------------- | ToLayoutOp | +// | --> | df2, l1, ms2, tml1 | +// | ----------------------- +// ----------------------- | +// | ToLayoutOp | +// | df2, ms2 | +// ----------------------- +// | +// +::mlir::LogicalResult +mlir::tt::ttnn::ToLayoutOp::canonicalize(ToLayoutOp toLayoutOp, + PatternRewriter &rewriter) { + // Get the input operand and verify that the previous op is toLayoutOp + ToLayoutOp previousToLayoutOp = + toLayoutOp.getOperand(0).getDefiningOp(); + + // NOLINTNEXTLINE + if (!previousToLayoutOp) { + return mlir::failure(); + } + + // Check if the previous op has only one use. We can only merge if the + // previous op has single use. + if (!previousToLayoutOp->hasOneUse()) { + return mlir::failure(); + } + + // Replace the previous op with the merged ToLayoutOp + Value mergedToLayout = rewriter.replaceOpWithNewOp( + previousToLayoutOp, toLayoutOp.getType(), previousToLayoutOp.getInput(), + toLayoutOp.getLayoutAttr(), + toLayoutOp.getDtypeAttr() ? toLayoutOp.getDtypeAttr() + : previousToLayoutOp.getDtypeAttr(), + toLayoutOp.getMemoryConfigAttr() + ? toLayoutOp.getMemoryConfigAttr() + : previousToLayoutOp.getMemoryConfigAttr(), + toLayoutOp.getDevice()); + + // Replace all uses of the current op with the merged ToLayoutOp + rewriter.replaceAllUsesWith(toLayoutOp, mergedToLayout); + + // Erase the current op + rewriter.eraseOp(toLayoutOp); + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // LinearOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index b2f257a7e..9b2d11a6e 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -70,6 +70,7 @@ void createTTNNPipelineWorkaroundPass( options.layouotWorkaroundsEnabled, options.decompositionWorkaroundsEnabled}; pm.addPass(createTTNNWorkarounds(workaroundOptions)); + pm.addPass(mlir::createCanonicalizerPass()); } void createTTNNPipelineLayoutDecompositionPass( diff --git a/test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir b/test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir new file mode 100644 index 000000000..73f68a5f2 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Canonicalizer/simple_to_layout_op_canonicalizer.mlir @@ -0,0 +1,84 @@ +// RUN: ttmlir-opt --canonicalize %s | FileCheck %s +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xbf16, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, > +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xbf16, #dram>, > +#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout4 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #system_memory>> +#ttnn_layout5 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xf32, #system_memory>> +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + func.func @merge_to_layout_op_layout(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout2> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + // Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged. + // CHECK: "ttnn.to_layout"(%arg0, %0) + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<1x1>>, > + // CHECK-SAME: -> tensor<32x32xbf16, #ttnn_layout1> + // CHECK-NEXT: return + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout2> + return %2 : tensor<32x32xbf16, #ttnn_layout2> + } + + func.func @merge_to_layout_op_data_type(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout3> { + // Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged. + // CHECK: "ttnn.to_layout"(%arg0, %0) + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<1x1>>, > + // CHECK-SAME: -> tensor<32x32xf32, #ttnn_layout2> + // CHECK-NEXT: return + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout3> + return %2 : tensor<32x32xf32, #ttnn_layout3> + } + + func.func @merge_to_layout_op_memory_config(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout4> { + // Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged. + // CHECK: "ttnn.to_layout"(%arg0, %0) + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<32x32>>> + // CHECK-SAME: -> tensor<32x32xbf16, #ttnn_layout3> + // CHECK-NEXT: return + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout4> + return %2 : tensor<32x32xbf16, #ttnn_layout4> + } + + func.func @merge_to_layout_op_all(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> { + // Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged. + // CHECK: "ttnn.to_layout"(%arg0, %0) + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<32x32>>> + // CHECK-SAME: -> tensor<32x32xf32, #ttnn_layout4> + // CHECK-NEXT: return + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5> + return %2 : tensor<32x32xf32, #ttnn_layout5> + } + + func.func @merge_to_layout_op_4x(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> { + // Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged. + // CHECK: "ttnn.to_layout"(%arg0, %0) + // CHECK-SAME: dtype = #tt.supportedDataTypes + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<32x32>>> + // CHECK-SAME: -> tensor<32x32xf32, #ttnn_layout4> + // CHECK-NEXT: return + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5> + %3 = "ttnn.to_layout"(%2, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<1x1>>, >}> : (tensor<32x32xf32, #ttnn_layout5>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout3> + %4 = "ttnn.to_layout"(%3, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5> + return %4 : tensor<32x32xf32, #ttnn_layout5> + } +}