Skip to content

Commit

Permalink
Improved docs of ttir.permute and ttnn.permute
Browse files Browse the repository at this point in the history
  • Loading branch information
azecevicTT committed Dec 4, 2024
1 parent 408b36b commit ab5e243
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 11 deletions.
12 changes: 10 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1169,9 +1169,17 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
// ANCHOR_END: adding_an_op_matmul_ttir

def TTIR_PermuteOp : TTIR_DPSOp<"permute"> {
let summary = "Permute op.";
let summary = "Permute operation.";
let description = [{
Permute tensor.
Permute input tensor dimensions.

Attributes:
- `permutation` array<i64>: The permutation of the input tensor dimensions.

Example:
%a = tensor.empty() : () -> tensor<2x3x4xi32>
%output = tensor.empty() : () -> tensor<3x4x2xi32>
%0 = "ttir.permute"(%a, %output) {permutation = array<i64: 1, 2, 0>} : (tensor<2x3x4xi32>, tensor<3x4x2xi32>) -> tensor<3x4x2xi32>
}];

let arguments = (ins AnyRankedTensor:$input,
Expand Down
11 changes: 9 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -916,9 +916,16 @@ def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> {
}

def TTNN_PermuteOp : TTNN_Op<"permute"> {
let summary = "Permute op.";
let summary = "Permute operation.";
let description = [{
Permute tensor.
Permute input tensor dimensions.

Attributes:
- `permutation` array<i64>: The permutation of the input tensor dimensions.

Example:
%a = tensor.empty() : () -> tensor<2x3x4xi32>
%0 = "ttir.permute"(%a) {permutation = array<i64: 1, 2, 0>} : (tensor<2x3x4xi32>) -> tensor<3x4x2xi32>
}];

let arguments = (ins AnyRankedTensor:$input,
Expand Down
2 changes: 0 additions & 2 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
#include "llvm/Support/raw_ostream.h"

#include <cassert>
#include <fstream>
#include <mlir/Support/LLVM.h>
#include <optional>

namespace mlir::tt {
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTIR/permute/permute_tests_negative.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module {
// -----
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module {
func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
func.func @permute_subset_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
// CHECK: error: 'ttir.permute' op Expected a permutation of {k | 0 <= k < 3} got (0, 1)
%0 = tensor.empty() : tensor<16x32x64xbf16>
%1 = "ttir.permute"(%arg0, %0) <{operand_constraints = [#any_device, #any_device], permutation = array<i64: 0, 1>}> : (tensor<16x32x64xbf16>, tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16>
Expand All @@ -27,7 +27,7 @@ module {
// -----
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module {
func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
func.func @permute_non_valid_shape(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
// CHECK: error: 'ttir.permute' op Expected result shape (16, 64, 32), got (16, 32, 64)
%0 = tensor.empty() : tensor<16x32x64xbf16>
%1 = "ttir.permute"(%arg0, %0) <{operand_constraints = [#any_device, #any_device], permutation = array<i64: 0, 2, 1>}> : (tensor<16x32x64xbf16>, tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16>
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTNN/permute/permute_tests_negative.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module {
// -----
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module {
func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
func.func @permute_subset_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
// CHECK: error: 'ttnn.permute' op Expected a permutation of {k | 0 <= k < 3} got (0, 1)
%0 = "ttnn.permute"(%arg0) <{permutation = array<i64: 0, 1>}> : (tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16>
return %0 : tensor<16x32x64xbf16>
Expand All @@ -25,7 +25,7 @@ module {
// -----
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module {
func.func @permute_non_valid_permutation(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
func.func @permute_non_valid_shape(%arg0: tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16> {
// CHECK: error: 'ttnn.permute' op Expected result shape (16, 64, 32), got (16, 32, 64)
%0 = "ttnn.permute"(%arg0) <{permutation = array<i64: 0, 2, 1>}> : (tensor<16x32x64xbf16>) -> tensor<16x32x64xbf16>
return %0 : tensor<16x32x64xbf16>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/permute/simple_permute.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module {
func.func @forward(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> {
func.func @permute(%arg0: tensor<1x4x32x64xf32>) -> tensor<4x32x64x1xf32> {
%0 = tensor.empty() : tensor<4x32x64x1xf32>
// CHECK: "ttnn.permute"
// CHECK-SAME: permutation = array<i64: 1, 2, 3, 0>
Expand Down

0 comments on commit ab5e243

Please sign in to comment.