diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index aac8975e4..60d2b9664 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1169,9 +1169,17 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { // ANCHOR_END: adding_an_op_matmul_ttir def TTIR_PermuteOp : TTIR_DPSOp<"permute"> { - let summary = "Permute op."; + let summary = "Permute operation."; let description = [{ - Permute tensor. + Permute input tensor dimensions. + + Attributes: + - `permutation` array: The permutation of the input tensor dimensions. + + Example: + %a = tensor.empty() : () -> tensor<2x3x4xi32> + %output = tensor.empty() : () -> tensor<3x4x2xi32> + %0 = "ttir.permute"(%a, %output) {permutation = array} : (tensor<2x3x4xi32>, tensor<3x4x2xi32>) -> tensor<3x4x2xi32> }]; let arguments = (ins AnyRankedTensor:$input, diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 876f8c349..ef1808ab0 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -916,9 +916,16 @@ def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> { } def TTNN_PermuteOp : TTNN_Op<"permute"> { - let summary = "Permute op."; + let summary = "Permute operation."; let description = [{ - Permute tensor. + Permute input tensor dimensions. + + Attributes: + - `permutation` array: The permutation of the input tensor dimensions. + + Example: + %a = tensor.empty() : () -> tensor<2x3x4xi32> + %0 = "ttir.permute"(%a) {permutation = array} : (tensor<2x3x4xi32>) -> tensor<3x4x2xi32> }]; let arguments = (ins AnyRankedTensor:$input, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 04357033d..3857f5950 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -34,8 +34,6 @@ #include "llvm/Support/raw_ostream.h" #include -#include -#include #include namespace mlir::tt { diff --git a/test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir index d000b5161..8a7529936 100644 --- a/test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir @@ -15,7 +15,7 @@ module { // ----- #any_device = #tt.operand_constraint module { - func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + func.func @permute_subset_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { // CHECK: error: 'ttir.permute' op Expected a permutation of {k | 0 <= k < 3} got (0, 1) %0 = tensor.empty() : tensor<16x32x64xbf16> %1 = "ttir.permute"(%arg0, %0) <{operand_constraints = [#any_device, #any_device], permutation = array}> : (tensor<16x32x64xbf16>, tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> @@ -27,7 +27,7 @@ module { // ----- #any_device = #tt.operand_constraint module { - func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + func.func @permute_non_valid_shape(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { // CHECK: error: 'ttir.permute' op Expected result shape (16, 64, 32), got (16, 32, 64) %0 = tensor.empty() : tensor<16x32x64xbf16> %1 = "ttir.permute"(%arg0, %0) <{operand_constraints = [#any_device, #any_device], permutation = array}> : (tensor<16x32x64xbf16>, tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> diff --git a/test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir index 6838d1664..c0dfba63a 100644 --- a/test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir @@ -14,7 +14,7 @@ module { // ----- #any_device = #tt.operand_constraint module { - func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + func.func @permute_subset_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { // CHECK: error: 'ttnn.permute' op Expected a permutation of {k | 0 <= k < 3} got (0, 1) %0 = "ttnn.permute"(%arg0) <{permutation = array}> : (tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> return %0 : tensor<16x32x64xbf16> @@ -25,7 +25,7 @@ module { // ----- #any_device = #tt.operand_constraint module { - func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { + func.func @permute_non_valid_shape(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> { // CHECK: error: 'ttnn.permute' op Expected result shape (16, 64, 32), got (16, 32, 64) %0 = "ttnn.permute"(%arg0) <{permutation = array}> : (tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> return %0 : tensor<16x32x64xbf16> diff --git a/test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir b/test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir index 81ba6f483..bf514edb8 100644 --- a/test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir +++ b/test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir @@ -1,7 +1,7 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module { - func.func @forward(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> { + func.func @permute(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> { %0 = tensor.empty() : tensor<4x32x64x1xf32> // CHECK: "ttnn.permute" // CHECK-SAME: permutation = array