From 03d62a50cfeace70a33fda3b27cc4debf02f1586 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Thu, 21 Nov 2024 08:36:37 +0000 Subject: [PATCH] Addressed comments --- lib/Dialect/TTIR/Transforms/Reshape.cpp | 2 +- test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TTIR/Transforms/Reshape.cpp b/lib/Dialect/TTIR/Transforms/Reshape.cpp index 9f804dae8..3a6b96cd7 100644 --- a/lib/Dialect/TTIR/Transforms/Reshape.cpp +++ b/lib/Dialect/TTIR/Transforms/Reshape.cpp @@ -13,7 +13,7 @@ namespace mlir::tt::ttir { #include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// -// Constant as fill pass +// Reshape folding pass //===----------------------------------------------------------------------===// class TTIRReshapeFoldingRewriter : public OpRewritePattern { diff --git a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir index 8d7944869..a6ac203f4 100644 --- a/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir +++ b/test/ttmlir/Dialect/TTIR/reshape/reshape_test.mlir @@ -1,7 +1,8 @@ // RUN: ttmlir-opt --ttir-reshape-fold %s| FileCheck %s -#any_device_tile = #tt.operand_constraint -module @jit_ravel attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { +#any_device_tile = #tt.operand_constraint +// Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. +module @reshape_test { + func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { %0 = tensor.empty() : tensor<1xi32> %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: return %arg0 : tensor<1xi32>