Skip to content

Commit

Permalink
Address code review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Nov 21, 2024
1 parent 43e9226 commit c142254
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
8 changes: 7 additions & 1 deletion include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> {
let summary = "Broadcast operation is folded to all the consumers.";
let description = [{
As TTNN does not support fold operation, this pass folds the broadcast operations to all consumers.
This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops.
Example:
%1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32>
%2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>

This above broadcast is folded as:
%1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32>
}];
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/Transforms/Broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ namespace mlir::tt::ttir {

//===----------------------------------------------------------------------===//
// Broadcast Folding pass
// Our backend supports implicit broadcast of operands, so explicit broadcast
// instructions are folded.
//
// For Example:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) ->
// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1,
// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) ->
// tensor<512xf32>
//
// After folding:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>,
// tensor<512xf32>) -> tensor<512xf32>
//===----------------------------------------------------------------------===//

class TTIRBroadcastFoldRewriter : public OpRewritePattern<BroadcastOp> {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm,

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
createTTNNPipelineLayoutDecompositionPass(pm, options);
Expand Down

0 comments on commit c142254

Please sign in to comment.