diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 489fd2faa9..506d07ce09 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -899,6 +899,8 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { }]; let hasVerifier = 1; + + let hasFolder = 1; } def TTIR_SliceOp: TTIR_DPSOp<"slice"> { diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index cc11551d7d..63ccb0d28a 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,11 +112,4 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } -def TTIRReshapeFold: Pass<"ttir-reshape-fold", "::mlir::ModuleOp"> { - let summary = "Folds reshape ops that reshape the same shapes."; - let description = [{ - This pass converts folds the ttir.reshape ops that are called to reshape a tensor to the same shape. - }]; -} - #endif diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index ec415e090b..feedc845cc 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -358,6 +358,15 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return success(); } +// ReshapeOp folder +::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) { + + if (getType() == getOperand(0).getType()) { + return getOperand(0); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index 5dab548297..f5fec45a8b 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms Constant.cpp Generic.cpp Layout.cpp - Reshape.cpp Transforms.cpp Utility.cpp diff --git a/lib/Dialect/TTIR/Transforms/Reshape.cpp b/lib/Dialect/TTIR/Transforms/Reshape.cpp deleted file mode 100644 index 3a6b96cd7f..0000000000 --- a/lib/Dialect/TTIR/Transforms/Reshape.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" - -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include - -namespace mlir::tt::ttir { -#define GEN_PASS_DEF_TTIRRESHAPEFOLD -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" - -//===----------------------------------------------------------------------===// -// Reshape folding pass -//===----------------------------------------------------------------------===// - -class TTIRReshapeFoldingRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReshapeOp op, - PatternRewriter &rewriter) const final { - if (op.getType() == op->getOperand(0).getType()) { - rewriter.replaceOp(op, op->getOperand(0)); - return success(); - } - return failure(); - } -}; - -class TTIRReshapeFold : public impl::TTIRReshapeFoldBase { -public: - using impl::TTIRReshapeFoldBase::TTIRReshapeFoldBase; - - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { - signalPassFailure(); - return; - } - } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - } -}; - -} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 88040bd64e..24980fb7c0 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -72,11 +72,6 @@ void createTTNNPipelineDeallocPass( pm.addPass(createTTNNDeallocate()); } -void createTTIRReshapeFoldPass( - OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - pm.addPass(mlir::tt::ttir::createTTIRReshapeFold()); -} - void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, std::string options) { auto optionsStruct = @@ -112,16 +107,8 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } -void createTTIRReshapeFoldPassFromString(OpPassManager &pm, - std::string options) { - auto optionsStruct = - TTIRToTTNNBackendPipelineOptions::createFromString(options); - createTTNNPipelineDeallocPass(pm, *optionsStruct); -} - void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - createTTIRReshapeFoldPass(pm, options); createTTNNPipelineTTIRPasses(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineAnalysisPasses(pm, options); diff --git a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir similarity index 81% rename from test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir rename to test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir index a6ac203f4c..b33e35f319 100644 --- a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir +++ b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir @@ -1,11 +1,11 @@ -// RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s #any_device_tile = #tt.operand_constraint // Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. module @reshape_test { func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { %0 = tensor.empty() : tensor<1xi32> %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - // CHECK: return %arg0 : tensor<1xi32> + // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> return %1 : tensor<1xi32> } }