Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing reshape op so it supports reshaping of scalars #1322

Merged
merged 15 commits into from
Dec 4, 2024
Merged
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,11 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
];
}

def TTIRReshapeFold: Pass<"ttir-reshape-fold", "::mlir::ModuleOp"> {
let summary = "Folds reshape ops that reshape the same shapes.";
let description = [{
This pass converts folds the ttir.reshape ops that are called to reshape a tensor to the same shape.
}];
}

#endif
8 changes: 6 additions & 2 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ class StableHLOToTTIRReduceOpConversionPattern
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector<Attribute>(
1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0])));
mlir::ArrayAttr dimArg =
adaptor.getDimensionsAttr().size() > 0
? rewriter.getArrayAttr(SmallVector<Attribute>(
1,
rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0])))
: 0;

// If someone changes definition of TTIR_ReductionOp this constant will
// become outdated, but I currently see no way to get this info (without
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms
Constant.cpp
Generic.cpp
Layout.cpp
Reshape.cpp
Transforms.cpp
Utility.cpp

Expand Down
53 changes: 53 additions & 0 deletions lib/Dialect/TTIR/Transforms/Reshape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRRESHAPEFOLD
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Reshape folding pass
//===----------------------------------------------------------------------===//
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved

class TTIRReshapeFoldingRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const final {
if (op.getType() == op->getOperand(0).getType()) {
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
return failure();
}
};

class TTIRReshapeFold : public impl::TTIRReshapeFoldBase<TTIRReshapeFold> {
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
public:
using impl::TTIRReshapeFoldBase<TTIRReshapeFold>::TTIRReshapeFoldBase;

void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRReshapeFoldingRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
}

void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::tt::ttir::TTIRDialect>();
registry.insert<mlir::tt::TTDialect>();
}
};

} // namespace mlir::tt::ttir
13 changes: 13 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ void createTTNNPipelineDeallocPass(
pm.addPass(createTTNNDeallocate());
}

void createTTIRReshapeFoldPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRReshapeFold());
}

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
Expand Down Expand Up @@ -107,8 +112,16 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTIRReshapeFoldPassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTIRReshapeFoldPass(pm, options);
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes.
module @reshape_test {
ajakovljevicTT marked this conversation as resolved.
Show resolved Hide resolved
func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<1xi32>
%1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: return %arg0 : tensor<1xi32>
return %1 : tensor<1xi32>
}
}
Loading