Skip to content

Commit

Permalink
Address code review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Nov 22, 2024
1 parent e3b34e7 commit 3f1236c
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 16 deletions.
8 changes: 7 additions & 1 deletion include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> {
let summary = "Broadcast operation is folded to all the consumers.";
let description = [{
As TTNN does not support fold operation, this pass folds the broadcast operations to all consumers.
This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops.
Example:
%1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32>
%2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>

This above broadcast is folded as:
%1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32>
}];
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/Transforms/Broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ namespace mlir::tt::ttir {

//===----------------------------------------------------------------------===//
// Broadcast Folding pass
// Our backend supports implicit broadcast of operands, so explicit broadcast
// instructions are folded.
//
// For Example:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) ->
// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1,
// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) ->
// tensor<512xf32>
//
// After folding:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>,
// tensor<512xf32>) -> tensor<512xf32>
//===----------------------------------------------------------------------===//

class TTIRBroadcastFoldRewriter : public OpRewritePattern<BroadcastOp> {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm,

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
createTTNNPipelineLayoutDecompositionPass(pm, options);
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,16 @@ module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replic
%3 = "ttir.maximum"(%1, %arg1, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %3 : tensor<512x512xf32>
}

func.func public @broadcast() -> (tensor<32xf32>) {
%0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%1 = tensor.empty() : tensor<32xf32>
%2 = "ttir.broadcast"(%0, %1) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1xf32>, tensor<32xf32>) -> tensor<32xf32>
%3 = tensor.empty() : tensor<32xf32>
%4 = "ttir.broadcast"(%2, %3) <{dimension = [0], operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
// CHECK-NOT: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
return %4 : tensor<32xf32>
}
}


14 changes: 0 additions & 14 deletions test/ttmlir/Silicon/TTNN/simple_broadcast.mlir

This file was deleted.

0 comments on commit 3f1236c

Please sign in to comment.