From 3f1236c314f15f72ee24e0993f2d5aa8bc603923 Mon Sep 17 00:00:00 2001 From: uazizTT Date: Thu, 21 Nov 2024 15:03:33 -0500 Subject: [PATCH] Address code review comments. --- include/ttmlir/Dialect/TTIR/Transforms/Passes.td | 8 +++++++- lib/Dialect/TTIR/Transforms/Broadcast.cpp | 16 ++++++++++++++++ lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 2 +- test/ttmlir/Dialect/TTNN/simple_broadcast.mlir | 12 ++++++++++++ test/ttmlir/Silicon/TTNN/simple_broadcast.mlir | 14 -------------- 5 files changed, 36 insertions(+), 16 deletions(-) delete mode 100644 test/ttmlir/Silicon/TTNN/simple_broadcast.mlir diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index fea8f83164..b6269f7153 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -115,7 +115,13 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { let summary = "Broadcast operation is folded to all the consumers."; let description = [{ - As TTNN does not support fold operation, this pass folds the broadcast operations to all consumers. + This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops. + Example: + %1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32> + %2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32> + + This above broadcast is folded as: + %1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32> }]; } diff --git a/lib/Dialect/TTIR/Transforms/Broadcast.cpp b/lib/Dialect/TTIR/Transforms/Broadcast.cpp index 3bd987089c..7823b021ed 100644 --- a/lib/Dialect/TTIR/Transforms/Broadcast.cpp +++ b/lib/Dialect/TTIR/Transforms/Broadcast.cpp @@ -14,6 +14,22 @@ namespace mlir::tt::ttir { //===----------------------------------------------------------------------===// // Broadcast Folding pass +// Our backend supports implicit broadcast of operands, so explicit broadcast +// instructions are folded. +// +// For Example: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) -> +// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1, +// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) -> +// tensor<512xf32> +// +// After folding: +// +// %0 = tensor.empty() : tensor<512xf32> +// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>, +// tensor<512xf32>) -> tensor<512xf32> //===----------------------------------------------------------------------===// class TTIRBroadcastFoldRewriter : public OpRewritePattern { diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 6a292de18c..13c2f3b33e 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -121,8 +121,8 @@ void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm, void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - createTTNNPipelineTTIRBroadcastFoldPass(pm, options); createTTNNPipelineTTIRPasses(pm, options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineAnalysisPasses(pm, options); createTTNNPipelineLayoutDecompositionPass(pm, options); diff --git a/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir b/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir index e7aac7e2e9..b6e49e6bae 100644 --- a/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir @@ -10,4 +10,16 @@ module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replic %3 = "ttir.maximum"(%1, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> return %3 : tensor<512x512xf32> } + + func.func public @broadcast() -> (tensor<32xf32>) { + %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + %1 = tensor.empty() : tensor<32xf32> + %2 = "ttir.broadcast"(%0, %1) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1xf32>, tensor<32xf32>) -> tensor<32xf32> + %3 = tensor.empty() : tensor<32xf32> + %4 = "ttir.broadcast"(%2, %3) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> + // CHECK-NOT: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + return %4 : tensor<32xf32> + } } + + diff --git a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir b/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir deleted file mode 100644 index 79d40d929d..0000000000 --- a/test/ttmlir/Silicon/TTNN/simple_broadcast.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint - -func.func public @main(%arg0: tensor<1xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - // CHECK-NOT: %[[C:.*]] = "ttnn.broadcast"[[C:.*]] - %0 = tensor.empty() : tensor<512x512xf32> - %1 = "ttir.broadcast"(%arg0, %0) <{dimension = [1], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> - %2 = tensor.empty() : tensor<512x512xf32> - %3 = "ttir.maximum"(%1, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> - return %3 : tensor<512x512xf32> -}