From 19b2cee389b02b3497021cebc6e081b70f4ed34c Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Mon, 18 Nov 2024 12:44:10 +0000 Subject: [PATCH 01/13] Reshape fix --- .../StableHLOToTTIR/StableHLOToTTIRPatterns.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 1ec8556cf..b65beefee 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -19,6 +19,7 @@ #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include +#include #include #include #include @@ -220,6 +221,11 @@ class StableHLOToTTIRReshapeOpConversionPattern SmallVector(adaptor.getOperands().size() + 1, rewriter.getAttr( OperandConstraint::AnyDeviceTile)))); + checkForArgumentsAndReplace(srcOp, new_reshape_op); + rewriter.replaceOp(srcOp, new_reshape_op); + if (new_reshape_op.getType() == new_reshape_op->getOperand(0).getType()) { + rewriter.replaceOp(new_reshape_op, new_reshape_op->getOperand(0)); + } return success(); } @@ -235,6 +241,17 @@ class StableHLOToTTIRReshapeOpConversionPattern return success(); } + +private: + void + checkForArgumentsAndReplace(mlir::stablehlo::ReshapeOp &srcOp, + mlir::tt::ttir::ReshapeOp &newReshapeOp) const { + if (auto blockArg = mlir::cast(srcOp->getOperand(0))) { + newReshapeOp.setOperand( + 0, srcOp->getParentOfType().getArgument( + blockArg.getArgNumber())); + } + } }; class StableHLOToTTIRDotGeneralOpConversionPattern From 31bea8c165384f765648d7d2353818a409c368f1 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 10:07:18 +0000 Subject: [PATCH 02/13] small fixes --- .../StableHLOToTTIR/StableHLOToTTIRPatterns.cpp | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index b65beefee..29f431415 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -214,15 +214,13 @@ class StableHLOToTTIRReshapeOpConversionPattern new_shape_i32.push_back(static_cast(dim)); } ArrayAttr new_shape_attr = rewriter.getI32ArrayAttr(new_shape_i32); - rewriter.replaceOpWithNewOp( + auto new_reshape_op = rewriter.replaceOpWithNewOp( srcOp, getTypeConverter()->convertType(outputTensor.getType()), adaptor.getOperand(), outputTensor, new_shape_attr, rewriter.getArrayAttr( SmallVector(adaptor.getOperands().size() + 1, rewriter.getAttr( OperandConstraint::AnyDeviceTile)))); - checkForArgumentsAndReplace(srcOp, new_reshape_op); - rewriter.replaceOp(srcOp, new_reshape_op); if (new_reshape_op.getType() == new_reshape_op->getOperand(0).getType()) { rewriter.replaceOp(new_reshape_op, new_reshape_op->getOperand(0)); } @@ -241,17 +239,6 @@ class StableHLOToTTIRReshapeOpConversionPattern return success(); } - -private: - void - checkForArgumentsAndReplace(mlir::stablehlo::ReshapeOp &srcOp, - mlir::tt::ttir::ReshapeOp &newReshapeOp) const { - if (auto blockArg = mlir::cast(srcOp->getOperand(0))) { - newReshapeOp.setOperand( - 0, srcOp->getParentOfType().getArgument( - blockArg.getArgNumber())); - } - } }; class StableHLOToTTIRDotGeneralOpConversionPattern From 20559f6e3f7297b4edc3b6375dddf393072a8215 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 11:37:52 +0000 Subject: [PATCH 03/13] Addressed comments --- .../StableHLOToTTIRPatterns.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 29f431415..d7320aaca 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -19,7 +19,6 @@ #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include -#include #include #include #include @@ -214,13 +213,17 @@ class StableHLOToTTIRReshapeOpConversionPattern new_shape_i32.push_back(static_cast(dim)); } ArrayAttr new_shape_attr = rewriter.getI32ArrayAttr(new_shape_i32); - auto new_reshape_op = rewriter.replaceOpWithNewOp( - srcOp, getTypeConverter()->convertType(outputTensor.getType()), - adaptor.getOperand(), outputTensor, new_shape_attr, - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + auto new_reshape_op = + rewriter.replaceOpWithNewOp( + srcOp, getTypeConverter()->convertType(outputTensor.getType()), + adaptor.getOperand(), outputTensor, new_shape_attr, + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + // If the reshape op is trying to reshape into the same shape, we can omit + // it completely. if (new_reshape_op.getType() == new_reshape_op->getOperand(0).getType()) { rewriter.replaceOp(new_reshape_op, new_reshape_op->getOperand(0)); } From 949ed143f6ff7b26151f5857ebe41e3f65cf1a76 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 15:19:02 +0000 Subject: [PATCH 04/13] Added the pass to the TTIR decomposition --- .../StableHLOToTTIRPatterns.cpp | 20 +++++++------------ .../TTIRToTTIRDecomposition.cpp | 15 ++++++++++++++ .../TTIRToTTIRDecompositionPass.cpp | 1 - 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index d7320aaca..15fa17792 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -213,20 +213,14 @@ class StableHLOToTTIRReshapeOpConversionPattern new_shape_i32.push_back(static_cast(dim)); } ArrayAttr new_shape_attr = rewriter.getI32ArrayAttr(new_shape_i32); - auto new_reshape_op = - rewriter.replaceOpWithNewOp( - srcOp, getTypeConverter()->convertType(outputTensor.getType()), - adaptor.getOperand(), outputTensor, new_shape_attr, - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + rewriter.replaceOpWithNewOp( + srcOp, getTypeConverter()->convertType(outputTensor.getType()), + adaptor.getOperand(), outputTensor, new_shape_attr, + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); - // If the reshape op is trying to reshape into the same shape, we can omit - // it completely. - if (new_reshape_op.getType() == new_reshape_op->getOperand(0).getType()) { - rewriter.replaceOp(new_reshape_op, new_reshape_op->getOperand(0)); - } return success(); } diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 9b8c634ad..9c89dbf41 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -775,6 +775,21 @@ class GetDimensionSizeToConstantConversionPattern } }; +class ReshapeOpOmissionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getType() == op->getOperand(0).getType()) { + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } + return failure(); + } +}; + void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index 76cbae96e..18b59ede9 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -50,7 +50,6 @@ struct TTIRToTTIRDecompositionPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); TypeConverter typeConverter; // All types map 1:1. From db3507c1316c161637f6609fdcf76a501b8ffbc0 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 17:25:48 +0000 Subject: [PATCH 05/13] Add pattern change --- lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 15fa17792..64872da41 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -119,8 +119,8 @@ class StableHLOToTTIRReduceOpConversionPattern tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector( - 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))); + mlir::ArrayAttr dimArg = adaptor.getDimensionsAttr().size() > 0 ? rewriter.getArrayAttr(SmallVector( + 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))) : 0; // If someone changes definition of TTIR_ReductionOp this constant will // become outdated, but I currently see no way to get this info (without From ca82cbe410c9d72f07b7a64341f5249f70bd58ed Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Wed, 20 Nov 2024 12:08:34 +0000 Subject: [PATCH 06/13] Changed to a transformation pass --- .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 7 +++ .../StableHLOToTTIRPatterns.cpp | 8 ++- .../TTIRToTTIRDecomposition.cpp | 15 ------ lib/Dialect/TTIR/Transforms/CMakeLists.txt | 1 + lib/Dialect/TTIR/Transforms/Reshape.cpp | 53 +++++++++++++++++++ lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 13 +++++ 6 files changed, 80 insertions(+), 17 deletions(-) create mode 100644 lib/Dialect/TTIR/Transforms/Reshape.cpp diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 63ccb0d28..cc11551d7 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,4 +112,11 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } +def TTIRReshapeFold: Pass<"ttir-reshape-fold", "::mlir::ModuleOp"> { + let summary = "Folds reshape ops that reshape the same shapes."; + let description = [{ + This pass converts folds the ttir.reshape ops that are called to reshape a tensor to the same shape. + }]; +} + #endif diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 64872da41..be6eec4f5 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -119,8 +119,12 @@ class StableHLOToTTIRReduceOpConversionPattern tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - mlir::ArrayAttr dimArg = adaptor.getDimensionsAttr().size() > 0 ? rewriter.getArrayAttr(SmallVector( - 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))) : 0; + mlir::ArrayAttr dimArg = + adaptor.getDimensionsAttr().size() > 0 + ? rewriter.getArrayAttr(SmallVector( + 1, + rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))) + : 0; // If someone changes definition of TTIR_ReductionOp this constant will // become outdated, but I currently see no way to get this info (without diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 9c89dbf41..9b8c634ad 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -775,21 +775,6 @@ class GetDimensionSizeToConstantConversionPattern } }; -class ReshapeOpOmissionPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ttir::ReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getType() == op->getOperand(0).getType()) { - rewriter.replaceOp(op, op->getOperand(0)); - return success(); - } - return failure(); - } -}; - void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index f5fec45a8..5dab54829 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms Constant.cpp Generic.cpp Layout.cpp + Reshape.cpp Transforms.cpp Utility.cpp diff --git a/lib/Dialect/TTIR/Transforms/Reshape.cpp b/lib/Dialect/TTIR/Transforms/Reshape.cpp new file mode 100644 index 000000000..9f804dae8 --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/Reshape.cpp @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRRESHAPEFOLD +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Constant as fill pass +//===----------------------------------------------------------------------===// + +class TTIRReshapeFoldingRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReshapeOp op, + PatternRewriter &rewriter) const final { + if (op.getType() == op->getOperand(0).getType()) { + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } + return failure(); + } +}; + +class TTIRReshapeFold : public impl::TTIRReshapeFoldBase { +public: + using impl::TTIRReshapeFoldBase::TTIRReshapeFoldBase; + + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 24980fb7c..88040bd64 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -72,6 +72,11 @@ void createTTNNPipelineDeallocPass( pm.addPass(createTTNNDeallocate()); } +void createTTIRReshapeFoldPass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + pm.addPass(mlir::tt::ttir::createTTIRReshapeFold()); +} + void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, std::string options) { auto optionsStruct = @@ -107,8 +112,16 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } +void createTTIRReshapeFoldPassFromString(OpPassManager &pm, + std::string options) { + auto optionsStruct = + TTIRToTTNNBackendPipelineOptions::createFromString(options); + createTTNNPipelineDeallocPass(pm, *optionsStruct); +} + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + createTTIRReshapeFoldPass(pm, options); createTTNNPipelineTTIRPasses(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineAnalysisPasses(pm, options); From d6e5e812062537cb1a288f824243e79e3b156c54 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Wed, 20 Nov 2024 13:59:25 +0000 Subject: [PATCH 07/13] Added test --- .../StableHLOToTTIR/StableHLOToTTIRPatterns.cpp | 1 - test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir | 11 +++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index be6eec4f5..9db51169c 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -224,7 +224,6 @@ class StableHLOToTTIRReshapeOpConversionPattern SmallVector(adaptor.getOperands().size() + 1, rewriter.getAttr( OperandConstraint::AnyDeviceTile)))); - return success(); } diff --git a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir new file mode 100644 index 000000000..d2a740ca1 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s +#any_device_tile = #tt.operand_constraint +module @jit_ravel attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1xi32> + %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: return %arg0 : tensor<1xi32> + return %1 : tensor<1xi32> + } +} + From 072a90ed09ffe72af4c506c4f5f4a87137da297f Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Wed, 20 Nov 2024 14:16:14 +0000 Subject: [PATCH 08/13] Formatting --- test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir index d2a740ca1..8d7944869 100644 --- a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir +++ b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir @@ -8,4 +8,3 @@ module @jit_ravel attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = return %1 : tensor<1xi32> } } - From 850686a8e39a7748b28b4da1486f793c7d20541e Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Wed, 20 Nov 2024 14:18:33 +0000 Subject: [PATCH 09/13] Address comments --- .../TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index 18b59ede9..76cbae96e 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -50,6 +50,7 @@ struct TTIRToTTIRDecompositionPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); TypeConverter typeConverter; // All types map 1:1. From 03d62a50cfeace70a33fda3b27cc4debf02f1586 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Thu, 21 Nov 2024 08:36:37 +0000 Subject: [PATCH 10/13] Addressed comments --- lib/Dialect/TTIR/Transforms/Reshape.cpp | 2 +- test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TTIR/Transforms/Reshape.cpp b/lib/Dialect/TTIR/Transforms/Reshape.cpp index 9f804dae8..3a6b96cd7 100644 --- a/lib/Dialect/TTIR/Transforms/Reshape.cpp +++ b/lib/Dialect/TTIR/Transforms/Reshape.cpp @@ -13,7 +13,7 @@ namespace mlir::tt::ttir { #include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// -// Constant as fill pass +// Reshape folding pass //===----------------------------------------------------------------------===// class TTIRReshapeFoldingRewriter : public OpRewritePattern { diff --git a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir index 8d7944869..a6ac203f4 100644 --- a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir +++ b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir @@ -1,7 +1,8 @@ // RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s -#any_device_tile = #tt.operand_constraint -module @jit_ravel attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { +#any_device_tile = #tt.operand_constraint +// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. +module @reshape_test { + func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { %0 = tensor.empty() : tensor<1xi32> %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: return %arg0 : tensor<1xi32> From 57364dc534085f874c75e32b393ad24a1904a1e6 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 29 Nov 2024 16:25:21 +0000 Subject: [PATCH 11/13] Added folding --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 2 + .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 7 --- lib/Dialect/TTIR/IR/TTIROps.cpp | 9 ++++ lib/Dialect/TTIR/Transforms/CMakeLists.txt | 1 - lib/Dialect/TTIR/Transforms/Reshape.cpp | 53 ------------------- lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 13 ----- .../reshape/reshape_folding_test.mlir} | 5 +- 7 files changed, 14 insertions(+), 76 deletions(-) delete mode 100644 lib/Dialect/TTIR/Transforms/Reshape.cpp rename test/ttmlir/Dialect/{TTIR/reshape/reshape_test.mlir => TTNN/reshape/reshape_folding_test.mlir} (75%) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 489fd2faa..506d07ce0 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -899,6 +899,8 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { }]; let hasVerifier = 1; + + let hasFolder = 1; } def TTIR_SliceOp: TTIR_DPSOp<"slice"> { diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index cc11551d7..63ccb0d28 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,11 +112,4 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } -def TTIRReshapeFold: Pass<"ttir-reshape-fold", "::mlir::ModuleOp"> { - let summary = "Folds reshape ops that reshape the same shapes."; - let description = [{ - This pass converts folds the ttir.reshape ops that are called to reshape a tensor to the same shape. - }]; -} - #endif diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index ec415e090..feedc845c 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -358,6 +358,15 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return success(); } +// ReshapeOp folder +::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) { + + if (getType() == getOperand(0).getType()) { + return getOperand(0); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index 5dab54829..f5fec45a8 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms Constant.cpp Generic.cpp Layout.cpp - Reshape.cpp Transforms.cpp Utility.cpp diff --git a/lib/Dialect/TTIR/Transforms/Reshape.cpp b/lib/Dialect/TTIR/Transforms/Reshape.cpp deleted file mode 100644 index 3a6b96cd7..000000000 --- a/lib/Dialect/TTIR/Transforms/Reshape.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" - -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include - -namespace mlir::tt::ttir { -#define GEN_PASS_DEF_TTIRRESHAPEFOLD -#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" - -//===----------------------------------------------------------------------===// -// Reshape folding pass -//===----------------------------------------------------------------------===// - -class TTIRReshapeFoldingRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReshapeOp op, - PatternRewriter &rewriter) const final { - if (op.getType() == op->getOperand(0).getType()) { - rewriter.replaceOp(op, op->getOperand(0)); - return success(); - } - return failure(); - } -}; - -class TTIRReshapeFold : public impl::TTIRReshapeFoldBase { -public: - using impl::TTIRReshapeFoldBase::TTIRReshapeFoldBase; - - void runOnOperation() final { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { - signalPassFailure(); - return; - } - } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - } -}; - -} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 88040bd64..24980fb7c 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -72,11 +72,6 @@ void createTTNNPipelineDeallocPass( pm.addPass(createTTNNDeallocate()); } -void createTTIRReshapeFoldPass( - OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - pm.addPass(mlir::tt::ttir::createTTIRReshapeFold()); -} - void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm, std::string options) { auto optionsStruct = @@ -112,16 +107,8 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } -void createTTIRReshapeFoldPassFromString(OpPassManager &pm, - std::string options) { - auto optionsStruct = - TTIRToTTNNBackendPipelineOptions::createFromString(options); - createTTNNPipelineDeallocPass(pm, *optionsStruct); -} - void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - createTTIRReshapeFoldPass(pm, options); createTTNNPipelineTTIRPasses(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineAnalysisPasses(pm, options); diff --git a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir similarity index 75% rename from test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir rename to test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir index a6ac203f4..48b93444c 100644 --- a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir +++ b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir @@ -1,11 +1,12 @@ -// RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s #any_device_tile = #tt.operand_constraint // Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. module @reshape_test { func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { %0 = tensor.empty() : tensor<1xi32> %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - // CHECK: return %arg0 : tensor<1xi32> + // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> + // CHECK-NOT: %[[C:.*]] = "ttnn.reshape"[C:.*]] return %1 : tensor<1xi32> } } From 14494e4d8887e5f60193b17ec288ecce414fc21b Mon Sep 17 00:00:00 2001 From: Vladimir Milosevic <157983820+vmilosevic@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:09:33 +0100 Subject: [PATCH 12/13] Increase timeout to 45min --- .github/workflows/build-and-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 8ec0c93dc..e0cc7fb97 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -215,7 +215,7 @@ jobs: run-tests: - timeout-minutes: 30 + timeout-minutes: 45 needs: - build-image - build-ttmlir From 2eae67d57fe0a0f97ea16de55e5aa4e719fdca3e Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Wed, 4 Dec 2024 13:12:07 +0000 Subject: [PATCH 13/13] Addressed comments --- test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir index 48b93444c..c7f4442f0 100644 --- a/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir +++ b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir @@ -5,8 +5,8 @@ module @reshape_test { func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { %0 = tensor.empty() : tensor<1xi32> %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> // CHECK-NOT: %[[C:.*]] = "ttnn.reshape"[C:.*]] + // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> return %1 : tensor<1xi32> } }