diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index aeb2de1aed..50de8824dd 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -172,8 +172,12 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> { // TTIR top level named ops //===----------------------------------------------------------------------===// +def 2Operands : ParamNativeOpTrait<"NOperands", "2">; +def 3Operands : ParamNativeOpTrait<"NOperands", "3">; +def 4Operands : ParamNativeOpTrait<"NOperands", "4">; + class TTIR_ElementwiseOp traits = []> : - TTIR_DPSOp { + TTIR_DPSOp { let description = [{ Base class for elementwise operations. Elementwise operations can take inputs with different shape, @@ -187,7 +191,7 @@ class TTIR_ElementwiseOp traits = []> : } class TTIR_ElementwiseTernaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise ternary op."; let description = [{ Eltwise ternary op. @@ -210,7 +214,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> { } class TTIR_ElementwiseUnaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise unary op."; let description = [{ Eltwise unary op. @@ -424,7 +428,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> { } class TTIR_ElementwiseBinaryOp traits = []> : - TTIR_ElementwiseOp { + TTIR_ElementwiseOp { let summary = "Eltwise binary op."; let description = [{ Eltwise binary op. @@ -1196,11 +1200,10 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { - assert(getNumOperands() == 2 && "Input and output operand must have the same rank"); - assert(sameRank(getOperands()) && - "Elementwise unary op must have only one input and one output operand."); + assert(sameRank(getOperation()->getOperands()) && + "Input and output operand must have the same rank"); - auto rank = mlir::cast(getOperand(0).getType()).getRank(); + auto rank = mlir::cast(getOperation()->getOperand(0).getType()).getRank(); SmallVector indexingMaps(2, builder.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( @@ -1209,19 +1212,6 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : return {builder.getAffineMapArrayAttr(indexingMaps), builder.getArrayAttr(iteratorTypes)}; } - - static bool sameRank(mlir::OperandRange operands) { - if (operands.empty()) { - return true; - } - auto rank = mlir::cast(operands[0].getType()).getRank(); - for (auto operand : operands) { - if (mlir::cast(operand.getType()).getRank() != rank) { - return false; - } - } - return true; - } }]; } @@ -1241,29 +1231,16 @@ class TTIR_GenericElementwiseBinaryOp traits = []> void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { - assert(sameRank(getOperands()) && + assert(sameRank(getOperation()->getOperands()) && "For now all operands must have the same rank"); - auto rank = mlir::cast(getOperand(0).getType()).getRank(); - SmallVector indexingMaps(getNumOperands(), + auto rank = mlir::cast(getOperation()->getOperand(0).getType()).getRank(); + SmallVector indexingMaps(getOperation()->getNumOperands(), builder.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( rank, builder.getAttr(IteratorType::Parallel)); return {builder.getAffineMapArrayAttr(indexingMaps), builder.getArrayAttr(iteratorTypes)}; } - - static bool sameRank(mlir::OperandRange operands) { - if (operands.empty()) { - return true; - } - auto rank = mlir::cast(operands[0].getType()).getRank(); - for (auto operand : operands) { - if (mlir::cast(operand.getType()).getRank() != rank) { - return false; - } - } - return true; - } }]; } diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h index 1d88e8a657..01b6772972 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h @@ -12,7 +12,7 @@ namespace mlir { namespace tt { namespace ttir { namespace detail { -mlir::LogicalResult verifyElementwiseOp(mlir::Operation *op); +mlir::LogicalResult verifyBroadcastable(mlir::Operation *op); } // namespace detail } // namespace ttir } // namespace tt diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td index cbc0056737..a130332f0d 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td @@ -64,11 +64,13 @@ def TTIROpInterface : OpInterface<"TTIROp"> { ]; } -def TTIR_ElementwiseOpInterface : OpInterface<"ElementwiseOp"> { +def TTIR_Broadcastable : OpInterface<"Broadcastable"> { let cppNamespace = "::mlir::tt::ttir"; + let dependentTraits = [AttrSizedOperandSegments]; + let verify = [{ - return detail::verifyElementwiseOp($_op); + return detail::verifyBroadcastable($_op); }]; } @@ -105,6 +107,20 @@ def TTIR_GenericRegionOpInterface : OpInterface<"GenericRegionOp"> { /*methodBody=*/"", /*defaultImplementation=*/"" >, + StaticInterfaceMethod< + /*desc=*/[{ + Return if the given operands have the same rank. + }], + /*retTy=*/"bool", + /*methodName=*/"sameRank", + /*args=*/(ins "::mlir::OperandRange":$operands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::all_equal(llvm::map_range(operands, [](Value operand) { + return mlir::cast(operand.getType()).getRank(); + })); + }] + > ]; } diff --git a/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp index 84409174a3..10619f24b8 100644 --- a/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp +++ b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp @@ -17,37 +17,33 @@ #include "llvm/ADT/SmallVector.h" mlir::LogicalResult -mlir::tt::ttir::detail::verifyElementwiseOp(mlir::Operation *op) { +mlir::tt::ttir::detail::verifyBroadcastable(mlir::Operation *op) { + const auto getShape = [](const Value val) { + return mlir::cast(val.getType()).getShape(); + }; + + const auto operandSegmentSizes = + op->getAttrOfType("operandSegmentSizes"); + // DPS operands shouldn't affect the result shape. + const auto outputSegmentSize = + operandSegmentSizes[operandSegmentSizes.size() - 1]; + const auto operandShapes = llvm::map_range(op->getOperands(), getShape); llvm::SmallVector broadcastedShape; - mlir::OperandRange operands = op->getOperands(); - mlir::OperandRange::iterator operand_it = operands.begin(); - llvm::SmallVector prevOperandShape( - mlir::cast((*operand_it).getType()).getShape()); - - while (++operand_it != operands.end()) { - llvm::SmallVector nextOperandShape( - mlir::cast((*operand_it).getType()).getShape()); - - if (!OpTrait::util::getBroadcastedShape(prevOperandShape, nextOperandShape, + for (const auto operandShape : + llvm::drop_end(operandShapes, outputSegmentSize)) { + const auto prevBroadcastedShape = broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(prevBroadcastedShape, operandShape, broadcastedShape)) { return op->emitOpError("Operands are not broadcast compatible"); } - prevOperandShape = broadcastedShape; } - llvm::SmallVector resultShape( - mlir::cast(op->getResult(0).getType()) - .getShape()); + // Check that the result shape matches the broadcasted shape of the operands. + llvm::SmallVector resultShape(getShape(op->getResults().front())); if (broadcastedShape != resultShape) { return op->emitOpError( "Result shape must match operand shapes after broadcasting"); } - TypeID expectedBaseTy = op->getResultTypes().front().getTypeID(); - if (!llvm::all_of(op->getOperandTypes(), - [&](Type t) { return t.getTypeID() == expectedBaseTy; })) { - return op->emitOpError() << "All operands/results must have the same type"; - } - return success(); } diff --git a/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir new file mode 100644 index 0000000000..e1454ad0a0 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir @@ -0,0 +1,28 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for Broadcastable interface + +// CHECK: 'ttir.abs' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_unary(%arg0: tensor<1x64xbf16>) -> tensor<2x64xbf16> { + %0 = tensor.empty() : tensor<2x64xbf16> + %1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x64xbf16>, tensor<2x64xbf16>) -> tensor<2x64xbf16> + return %1 : tensor<2x64xbf16> +} + +// ----- +// CHECK: error: 'ttir.add' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<2x3x64xf32>, %arg1: tensor<64xf32>) -> tensor<4x2x3x64xf32> { + %0 = tensor.empty() : tensor<4x2x3x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x3x64xf32>, tensor<64xf32>, tensor<4x2x3x64xf32>) -> tensor<4x2x3x64xf32> + return %1 : tensor<4x2x3x64xf32> +} + +// ----- +// CHECK: error: 'ttir.where' op Result shape must match operand shapes after broadcasting +#any_device_tile = #tt.operand_constraint +func.func @eltwise_ternary(%arg0: tensor<3x64xf32>, %arg1: tensor<1x3x64xf32>, %arg2: tensor<2x1x64xf32>) -> tensor<1x2x3x64xf32> { + %0 = tensor.empty() : tensor<1x2x3x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64xf32>, tensor<1x3x64xf32>, tensor<2x1x64xf32>, tensor<1x2x3x64xf32>) -> tensor<1x2x3x64xf32> + return %1 : tensor<1x2x3x64xf32> +} diff --git a/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir new file mode 100644 index 0000000000..a22dc28370 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir @@ -0,0 +1,37 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for NOperands trait + +// CHECK: error: 'ttir.abs' op expected 2 operands, but found 3 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_unary(%arg0: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.abs"(%arg0, %arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> +} + +// ----- +// CHECK: error: 'ttir.add' op expected 3 operands, but found 4 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// ----- +// CHECK: error: 'ttir.add' op expected 3 operands, but found 2 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_binary(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +} + +// ----- +// CHECK: error: 'ttir.where' op expected 4 operands, but found 5 +#any_device_tile = #tt.operand_constraint +func.func @eltwise_ternary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + return %1 : tensor<64x64xf32> +}