Skip to content

Commit

Permalink
Changed to a transformation pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Nov 20, 2024
1 parent db3507c commit ca82cbe
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 17 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,11 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
];
}

def TTIRReshapeFold: Pass<"ttir-reshape-fold", "::mlir::ModuleOp"> {
let summary = "Folds reshape ops that reshape the same shapes.";
let description = [{
This pass converts folds the ttir.reshape ops that are called to reshape a tensor to the same shape.
}];
}

#endif
8 changes: 6 additions & 2 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ class StableHLOToTTIRReduceOpConversionPattern
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg = adaptor.getDimensionsAttr().size() > 0 ? rewriter.getArrayAttr(SmallVector<Attribute>(
1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))) : 0;
mlir::ArrayAttr dimArg =
adaptor.getDimensionsAttr().size() > 0
? rewriter.getArrayAttr(SmallVector<Attribute>(
1,
rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0])))
: 0;

// If someone changes definition of TTIR_ReductionOp this constant will
// become outdated, but I currently see no way to get this info (without
Expand Down
15 changes: 0 additions & 15 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,21 +775,6 @@ class GetDimensionSizeToConstantConversionPattern
}
};

class ReshapeOpOmissionPattern : public OpConversionPattern<ttir::ReshapeOp> {
public:
using OpConversionPattern<ttir::ReshapeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getType() == op->getOperand(0).getType()) {
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
return failure();
}
};

void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms
Constant.cpp
Generic.cpp
Layout.cpp
Reshape.cpp
Transforms.cpp
Utility.cpp

Expand Down
53 changes: 53 additions & 0 deletions lib/Dialect/TTIR/Transforms/Reshape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRRESHAPEFOLD
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Constant as fill pass
//===----------------------------------------------------------------------===//

class TTIRReshapeFoldingRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const final {
if (op.getType() == op->getOperand(0).getType()) {
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
return failure();
}
};

class TTIRReshapeFold : public impl::TTIRReshapeFoldBase<TTIRReshapeFold> {
public:
using impl::TTIRReshapeFoldBase<TTIRReshapeFold>::TTIRReshapeFoldBase;

void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRReshapeFoldingRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
}

void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::tt::ttir::TTIRDialect>();
registry.insert<mlir::tt::TTDialect>();
}
};

} // namespace mlir::tt::ttir
13 changes: 13 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ void createTTNNPipelineDeallocPass(
pm.addPass(createTTNNDeallocate());
}

void createTTIRReshapeFoldPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRReshapeFold());
}

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
Expand Down Expand Up @@ -107,8 +112,16 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTIRReshapeFoldPassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTIRReshapeFoldPass(pm, options);
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
Expand Down

0 comments on commit ca82cbe

Please sign in to comment.