Skip to content

Commit

Permalink
Fixing reshape op so it supports reshaping of scalars (#1322)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT authored Dec 4, 2024
1 parent 0aab97b commit 824b256
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ jobs:

run-tests:

timeout-minutes: 30
timeout-minutes: 45
needs:
- build-image
- build-ttmlir
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,8 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {
}];

let hasVerifier = 1;

let hasFolder = 1;
}

def TTIR_SliceOp: TTIR_DPSOp<"slice"> {
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ class StableHLOToTTIRReduceOpConversionPattern
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector<Attribute>(
1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr()[0])));
1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr().size() > 0
? adaptor.getDimensionsAttr()[0]
: 1)));

// If someone changes definition of TTIR_ReductionOp this constant will
// become outdated, but I currently see no way to get this info (without
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,15 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() {
return success();
}

// ReshapeOp folder
::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) {

if (getType() == getOperand(0).getType()) {
return getOperand(0);
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes.
module @reshape_test {
func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<1xi32>
%1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK-NOT: %[[C:.*]] = "ttnn.reshape"[C:.*]]
// CHECK: return %arg0 : tensor<1xi32, #{{.*}}>
return %1 : tensor<1xi32>
}
}

0 comments on commit 824b256

Please sign in to comment.