Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Nov 25, 2024
1 parent 850686a commit 03d62a5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lib/Dialect/TTIR/Transforms/Reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace mlir::tt::ttir {
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Constant as fill pass
// Reshape folding pass
//===----------------------------------------------------------------------===//

class TTIRReshapeFoldingRewriter : public OpRewritePattern<ReshapeOp> {
Expand Down
7 changes: 4 additions & 3 deletions test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>
module @jit_ravel attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) {
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes.
module @reshape_test {
func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) {
%0 = tensor.empty() : tensor<1xi32>
%1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: return %arg0 : tensor<1xi32>
Expand Down

0 comments on commit 03d62a5

Please sign in to comment.