From b663db2d0c6d57a56c2178109e4604fad5d0ab78 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Wed, 4 Dec 2024 16:26:10 +0100 Subject: [PATCH] Fixing reshape op so it supports reshaping of scalars (#1322) --- .github/workflows/build-and-test.yml | 2 +- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 2 ++ .../StableHLOToTTIR/StableHLOToTTIRPatterns.cpp | 4 +++- lib/Dialect/TTIR/IR/TTIROps.cpp | 9 +++++++++ .../Dialect/TTNN/reshape/reshape_folding_test.mlir | 12 ++++++++++++ 5 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index c54d734b23..62cfdd9455 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -215,7 +215,7 @@ jobs: run-tests: - timeout-minutes: 30 + timeout-minutes: 45 needs: - build-image - build-ttmlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 60d2b96643..b4881175ff 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -884,6 +884,8 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { }]; let hasVerifier = 1; + + let hasFolder = 1; } def TTIR_SliceOp: TTIR_DPSOp<"slice"> { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 6a3cfbf89c..c2b83bb542 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -121,7 +121,9 @@ class StableHLOToTTIRReduceOpConversionPattern srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector( - 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0]))); + 1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr().size() > 0 + ? adaptor.getDimensionsAttr()[0] + : 1))); // If someone changes definition of TTIR_ReductionOp this constant will // become outdated, but I currently see no way to get this info (without diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 2f08c0d414..73cda8fd57 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -390,6 +390,15 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() { return success(); } +// ReshapeOp folder +::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) { + + if (getType() == getOperand(0).getType()) { + return getOperand(0); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir new file mode 100644 index 0000000000..c7f4442f0b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s +#any_device_tile = #tt.operand_constraint +// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. +module @reshape_test { + func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { + %0 = tensor.empty() : tensor<1xi32> + %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NOT: %[[C:.*]] = "ttnn.reshape"[C:.*]] + // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> + return %1 : tensor<1xi32> + } +}