diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index c54d734b2..62cfdd945 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -215,7 +215,7 @@ jobs: run-tests: - timeout-minutes: 30 + timeout-minutes: 45 needs: - build-image - build-ttmlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 53976e56a..3c53a156a 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -884,6 +884,8 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { }]; let hasVerifier = 1; + + let hasFolder = 1; } def TTIR_SliceOp: TTIR_DPSOp<"slice"> { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 4b08f7b6e..96ef7ca01 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -121,7 +121,9 @@ class StableHLOToTTIRReduceOpConversionPattern srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector( - 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))); + 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr().size() > 0 + ? adaptor.getDimensionsAttr()[0] + : 1))); // If someone changes definition of TTIR_ReductionOp this constant will // become outdated, but I currently see no way to get this info (without diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index a3a6dd586..bc1f02868 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -389,6 +389,15 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return success(); } +// ReshapeOp folder +::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) { + + if (getType() == getOperand(0).getType()) { + return getOperand(0); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir new file mode 100644 index 000000000..c7f4442f0 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s +#any_device_tile = #tt.operand_constraint +// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. +module @reshape_test { + func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1xi32> + %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NOT: %[[C:.*]] = "ttnn.reshape"[C:.*]] + // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> + return %1 : tensor<1xi32> + } +}