Skip to content

Commit

Permalink
Added folding
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Dec 2, 2024
1 parent 03d62a5 commit 3c4d9f7
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 76 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {
}];

let hasVerifier = 1;

let hasFolder = 1;
}

def TTIR_SliceOp: TTIR_DPSOp<"slice"> {
Expand Down
7 changes: 0 additions & 7 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,4 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
];
}

def TTIRReshapeFold: Pass<"ttir-reshape-fold", "::mlir::ModuleOp"> {
let summary = "Folds reshape ops that reshape the same shapes.";
let description = [{
This pass converts folds the ttir.reshape ops that are called to reshape a tensor to the same shape.
}];
}

#endif
9 changes: 9 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() {
return success();
}

// ReshapeOp folder
::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) {

if (getType() == getOperand(0).getType()) {
return getOperand(0);
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms
Constant.cpp
Generic.cpp
Layout.cpp
Reshape.cpp
Transforms.cpp
Utility.cpp

Expand Down
53 changes: 0 additions & 53 deletions lib/Dialect/TTIR/Transforms/Reshape.cpp

This file was deleted.

13 changes: 0 additions & 13 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ void createTTNNPipelineDeallocPass(
pm.addPass(createTTNNDeallocate());
}

void createTTIRReshapeFoldPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRReshapeFold());
}

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
Expand Down Expand Up @@ -112,16 +107,8 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTIRReshapeFoldPassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTIRReshapeFoldPass(pm, options);
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes.
module @reshape_test {
func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<1xi32>
%1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: return %arg0 : tensor<1xi32>
// CHECK: return %arg0 : tensor<1xi32, #{{.*}}>
return %1 : tensor<1xi32>
}
}

0 comments on commit 3c4d9f7

Please sign in to comment.