Skip to content

Commit

Permalink
Merging consecutive ttnn ToLayoutOps (#1552)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT authored Dec 12, 2024
1 parent f226c8c commit d2cd95c
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 6 deletions.
13 changes: 7 additions & 6 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@
#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNOPS_H
#define TTMLIR_DIALECT_TTNN_IR_TTNNOPS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h"

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkaroundInterface.h"

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h.inc"

#endif
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> {
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config,
Optional<TT_Device>:$device);
let results = (outs AnyRankedTensor:$result);

let hasCanonicalizeMethod = 1;
}

def TTNN_TypecastOp : TTNN_Op<"typecast"> {
Expand Down
61 changes: 61 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,67 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ToLayoutOp
//===----------------------------------------------------------------------===//

// ToLayoutOp canonicalization
// ToLayoutOp can be canonicalized if the previous op is also a ToLayoutOp. The
// previous op can be merged with the current ToLayoutOp op if the previous op
// has only one use. df - data format, l - layout, ms - memory space, tml -
// tensor memory layout
//
// |
// -----------------------
// | ToLayoutOp | |
// | df1, l1, ms1, tml1 | -----------------------
// ----------------------- | ToLayoutOp |
// | --> | df2, l1, ms2, tml1 |
// | -----------------------
// ----------------------- |
// | ToLayoutOp |
// | df2, ms2 |
// -----------------------
// |
//
::mlir::LogicalResult
mlir::tt::ttnn::ToLayoutOp::canonicalize(ToLayoutOp toLayoutOp,
PatternRewriter &rewriter) {
// Get the input operand and verify that the previous op is toLayoutOp
ToLayoutOp previousToLayoutOp =
toLayoutOp.getOperand(0).getDefiningOp<ToLayoutOp>();

// NOLINTNEXTLINE
if (!previousToLayoutOp) {
return mlir::failure();
}

// Check if the previous op has only one use. We can only merge if the
// previous op has single use.
if (!previousToLayoutOp->hasOneUse()) {
return mlir::failure();
}

// Replace the previous op with the merged ToLayoutOp
Value mergedToLayout = rewriter.replaceOpWithNewOp<ToLayoutOp>(
previousToLayoutOp, toLayoutOp.getType(), previousToLayoutOp.getInput(),
toLayoutOp.getLayoutAttr(),
toLayoutOp.getDtypeAttr() ? toLayoutOp.getDtypeAttr()
: previousToLayoutOp.getDtypeAttr(),
toLayoutOp.getMemoryConfigAttr()
? toLayoutOp.getMemoryConfigAttr()
: previousToLayoutOp.getMemoryConfigAttr(),
toLayoutOp.getDevice());

// Replace all uses of the current op with the merged ToLayoutOp
rewriter.replaceAllUsesWith(toLayoutOp, mergedToLayout);

// Erase the current op
rewriter.eraseOp(toLayoutOp);

return mlir::success();
}

//===----------------------------------------------------------------------===//
// LinearOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void createTTNNPipelineWorkaroundPass(
options.layouotWorkaroundsEnabled,
options.decompositionWorkaroundsEnabled};
pm.addPass(createTTNNWorkarounds(workaroundOptions));
pm.addPass(mlir::createCanonicalizerPass());
}

void createTTNNPipelineLayoutDecompositionPass(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: ttmlir-opt --canonicalize %s | FileCheck %s
#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #ttnn.buffer_type<dram>
#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#system_memory = #ttnn.buffer_type<system_memory>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xbf16, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xbf16, #dram>, <interleaved>>
#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #dram>, <interleaved>>
#ttnn_layout4 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #system_memory>>
#ttnn_layout5 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xf32, #system_memory>>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @merge_to_layout_op_layout(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout2> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
// Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged.
// CHECK: "ttnn.to_layout"(%arg0, %0)
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>
// CHECK-SAME: -> tensor<32x32xbf16, #ttnn_layout1>
// CHECK-NEXT: return
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout2>
return %2 : tensor<32x32xbf16, #ttnn_layout2>
}

func.func @merge_to_layout_op_data_type(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout3> {
// Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged.
// CHECK: "ttnn.to_layout"(%arg0, %0)
// CHECK-SAME: dtype = #tt.supportedDataTypes<f32>
// CHECK-SAME: layout = #ttnn.layout<tile>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>
// CHECK-SAME: -> tensor<32x32xf32, #ttnn_layout2>
// CHECK-NEXT: return
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout3>
return %2 : tensor<32x32xf32, #ttnn_layout3>
}

func.func @merge_to_layout_op_memory_config(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xbf16, #ttnn_layout4> {
// Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged.
// CHECK: "ttnn.to_layout"(%arg0, %0)
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<tile>
// CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>
// CHECK-SAME: -> tensor<32x32xbf16, #ttnn_layout3>
// CHECK-NEXT: return
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout4>
return %2 : tensor<32x32xbf16, #ttnn_layout4>
}

func.func @merge_to_layout_op_all(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> {
// Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged.
// CHECK: "ttnn.to_layout"(%arg0, %0)
// CHECK-SAME: dtype = #tt.supportedDataTypes<f32>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>
// CHECK-SAME: -> tensor<32x32xf32, #ttnn_layout4>
// CHECK-NEXT: return
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5>
return %2 : tensor<32x32xf32, #ttnn_layout5>
}

func.func @merge_to_layout_op_4x(%arg0: tensor<32x32xbf16, #ttnn_layout>) -> tensor<32x32xf32, #ttnn_layout5> {
// Verify that the to_layout op is canonicalized to a single to_layout op and the attributes are merged.
// CHECK: "ttnn.to_layout"(%arg0, %0)
// CHECK-SAME: dtype = #tt.supportedDataTypes<f32>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>
// CHECK-SAME: -> tensor<32x32xf32, #ttnn_layout4>
// CHECK-NEXT: return
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xbf16, #ttnn_layout1>
%2 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xbf16, #ttnn_layout1>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5>
%3 = "ttnn.to_layout"(%2, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout5>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout3>
%4 = "ttnn.to_layout"(%3, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<32x32>>>}> : (tensor<32x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout5>
return %4 : tensor<32x32xf32, #ttnn_layout5>
}
}

0 comments on commit d2cd95c

Please sign in to comment.