Skip to content

Commit

Permalink
Address code review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Nov 22, 2024
1 parent 8f69507 commit c051721
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
8 changes: 7 additions & 1 deletion include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,13 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> {
let summary = "Broadcast operation is folded to all the consumers.";
let description = [{
As TTNN does not support fold operation, this pass folds the broadcast operations to all consumers.
This pass walks through the graph and folds all broadcast instructions since broadcast is supported implicitly by backend ops.
Example:
%1 = "ttir.broadcast"(%arg0) (tensor<1xf32>) -> tensor<512xf32>
%2 = "ttir.maximum"(%1, %arg1) (tensor<512xf32>, tensor<512xf32>) -> tensor<512xf32>

This above broadcast is folded as:
%1 = "ttir.maximum"(%arg0, %arg1) (tensor<1xf32>, tensor<512xf32>) -> tensor<512xf32>
}];
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/Transforms/Broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ namespace mlir::tt::ttir {

//===----------------------------------------------------------------------===//
// Broadcast Folding pass
// Our backend supports implicit broadcast of operands, so explicit broadcast
// instructions are folded.
//
// For Example:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.broadcast"(%arg0, %0) (tensor<1xf32>, tensor<512xf32>) ->
// tensor<512xf32> %2 = tensor.empty() : tensor<512xf32> %3 = "ttir.maximum"(%1,
// %arg1, %2) (tensor<512xf32>, tensor<512xf32>, tensor<512xf32>) ->
// tensor<512xf32>
//
// After folding:
//
// %0 = tensor.empty() : tensor<512xf32>
// %1 = "ttir.maximum"(%arg0, %arg1, %0) (tensor<1xf32>, tensor<512xf32>,
// tensor<512xf32>) -> tensor<512xf32>
//===----------------------------------------------------------------------===//

class TTIRBroadcastFoldRewriter : public OpRewritePattern<BroadcastOp> {
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineDeallocPass(pm, *optionsStruct);
createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
createTTNNPipelineLayoutDecompositionPass(pm, options);
Expand Down
14 changes: 0 additions & 14 deletions test/ttmlir/Silicon/TTNN/simple_broadcast.mlir

This file was deleted.

0 comments on commit c051721

Please sign in to comment.