From a7308cabd25a737ca238329699e453089cce2543 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Fri, 22 Nov 2024 19:21:07 +0100 Subject: [PATCH 1/3] Added LinearOp support (#1233) --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 28 +++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 28 +++ include/ttmlir/Target/TTNN/program.fbs | 8 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 17 +- lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 3 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 152 ++++++++++++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 152 ++++++++++++ lib/Target/TTNN/TTNNToFlatbuffer.cpp | 18 ++ runtime/lib/ttnn/operations/matmul/matmul.cpp | 38 ++- runtime/lib/ttnn/operations/matmul/matmul.h | 1 + runtime/lib/ttnn/program.cpp | 3 + .../TTIR/linear/linear_tests_negative.mlir | 194 ++++++++++++++++ .../TTNN/linear/linear_tests_positive.mlir | 216 ++++++++++++++++++ .../Dialect/TTNN/linear/simple_linear.mlir | 31 +++ .../TTNN/perf_unit/test_perf_linear.mlir | 20 ++ test/ttmlir/Silicon/TTNN/simple_linear.mlir | 33 +++ 16 files changed, 934 insertions(+), 8 deletions(-) create mode 100644 test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir create mode 100644 test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir create mode 100644 test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir create mode 100644 test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir create mode 100644 test/ttmlir/Silicon/TTNN/simple_linear.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 8782f63ae..5bfb77064 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1091,6 +1091,34 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> { }]; } +def TTIR_LinearOp : TTIR_DPSOp<"linear"> { + let summary = "Linear transformation of inputs."; + let description = [{ + Produces the matmul of tensors `a` and `b` with optional addition with `bias`. + + Example: + %a = tensor.empty() : () -> tensor<10x64x32xbf16> + %b = tensor.empty() : () -> tensor<32x128xbf16> + %bias = tensor.empty() : () -> tensor<128xbf16> + %output = tensor.empty() : () -> tensor<10x64x128xbf16> + %0 = "ttir.linear"(%a, %b, %bias, %output) : (tensor<10x64x32xbf16>, tensor<32x128xbf16>, tensor<128xbf16>, tensor<10x64x128xbf16>) -> tensor<10x64x128xbf16> + }]; + + let arguments = (ins AnyRankedTensor:$a, + AnyRankedTensor:$b, + Optional:$bias, + AnyRankedTensor:$output, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + // ANCHOR: adding_an_op_matmul_ttir def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { let summary = "Matrix multiply operation."; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 910ed7dfd..4147cc6d0 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -636,6 +636,34 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> { let hasVerifier = 1; } +def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> { + let summary = "Linear transformation of inputs."; + + let description = [{ + Produces the matmul of tensors `a` and `b` with optional addition with `bias`. + + Example: + // %a = [[1., 2.]], [2., 1.]] + // %b = [[0., 1.], [1., 0.]] + // %bias = [[1.]] + "ttnn.linear"(%a, %b, %bias, %result) : (tensor<2x2xf16>, tensor<2x2xf16>, tensor<1xf16>, tensor<2x2xf16>) -> tensor<2x2xf16> + // %result = [[3., 2.], [2., 3.]] + }]; + + let arguments = (ins AnyRankedTensor:$a, + AnyRankedTensor:$b, + Optional:$bias, + AnyRankedTensor:$output); + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + + // ANCHOR: adding_an_op_matmul_ttnn def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> { let arguments = (ins AnyRankedTensor:$a, diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index ec493e649..0be274b4b 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -178,6 +178,13 @@ table SliceOp { step: [int64]; } +table LinearOp { + in0: tt.target.TensorRef; + in1: tt.target.TensorRef; + bias: tt.target.TensorRef; + out: tt.target.TensorRef; +} + // ANCHOR: adding_an_op_matmul_fbs table MatmulOp { in0: tt.target.TensorRef; @@ -249,6 +256,7 @@ union OpType { EmptyOp, FullOp, EltwiseOp, + LinearOp, MatmulOp, ReductionOp, EmbeddingOp, diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 12e29a960..52995b64c 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -579,7 +579,19 @@ class ConstantOpConversionPattern } }; -} // namespace +class LinearOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::LinearOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(), + adaptor.getB(), adaptor.getBias(), adaptor.getOutput()); + return success(); + } +}; // ANCHOR: adding_an_op_matmul_op_rewriter class MatmulOpConversionPattern : public OpConversionPattern { @@ -908,6 +920,8 @@ class AllGatherOpConversionPattern } }; +} // namespace + namespace mlir::tt { void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, @@ -969,6 +983,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, SqueezeOpConversionPattern, UnsqueezeOpConversionPattern, ConstantOpConversionPattern, + LinearOpConversionPattern, MatmulOpConversionPattern, Conv2dOpConversionPattern, MaxPool2dOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 92862cd9d..6c83200f3 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -725,7 +725,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Matmul ops // - patterns.add>(typeConverter, ctx); + patterns.add, + DefaultOpConversionPattern>(typeConverter, ctx); // Reduction ops // diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 5946cb2fe..bf734df95 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -895,6 +895,158 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() { isMemoryLayoutChange}; } +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +// LinearOp verification +::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() { + ::mlir::RankedTensorType inputAType = getA().getType(); + ::mlir::RankedTensorType inputBType = getB().getType(); + std::optional<::mlir::RankedTensorType> biasType = + getBias() ? std::make_optional(getBias().getType()) : std::nullopt; + ::mlir::RankedTensorType outputType = getOutput().getType(); + + llvm::ArrayRef outputShape = outputType.getShape(); + llvm::SmallVector inputAShape(inputAType.getShape()); + llvm::SmallVector inputBShape(inputBType.getShape()); + + // Verify that the input A is at least 1D tensor. + if (inputAType.getRank() < 1) { + return emitOpError("Input A must be at least a 1D tensor"); + } + + // Verify that the input B is at least 1D tensor. + if (inputBType.getRank() < 1) { + return emitOpError("Input B must be at least a 1D tensor"); + } + + // If input A is a vector (1D tensor), 1 is prepended to its dimension for the + // purpose of the matrix multiplication. After the matrix multiplication, the + // prepended dimension is removed. + if (inputAType.getRank() == 1) { + inputAShape.insert(inputAShape.begin(), 1); + } + + // If input B is a vector (1D tensor), a 1 is appended to its dimension for + // the purpose of the matrix-vector product and removed afterwards. + if (inputBType.getRank() == 1) { + inputBShape.push_back(1); + } + + // Verify that the input A and input B has matching inner dimensions. + if (inputAShape[inputAShape.size() - 1] != + inputBShape[inputBShape.size() - 2]) { + return emitOpError( + "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + + ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + + ") must have matching inner dimensions"); + } + + llvm::SmallVector expectedOutputShape; + // Verify that the batch dimensions are broadcast compatible and construct the + // expected output shape. + if (inputAShape.size() > 2 || inputBShape.size() > 2) { + llvm::SmallVector inputABatchDims, inputBBatchDims; + + if (inputAShape.size() > 2) { + inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), + inputAShape.end() - 2); + } + + if (inputBShape.size() > 2) { + inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), + inputBShape.end() - 2); + } + + // Verify that the batch dimensions of input A and B are broadcast + // compatible. + llvm::SmallVector broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, + broadcastedShape)) { + + return emitOpError("Batch dimensions of input A(" + + ttmlir::utils::join(inputABatchDims, ",") + + ") and B(" + + ttmlir::utils::join(inputBBatchDims, ",") + + ") are not broadcast compatible"); + } + + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape.insert(expectedOutputShape.begin(), + broadcastedShape.begin(), + broadcastedShape.end()); + } + + // Insert the input A and B inner dimensions in expected output shape. + // Consider the case where input A and B are vectors. In that case, + // the dimension 1 is ommited from the output shape. + if (inputAType.getRank() > 1) { + expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]); + } + + if (inputBType.getRank() > 1) { + expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]); + } + + if (biasType) { + // Verify that the input bias is at least 1D tensor. + if (biasType.value().getRank() < 1) { + return emitOpError("Bias must be at least a 1D tensor"); + } + + llvm::SmallVector biasShape(biasType.value().getShape()); + + // Verify that the dimensions of the matmul of A and B are broadcast + // compatible with input bias. + llvm::SmallVector matmulShape = expectedOutputShape; + if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, + expectedOutputShape)) { + return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + + ") is not broadcast compatible with the matmul output " + "shape(" + + ttmlir::utils::join(matmulShape, ",") + ")"); + } + } + + // Check the case of a vector-vector product. At this moment we don't support + // scalars in IR, hence check that the output is at least 1D tensor of size 1. + if (expectedOutputShape.size() == 0) { + if (outputType.getRank() < 1) { + return emitOpError("Scalar output is not supported, output must be at " + "least a 1D tensor"); + } + + if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) { + return emitOpError("Scalar output must be a 1D tensor of size 1"); + } + + return success(); + } + + // Verify that the output shape dimension count is correct. + if (outputShape.size() != expectedOutputShape.size()) { + return emitOpError("Output shape rank(" + + std::to_string(outputShape.size()) + + ") must match the expected output shape rank(" + + std::to_string(expectedOutputShape.size()) + ")"); + } + + // Verify each dim of the output shape. + for (size_t i = 0; i < outputShape.size(); i++) { + if (outputShape[i] != expectedOutputShape[i]) { + return emitOpError( + "Output shape dimension[" + std::to_string(i) + "](" + + std::to_string(outputShape[i]) + + ") doesn't match the expected output shape dimension[" + + std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + + ")"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 4abd74d62..c4f0d7394 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -592,6 +592,158 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// LinearOp +//===----------------------------------------------------------------------===// + +// LinearOp verification +::mlir::LogicalResult mlir::tt::ttnn::LinearOp::verify() { + ::mlir::RankedTensorType inputAType = getA().getType(); + ::mlir::RankedTensorType inputBType = getB().getType(); + std::optional<::mlir::RankedTensorType> biasType = + getBias() ? std::make_optional(getBias().getType()) : std::nullopt; + ::mlir::RankedTensorType outputType = getOutput().getType(); + + llvm::ArrayRef outputShape = outputType.getShape(); + llvm::SmallVector inputAShape(inputAType.getShape()); + llvm::SmallVector inputBShape(inputBType.getShape()); + + // Verify that the input A is at least 1D tensor. + if (inputAType.getRank() < 1) { + return emitOpError("Input A must be at least a 1D tensor"); + } + + // Verify that the input B is at least 1D tensor. + if (inputBType.getRank() < 1) { + return emitOpError("Input B must be at least a 1D tensor"); + } + + // If input A is a vector (1D tensor), 1 is prepended to its dimension for the + // purpose of the matrix multiplication. After the matrix multiplication, the + // prepended dimension is removed. + if (inputAType.getRank() == 1) { + inputAShape.insert(inputAShape.begin(), 1); + } + + // If input B is a vector (1D tensor), a 1 is appended to its dimension for + // the purpose of the matrix-vector product and removed afterwards. + if (inputBType.getRank() == 1) { + inputBShape.push_back(1); + } + + // Verify that the input A and input B has matching inner dimensions. + if (inputAShape[inputAShape.size() - 1] != + inputBShape[inputBShape.size() - 2]) { + return emitOpError( + "Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) + + ") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) + + ") must have matching inner dimensions"); + } + + llvm::SmallVector expectedOutputShape; + // Verify that the batch dimensions are broadcast compatible and construct the + // expected output shape. + if (inputAShape.size() > 2 || inputBShape.size() > 2) { + llvm::SmallVector inputABatchDims, inputBBatchDims; + + if (inputAShape.size() > 2) { + inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(), + inputAShape.end() - 2); + } + + if (inputBShape.size() > 2) { + inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(), + inputBShape.end() - 2); + } + + // Verify that the batch dimensions of input A and B are broadcast + // compatible. + llvm::SmallVector broadcastedShape; + if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims, + broadcastedShape)) { + + return emitOpError("Batch dimensions of input A(" + + ttmlir::utils::join(inputABatchDims, ",") + + ") and B(" + + ttmlir::utils::join(inputBBatchDims, ",") + + ") are not broadcast compatible"); + } + + // Insert the broadcasted batch dimensions in the expected output shape. + expectedOutputShape.insert(expectedOutputShape.begin(), + broadcastedShape.begin(), + broadcastedShape.end()); + } + + // Insert the input A and B inner dimensions in expected output shape. + // Consider the case where input A and B are vectors. In that case, + // the dimension 1 is ommited from the output shape. + if (inputAType.getRank() > 1) { + expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]); + } + + if (inputBType.getRank() > 1) { + expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]); + } + + if (biasType) { + // Verify that the input bias is at least 1D tensor. + if (biasType.value().getRank() < 1) { + return emitOpError("Bias must be at least a 1D tensor"); + } + + llvm::SmallVector biasShape(biasType.value().getShape()); + + // Verify that the dimensions of the matmul of A and B are broadcast + // compatible with input bias. + llvm::SmallVector matmulShape = expectedOutputShape; + if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape, + expectedOutputShape)) { + return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") + + ") is not broadcast compatible with the matmul output " + "shape(" + + ttmlir::utils::join(matmulShape, ",") + ")"); + } + } + + // Check the case of a vector-vector product. At this moment we don't support + // scalars in IR, hence check that the output is at least 1D tensor of size 1. + if (expectedOutputShape.size() == 0) { + if (outputType.getRank() < 1) { + return emitOpError("Scalar output is not supported, output must be at " + "least a 1D tensor"); + } + + if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) { + return emitOpError("Scalar output must be a 1D tensor of size 1"); + } + + return success(); + } + + // Verify that the output shape dimension count is correct. + if (outputShape.size() != expectedOutputShape.size()) { + return emitOpError("Output shape rank(" + + std::to_string(outputShape.size()) + + ") must match the expected output shape rank(" + + std::to_string(expectedOutputShape.size()) + ")"); + } + + // Verify each dim of the output shape. + for (size_t i = 0; i < outputShape.size(); i++) { + if (outputShape[i] != expectedOutputShape[i]) { + return emitOpError( + "Output shape dimension[" + std::to_string(i) + "](" + + std::to_string(outputShape[i]) + + ") doesn't match the expected output shape dimension[" + + std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) + + ")"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // MatmulOp //===----------------------------------------------------------------------===// diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 30b83014d..8971963f2 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -333,6 +333,21 @@ createOp(FlatbufferObjectCache &cache, FullOp op) { kHostAllocatedSize)); } +::flatbuffers::Offset<::tt::target::ttnn::LinearOp> +createOp(FlatbufferObjectCache &cache, LinearOp op) { + auto in0 = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getA())); + auto in1 = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getB())); + auto bias = op.getODSOperands(2).empty() + ? flatbuffers::Offset<::tt::target::TensorRef>() + : cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getBias())); + auto output = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + return ::tt::target::ttnn::CreateLinearOp(*cache.fbb, in0, in1, bias, output); +} + // ANCHOR: adding_an_op_matmul_serialize_to_binary ::flatbuffers::Offset<::tt::target::ttnn::MatmulOp> createOp(FlatbufferObjectCache &cache, MatmulOp op) { @@ -801,6 +816,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createEltwiseOp(cache, leakyReluOp), debugString); } + if (auto linearOp = dyn_cast(op); linearOp) { + return createOperation(cache, createOp(cache, linearOp), debugString); + } if (auto matmulOp = dyn_cast(op); matmulOp) { return createOperation(cache, createOp(cache, matmulOp), debugString); } diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index abe71f970..a25102d9a 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -8,8 +8,8 @@ #include "tt/runtime/ttnn/operations/utils.h" #include -// ANCHOR: adding_an_op_matmul_runtime_operations namespace tt::runtime::ttnn::operations::matmul { +// ANCHOR: adding_an_op_matmul_runtime_operations void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); @@ -20,10 +20,6 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { ::tt::tt_metal::MemoryConfig outputMemoryConfig = utils::createMemoryConfig(op->out()); - std::optional< - ::ttnn::operations::matmul::MatmulMultiCoreReuseMultiCast1DProgramConfig> - programConfig = std::nullopt; - const std::optional memoryConfig = std::make_optional(outputMemoryConfig); @@ -37,5 +33,35 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { tensorPool.insert_or_assign(op->out()->global_id(), out); } -} // namespace tt::runtime::ttnn::operations::matmul // ANCHOR_END: adding_an_op_matmul_runtime_operations + +void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); + const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); + std::optional<::ttnn::Tensor> bias = + op->bias() ? std::make_optional(tensorPool.at(op->bias()->global_id())) + : std::nullopt; + + DEBUG_ASSERT(lhs.is_allocated()); + DEBUG_ASSERT(rhs.is_allocated()); + DEBUG_ASSERT(!bias || bias->is_allocated()); + + ::ttnn::DataType outputDataType = utils::getDataType(op->out()); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + utils::createMemoryConfig(op->out()); + + const std::optional memoryConfig = + std::make_optional(outputMemoryConfig); + + const std::optional dtype = + std::make_optional(outputDataType); + + ::ttnn::Tensor out = ::ttnn::linear( + lhs, rhs, bias, /*transposeA*/ false, /*transposeB*/ false, memoryConfig, + dtype, /*programConfig*/ std::nullopt, /*activation*/ std::nullopt, + /*computeKernelConfig*/ std::nullopt, /*coreGrid*/ std::nullopt); + + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::matmul diff --git a/runtime/lib/ttnn/operations/matmul/matmul.h b/runtime/lib/ttnn/operations/matmul/matmul.h index 5957a54a3..7b0583786 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.h +++ b/runtime/lib/ttnn/operations/matmul/matmul.h @@ -10,6 +10,7 @@ namespace tt::runtime::ttnn::operations::matmul { void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context); +void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context); } // namespace tt::runtime::ttnn::operations::matmul #endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 8cfa01389..fbd58c593 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -148,6 +148,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::EltwiseOp: { return runEltwiseOperation(op->type_as_EltwiseOp()); } + case ::tt::target::ttnn::OpType::LinearOp: { + return operations::matmul::run(op->type_as_LinearOp(), context); + } // ANCHOR: adding_an_op_matmul_runtime_program case ::tt::target::ttnn::OpType::MatmulOp: { return operations::matmul::run(op->type_as_MatmulOp(), context); diff --git a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir new file mode 100644 index 000000000..522628160 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir @@ -0,0 +1,194 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for linear operation + +// Verify that the parsing fails if either of operands is a scalar +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_a(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input A must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_b(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input B must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Bias must be at least a 1D tensor + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// Verifty that the parsing fails if the output is a scalar +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_scalar_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { + // CHECK: error: 'ttir.linear' op Scalar output is not supported, output must be at least a 1D tensor + %0 = tensor.empty() : tensor + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor + return %1 : tensor + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_output_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { + // CHECK: error: 'ttir.linear' op Scalar output must be a 1D tensor of size 1 + %0 = tensor.empty() : tensor<2xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + return %1 : tensor<2xbf16> + } +} + +// Inner dimension mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { +func.func @linear_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](64) and B[-2](128) must have matching inner dimensions + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_1d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } +} + +// Batch dimension mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(7) and B(2) are not broadcast compatible + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible + %0 = tensor.empty() : tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + return %1 : tensor<7x7x64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible + %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + return %1 : tensor<12x7x7x64x64xbf16> + } +} + +// Bias shape mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_matmul_bias_broadcast_incompatible(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: error: 'ttir.linear' op Bias shape(2,64) is not broadcast compatible with the matmul output shape(64,64) + %0 = tensor.empty() : tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<2x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_nd_nd_matmul_bias_broadcast_incompatible(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64x64xbf16>) -> tensor<3x64x64xbf16> { + // CHECK: error: 'ttir.linear' op Bias shape(2,64,64) is not broadcast compatible with the matmul output shape(3,64,64) + %0 = tensor.empty() : tensor<3x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<128x64xbf16>, tensor<2x64x64xbf16>, tensor<3x64x64xbf16>) -> tensor<3x64x64xbf16> + return %1 : tensor<3x64x64xbf16> + } +} + +// Output shape mismatch tests +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { + // CHECK: error: 'ttir.linear' op Output shape rank(1) must match the expected output shape rank(2) + %0 = tensor.empty() : tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } +} + +// ----- +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { + // CHECK: error: 'ttir.linear' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) + %0 = tensor.empty() : tensor<64x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir new file mode 100644 index 000000000..0e248623d --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir @@ -0,0 +1,216 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module { + func.func @linear_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<1xbf16 + %0 = tensor.empty() : tensor<1xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func @linear_1d_1d_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<1xbf16>) -> tensor<1xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<1xbf16 + %0 = tensor.empty() : tensor<1xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + // CHECK-SAME: tensor<1xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func @linear_1d_1d_bias_broadcast(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<128xbf16>) -> tensor<128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<128xbf16 + %0 = tensor.empty() : tensor<128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>) -> tensor<128xbf16> + return %1 : tensor<128xbf16> + } + + func.func @linear_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64xbf16 + %0 = tensor.empty() : tensor<64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + return %1 : tensor<64xbf16> + } + + func.func @linear_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @linear_2d_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @linear_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x64xbf16 + %0 = tensor.empty() : tensor<12x7x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<128xbf16 + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<12x7x64xbf16 + // CHECK-SAME: tensor<12x7x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> + return %1 : tensor<12x7x64xbf16> + } + + func.func @linear_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x128xbf16 + %0 = tensor.empty() : tensor<12x7x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<12x7x128xbf16 + // CHECK-SAME: tensor<12x7x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> + return %1 : tensor<12x7x128xbf16> + } + + func.func @linear_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x64x64xbf16 + %0 = tensor.empty() : tensor<12x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<12x7x64x64xbf16 + // CHECK-SAME: tensor<12x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> + return %1 : tensor<12x7x64x64xbf16> + } + + func.func @linear_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x128x128xbf16 + %0 = tensor.empty() : tensor<12x7x128x128xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x7x128x64xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<12x7x128x128xbf16 + // CHECK-SAME: tensor<12x7x128x128xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> + return %1 : tensor<12x7x128x128xbf16> + } + + // linear nd - nd tests + func.func @linear_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x64x64xbf16 + %0 = tensor.empty() : tensor<7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<7x64x128xbf16 + // CHECK-SAME: tensor<7x128x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } + + func.func @linear_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x64x64xbf16 + %0 = tensor.empty() : tensor<7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<7x64x128xbf16 + // CHECK-SAME: tensor<1x128x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + // CHECK-SAME: tensor<7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + return %1 : tensor<7x64x64xbf16> + } + + func.func @linear_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<7x7x64x64xbf16 + %0 = tensor.empty() : tensor<7x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<1x7x64x128xbf16 + // CHECK-SAME: tensor<7x1x128x64xbf16 + // CHECK-SAME: tensor<7x7x64x64xbf16 + // CHECK-SAME: tensor<7x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + return %1 : tensor<7x7x64x64xbf16> + } + + func.func @linear_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<12x1x7x64x128xbf16 + // CHECK-SAME: tensor<7x1x128x64xbf16 + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + // CHECK-SAME: tensor<12x7x7x64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + return %1 : tensor<12x7x7x64x64xbf16> + } + + func.func @linear_nd_nd_bias_broadcast_bias(%arg0: tensor<14x7x32x32xbf16>, %arg1:tensor<14x1x32x64xbf16>, %bias: tensor<64xbf16>) -> tensor<14x7x32x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<14x7x32x64xbf16 + %0 = tensor.empty() : tensor<14x7x32x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<14x7x32x32xbf16 + // CHECK-SAME: tensor<14x1x32x64xbf16 + // CHECK-SAME: tensor<64xbf16 + // CHECK-SAME: tensor<14x7x32x64xbf16 + // CHECK-SAME: tensor<14x7x32x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<14x7x32x32xbf16>, tensor<14x1x32x64xbf16>, tensor<64xbf16>, tensor<14x7x32x64xbf16>) -> tensor<14x7x32x64xbf16> + return %1 : tensor<14x7x32x64xbf16> + } + + func.func @linear_nd_nd_bias_broadcast_matmul(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<4x3x128x32xbf16>, %bias: tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + %0 = tensor.empty() : tensor<14x4x3x64x32xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<3x64x128xbf16 + // CHECK-SAME: tensor<4x3x128x32xbf16 + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + // CHECK-SAME: tensor<14x4x3x64x32xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<4x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> + return %1 : tensor<14x4x3x64x32xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir new file mode 100644 index 000000000..56728eb52 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir @@ -0,0 +1,31 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint + +module { + func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir new file mode 100644 index 000000000..6da5d3910 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir @@ -0,0 +1,20 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device_tile = #tt.operand_constraint +module { + func.func @linear(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_linear.mlir b/test/ttmlir/Silicon/TTNN/simple_linear.mlir new file mode 100644 index 000000000..f53de38cf --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/simple_linear.mlir @@ -0,0 +1,33 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device_tile = #tt.operand_constraint +module { + func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } + + func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x64xbf16 + %0 = tensor.empty() : tensor<64x64xbf16> + // CHECK: "ttnn.linear" + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<128x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + // CHECK-SAME: tensor<64x64xbf16 + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + return %1 : tensor<64x64xbf16> + } +} From feb127907958b14bb969ffafaf526a9509c858d4 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Fri, 22 Nov 2024 14:40:32 -0500 Subject: [PATCH 2/3] Bringup ttir.arange, ttnn.arange. (#1332) Add conversion patterns from stablehlo.iota and stablehlo.dynamic_iota to ttir.arange Add pattern in TTIRToTTIRDecompositionPass to rewrite all ttir.arange ops where the arange_dimension is not the right-most dim. This has the effect of making-explicit the broadcasts and tms that would need to be done after executin ttnn.arange Add special TTNNLayout case for ttir.arange since it is a creation op add runtime support and basic silicon test, stablehlo silicon tests Added decomposition test --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 42 ++++++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 26 ++++ include/ttmlir/Target/TTNN/program.fbs | 11 ++ .../StableHLOToTTIRPatterns.cpp | 40 +++++ .../TTIRToTTIRDecomposition.cpp | 138 ++++++++++++++++++ .../TTIRToTTIRDecompositionPass.cpp | 8 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 44 +++++- lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 4 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 31 ++++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 26 ++++ lib/Dialect/TTNN/Transforms/TTNNLayout.cpp | 29 ++++ lib/Target/TTNN/TTNNToFlatbuffer.cpp | 29 ++++ .../lib/ttnn/include/tt/runtime/ttnn/utils.h | 1 + runtime/lib/ttnn/operations/CMakeLists.txt | 1 + .../lib/ttnn/operations/creation/arange.cpp | 46 ++++++ runtime/lib/ttnn/operations/creation/arange.h | 17 +++ runtime/lib/ttnn/program.cpp | 4 + .../StableHLOToTTIR/dynamic_iota_op.mlir | 11 ++ .../Conversion/StableHLOToTTIR/iota_op.mlir | 10 ++ .../Decomposition/arange_decomposition.mlir | 11 ++ .../select_decomposition_tests.mlir | 0 .../TTNN/arange/arange_tests_negative.mlir | 12 ++ .../TTNN/arange/arange_tests_positive.mlir | 11 ++ .../Iota/simple_device_dynamic_iota_dim2.mlir | 15 ++ .../Iota/simple_device_dynamic_iota_dim3.mlir | 16 ++ .../Iota/simple_device_iota_dim2.mlir | 15 ++ .../Iota/simple_device_iota_dim3.mlir | 15 ++ .../arange/simple_device_arange_dim2.mlir | 13 ++ .../arange/simple_device_arange_dim3.mlir | 13 ++ 29 files changed, 636 insertions(+), 3 deletions(-) create mode 100644 runtime/lib/ttnn/operations/creation/arange.cpp create mode 100644 runtime/lib/ttnn/operations/creation/arange.h create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir create mode 100644 test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir rename test/ttmlir/Dialect/TTIR/{decompositions => Decomposition}/select_decomposition_tests.mlir (100%) create mode 100644 test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir create mode 100644 test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir create mode 100644 test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir create mode 100644 test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir create mode 100644 test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir create mode 100644 test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir create mode 100644 test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir create mode 100644 test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 5bfb77064..aeb2de1ae 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1048,6 +1048,48 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> { let hasVerifier = 1; } +def TTIR_ArangeOp : TTIR_Op<"arange"> { + let summary = "Arange operation."; + let description = [{ + Tensor arange operation. + + Produces a tensor with values from `start` to `end` (exclusive) with a step size of `step`, along the dimension specified by `arange_dimension`. + + Examples: + %0 = "ttir.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5xi64> + // %0: [0, 1, 2, 3, 4] + + %1 = "ttir.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64, arange_dimension = 0 : i64} : () -> tensor<5xf32> + // %1: [0.0, 2.0, 4.0, 6.0, 8.0] + + %2 = "ttir.arange"() {start = 0 : i64, end = 5 : i64, step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5x3xi64> + // %2: [ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [3, 3, 3], + [4, 4, 4] + ] + + %3 = "ttir.arange"() {start = 0 : i64, end = 3 : i64, step = 1 : i64, arange_dimension = 1 : i64} : () -> tensor<5x3xi64> + // %3: [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ] + }]; + + let arguments = (ins SI64Attr:$start, + SI64Attr:$end, + SI64Attr:$step, + I64Attr:$arange_dimension); + + let results = (outs AnyRankedTensor:$result); + let hasVerifier = 1; +} + def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike, AllShapesMatch<["value", "result"]>]> { let summary = "Constant op."; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 4147cc6d0..21eb704cf 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -787,6 +787,32 @@ def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> { let hasVerifier = 1; } +def TTNN_ArangeOp : TTNN_Op<"arange"> { + let summary = "Arange operation."; + let description = [{ + Tensor arange operation. + + Produces a (1, 1, 1, N)-shaped tensor with values from `start` to `end` (exclusive) with a step size of `step`. + + Examples: + %0 = "ttnn.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64} : () -> tensor<1x1x1x5xi64> + // %0: [[[[0, 1, 2, 3, 4]]]] + + %1 = "ttnn.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64} : () -> tensor<1x1x1x5xf32> + // %1: [[[[0.0, 2.0, 4.0, 6.0, 8.0]]]] + }]; + + let arguments = (ins I64Attr:$start, + I64Attr:$end, + I64Attr:$step, + OptionalAttr:$dtype, + Optional:$device, + OptionalAttr:$memory_config); + + let results = (outs AnyRankedTensor:$result); + let hasVerifier = 1; +} + def TTNN_FullOp : TTNN_Op<"full"> { let summary = "Full op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 0be274b4b..5f486bac9 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -61,6 +61,16 @@ table FullOp { out: tt.target.TensorRef; } +table ArangeOp { + start: float; + end: float; + step: float; + dtype: tt.target.DataType = null; // optional + device: tt.target.DeviceRef; // optional + memcfg: tt.target.MemoryConfigDesc; // optional + out: tt.target.TensorRef; +} + enum EltwiseOpType: uint32 { Add = 0, Multiply = 1, @@ -269,6 +279,7 @@ union OpType { MaxPool2dOp, DeallocateOp, AllGatherOp, + ArangeOp, } table Operation { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 28bf4f71d..8db1b44e6 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1201,6 +1201,36 @@ class StableHLOToTTIRGatherOpConversionPattern } }; +template +class StableHLOToTTIROpIotaOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(SrcIotaOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + RankedTensorType outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResult().getType())); + rewriter.replaceOpWithNewOp( + srcOp, outputType, 0, outputType.getDimSize(adaptor.getIotaDimension()), + 1, adaptor.getIotaDimension()); + + // Dynamic Iota has an output_shape attribute but the output shape is + // already known by the result type This is to remove the operand that will + // become dead code + for (auto operand : adaptor.getOperands()) { + if (operand.getDefiningOp()) { + rewriter.eraseOp(operand.getDefiningOp()); + } + } + + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1365,6 +1395,15 @@ void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, patterns.add(typeConverter, ctx); } +void addIotaOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns + .add>( + typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1389,6 +1428,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addSliceOpConversionPattern(ctx, patterns, typeConverter); addClampOpConversionPattern(ctx, patterns, typeConverter); addGatherOpConversionPattern(ctx, patterns, typeConverter); + addIotaOpConversionPattern(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 9c5afd41e..ed7eb0be8 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -897,6 +897,143 @@ struct SelectToSliceConversionPattern } }; +/* + * This pattern rewrites ArangeOp by forcing the arange_dimension to be + * rightmost dimension of the output tensor. This is done by replacing the + * ArangeOp with a new one that has this property, and then transposing out last + * dimension to the dimension specified by the original ArangeOp, and also + * inserting a reshape to match the rank of the intended output and broadcasts + * to repeat the data along the other dimensions. + * + * The ArangeOp that is generated here will be equivalent to how ttnn::ArangeOp + * behaves. The reason this pass is done in TTIR rather than generated when we + * want to lower to TTNN is because in the future we will want to consteval the + * ArangeOp, but have the option to not include repeated data in the constant + * tensor and broadcast at runtime instead. Consteval will be implemented for + * the TTIR dialect only and so this explication of the TMs implicit in ArangeOp + * must be done in TTIR. + */ +struct ArangeForceLastDimensionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + const RankedTensorType outputType = + mlir::cast(op.getResult().getType()); + + int64_t arangeDimension = adaptor.getArangeDimension(); + int64_t arangeDimensionNegative = arangeDimension - outputType.getRank(); + int64_t start = adaptor.getStart(); + int64_t end = adaptor.getEnd(); + int64_t step = adaptor.getStep(); + + int64_t arangeLength = (end - start) / step; + + ArrayRef ttnnShape = {1, 1, 1, arangeLength}; + if (ttnnShape == outputType.getShape()) { + return success(); + } + + RankedTensorType arangeOutputType = RankedTensorType::get( + SmallVector({1, 1, 1, arangeLength}), + outputType.getElementType(), outputType.getEncoding()); + + Value output = + rewriter + .create( // perform arange on the last dimension to + // match how ttnn behaves + op.getLoc(), arangeOutputType, start, end, step, 3) + .getResult(); + + std::vector outputShape = arangeOutputType.getShape().vec(); + // Must transpose the output so that the data changes along the axis defined + // by arangeDimension + if (arangeDimensionNegative != -1) { + std::vector transposeShape = outputShape; + transposeShape[arangeDimensionNegative + transposeShape.size()] = + arangeLength; + transposeShape[arangeOutputType.getRank() - 1] = 1; + RankedTensorType transposeType = RankedTensorType::get( + transposeShape, arangeOutputType.getElementType(), + arangeOutputType.getEncoding()); + + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), transposeShape, transposeType.getElementType()); + + output = rewriter.create( + op.getLoc(), transposeType, output, dpsOutput, + arangeDimensionNegative + transposeShape.size(), + arangeOutputType.getRank() - 1, + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + outputShape = transposeShape; + } + + // Must match up the rank of the output with the rank of the intended output + // from the original arange, with the arangeDimension in the correct + // position + if (outputType.getRank() != static_cast(outputShape.size())) { + std::vector reshapeShape; + for (uint32_t i = 0; i < outputType.getRank(); i++) { + i == arangeDimension ? reshapeShape.push_back(end) + : reshapeShape.push_back(1); + } + + RankedTensorType reshapeType = RankedTensorType::get( + SmallVector(reshapeShape.begin(), reshapeShape.end()), + outputType.getElementType(), outputType.getEncoding()); + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), + SmallVector(reshapeShape.begin(), reshapeShape.end()), + reshapeType.getElementType()); + output = rewriter.create( + op.getLoc(), reshapeType, output, dpsOutput, + rewriter.getI32ArrayAttr(reshapeShape), + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + outputShape = + std::vector(reshapeShape.begin(), reshapeShape.end()); + } + + // Must broadcast the rest of the dimensions + SmallVector broadcastDims; + for (uint32_t i = 0; i < outputShape.size(); i++) { + if (i != arangeDimension && outputShape[i] != outputType.getShape()[i]) { + outputShape[i] = outputType.getShape()[i]; + broadcastDims.push_back(rewriter.getI64IntegerAttr(i)); + } + } + if (!broadcastDims.empty()) { + RankedTensorType broadcastType = RankedTensorType::get( + outputShape, outputType.getElementType(), outputType.getEncoding()); + + tensor::EmptyOp dpsOutput = rewriter.create( + op.getLoc(), outputShape, outputType.getElementType()); + + output = rewriter.create( + op.getLoc(), broadcastType, output, dpsOutput, + rewriter.getArrayAttr(broadcastDims), + rewriter.getArrayAttr(SmallVector( + 2, rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + + assert(mlir::cast(output.getType()).getShape() == + outputType.getShape() && + "Output shape must match the shape of the input tensor"); + } + rewriter.replaceOp(op, output); + return success(); + } +}; + void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -906,6 +1043,7 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index d91084f59..e244eea8f 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -53,6 +53,14 @@ struct TTIRToTTIRDecompositionPass target.addIllegalOp(); target.addIllegalOp(); + // These are the ops that must satisfy some conditions after this pass + target.addDynamicallyLegalOp([&](ttir::ArangeOp op) { + auto shape = op.getResult().getType().getShape(); + return (static_cast(op.getArangeDimension()) == 3 && + shape.size() == 4 && shape[0] == 1 && shape[1] == 1 && + shape[2] == 1); + }); + TypeConverter typeConverter; // All types map 1:1. typeConverter.addConversion([](Type type) { return type; }); diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 52995b64c..9dbc9cf97 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -920,6 +920,47 @@ class AllGatherOpConversionPattern } }; +class ArangeOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + RankedTensorType outputType = + mlir::cast(op.getResult().getType()); + assert(static_cast(adaptor.getArangeDimension()) == + outputType.getRank() - 1 && + "Arange dimension must be the final dimension of the output tensor " + "to convert to ttnn.arange"); + + // Get ttnn::TTNNLayoutAttr of the result type + // + ttnn::TTNNLayoutAttr layoutAttr = + mlir::cast(outputType.getEncoding()); + + DataTypeAttr dtypeAttr = rewriter.getAttr( + elementTypeToDataType(outputType.getElementType())); + Value device = getOrInsertDevice(rewriter, op); + + ttnn::MemoryConfigAttr memConfigAttr = + rewriter.getAttr( + rewriter.getAttr( + layoutAttr.getMemLayout()), + rewriter.getAttr(layoutAttr.getBufferType()), + rewriter.getAttr( + rewriter.getAttr( + layoutAttr.getMemref().getShape()))); + + rewriter.replaceOpWithNewOp( + op, outputType, adaptor.getStart(), adaptor.getEnd(), adaptor.getStep(), + dtypeAttr, device, memConfigAttr); + + return success(); + } +}; + } // namespace namespace mlir::tt { @@ -988,7 +1029,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, Conv2dOpConversionPattern, MaxPool2dOpConversionPattern, SubtractOpConversionPattern, - AllGatherOpConversionPattern + AllGatherOpConversionPattern, + ArangeOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set // clang-format on diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 6c83200f3..c5ab71b23 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -668,8 +668,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Tensor ops // patterns - .add>( - typeConverter, ctx); + .add, + DefaultOpConversionPattern>(typeConverter, ctx); // Eltwise unary ops // diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index bf734df95..3cd28626a 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -45,6 +45,37 @@ ::mlir::LogicalResult mlir::tt::ttir::ClampOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ArangeOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::ArangeOp::verify() { + int64_t start = getStart(); + int64_t end = getEnd(); + int64_t step = getStep(); + + if (step == 0) { + return emitOpError("Step value cannot be zero"); + } + + int64_t numValues = (end - start) / step; + + if (numValues <= 0) { + return emitOpError() << "Invalid range: start=" << start << ", end=" << end + << ", step=" << step; + } + + if (numValues != getType().getDimSize(getArangeDimension())) { + return emitOpError() << "Output tensor shape must be " << numValues + << " at dim " << getArangeDimension() + << " (since start=" << start << ", end=" << end + << ", step=" << step << "), but got " + << getType().getDimSize(getArangeDimension()); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index c4f0d7394..b3201cf67 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -140,6 +140,32 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ArangeOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() { + + if (getStep() == 0) { + return emitOpError("Step cannot be zero."); + } + + int64_t numValues = (getEnd() - getStart()) / getStep(); + + if (numValues <= 0) { + return emitOpError("Invalid range: start=") + << getStart() << ", end=" << getEnd() << ", step=" << getStep(); + } + + std::vector expectedShape = {1, 1, 1, numValues}; + if (getType().getShape().vec() != expectedShape) { + return emitOpError() << "Output tensor shape must be " << expectedShape + << ", but got " << getType().getShape(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // EmptyOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index eebfdc13f..2d4a2ff8f 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -214,6 +214,28 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, .getResult(); } + // If the input tensor is an arange, we want to set the desired layout just + // like the other creation ops. However, a caveat is that in ttnn, arange is + // hardcoded to be ROW_MAJOR. So we must ensure that the layout we assign to + // it is ROW_MAJOR - and to make it tile layout we still must insert + // ToLayoutOp on its output. We can do this by setting the element type to + // ty.getElementType() in case desiredElementType is a TileType. + ttir::ArangeOp existingArange = input.getDefiningOp(); + if (existingArange) { + TTNNLayoutAttr arangeLayout = rewriter.getAttr( + ty.getShape(), ty.getElementType(), desiredBufferType, + tensorConfig.getGrid(), desiredMemLayout, g_defaultCollapseDims); + input = + rewriter + .replaceOpWithNewOp( + existingArange, + mlir::RankedTensorType::get(ty.getShape(), ty.getElementType(), + arangeLayout), + existingArange.getStart(), existingArange.getEnd(), + existingArange.getStep(), existingArange.getArangeDimension()) + .getResult(); + } + // If the input tensor is not a constant or empty tensor, we need to create a // new tensor with the desired layout which will be used as the output of the // ToLayoutOp @@ -281,6 +303,13 @@ class TTNNLayoutDPSOperandsRewriter continue; } + // If the operand is a BroadcastOp or a ToLayout op do not put a + // ToLayoutOp on its output + if (operand.get().getDefiningOp() || + operand.get().getDefiningOp()) { + continue; + } + // Read operand constrait for current operand OperandConstraint operandConstraint = mlir::cast( diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 8971963f2..5677ce94b 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -28,6 +28,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Support/LogicalResult.h" +#include "types_generated.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -333,6 +334,31 @@ createOp(FlatbufferObjectCache &cache, FullOp op) { kHostAllocatedSize)); } +::flatbuffers::Offset<::tt::target::ttnn::ArangeOp> +createOp(FlatbufferObjectCache &cache, ArangeOp op) { + + std::optional<::tt::target::DataType> dtype = + op.getDtype().has_value() + ? std::make_optional(toFlatbuffer(cache, op.getDtype().value())) + : std::nullopt; + auto device = + op.getDevice() ? cache.at<::tt::target::DeviceRef>(op.getDevice()) : 0; + + auto memoryConfigDesc = op.getMemoryConfig().has_value() + ? cache.getOrCreate(op.getMemoryConfig().value(), + memoryConfigToFlatbuffer) + : 0; + + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + + return ::tt::target::ttnn::CreateArangeOp( + *cache.fbb, static_cast(op.getStart()), + static_cast(op.getEnd()), static_cast(op.getStep()), + dtype /* optional */, device /* optional */, + memoryConfigDesc /* optional */, output); +} + ::flatbuffers::Offset<::tt::target::ttnn::LinearOp> createOp(FlatbufferObjectCache &cache, LinearOp op) { auto in0 = @@ -887,6 +913,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto geluOp = dyn_cast(op); geluOp) { return createOperation(cache, createEltwiseOp(cache, geluOp), debugString); } + if (auto arangeOp = dyn_cast(op); arangeOp) { + return createOperation(cache, createOp(cache, arangeOp), debugString); + } llvm_unreachable("unhandled op in emitTTNNOperation"); } diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h index ca50ad58b..75b22d114 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h @@ -6,6 +6,7 @@ #define TT_RUNTIME_TTNN_UTILS_H #include "flatbuffers/vector.h" +#include "tt_metal/impl/buffers/buffer.hpp" #include "ttmlir/Target/Common/types_generated.h" #include "ttmlir/Target/TTNN/Target.h" #include "ttnn/types.hpp" diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index 4edc4780b..38115803f 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -5,6 +5,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/creation/arange.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp ${CMAKE_CURRENT_SOURCE_DIR}/data_movement/concat.cpp diff --git a/runtime/lib/ttnn/operations/creation/arange.cpp b/runtime/lib/ttnn/operations/creation/arange.cpp new file mode 100644 index 000000000..446cdf72a --- /dev/null +++ b/runtime/lib/ttnn/operations/creation/arange.cpp @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "arange.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" +#include +#include +#include + +namespace tt::runtime::ttnn::operations::creation { +void run(const ::tt::target::ttnn::ArangeOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + ::ttnn::DataType dtype = + ::ttnn::DataType::BFLOAT16; // Default in arange implementation + std::optional> device = std::nullopt; + ::ttnn::MemoryConfig memoryConfig = + ::ttnn::DRAM_MEMORY_CONFIG; // Default in arange implementation + + if (op->dtype()) { + dtype = ::tt::runtime::ttnn::utils::toTTNNDataType(*(op->dtype())); + } + + if (op->memcfg()) { + memoryConfig = utils::createMemoryConfig(op->memcfg(), op->out()); + } + + if (op->device()) { + // ttnn::arange supports no device (host) and single device + DeviceVariant targetDevice = + context.getTargetDevice(op->device()->global_id()); + + LOG_ASSERT(std::holds_alternative>( + targetDevice), + "ttnn::arange does not support MeshDevice."); + device = std::make_optional( + std::get>(targetDevice)); + } + ::ttnn::Tensor out = ::ttnn::arange(op->start(), op->end(), op->step(), dtype, + device, memoryConfig); + + utils::updateTensorPool(tensorPool, out, op->out()->global_id()); +} +} // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/creation/arange.h b/runtime/lib/ttnn/operations/creation/arange.h new file mode 100644 index 000000000..157ee2dc6 --- /dev/null +++ b/runtime/lib/ttnn/operations/creation/arange.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef RUNTIME_LIB_TTNN_OPERATIONS_CREATION_ARANGE_H +#define RUNTIME_LIB_TTNN_OPERATIONS_CREATION_ARANGE_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::creation { + +void run(const ::tt::target::ttnn::ArangeOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::creation + +#endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index fbd58c593..48b0be7ff 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -4,6 +4,7 @@ #include "operations/ccl/all_gather.h" #include "operations/context/get_device.h" #include "operations/conv/conv2d.h" +#include "operations/creation/arange.h" #include "operations/creation/empty.h" #include "operations/creation/full.h" #include "operations/data_movement/concat.h" @@ -189,6 +190,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::AllGatherOp: { return operations::ccl::run(op->type_as_AllGatherOp(), context); } + case ::tt::target::ttnn::OpType::ArangeOp: { + return operations::creation::run(op->type_as_ArangeOp(), context); + } default: { LOG_FATAL("Unsupported operation type"); } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir new file mode 100644 index 000000000..43241ac6f --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/dynamic_iota_op.mlir @@ -0,0 +1,11 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_dnamic_iota attributes {} { + func.func public @test_dynamic_iota() -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + %output_shape = stablehlo.constant dense<[1, 32, 128, 128]> : tensor<4xi64> + %0 = "stablehlo.dynamic_iota"(%output_shape) {iota_dimension = 1: i64} : (tensor<4xi64>) -> tensor<1x32x128x128xf32> + return %0 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir new file mode 100644 index 000000000..857a621bb --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/iota_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_iota attributes {} { + func.func public @test_iota() -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + %0 = "stablehlo.iota"() {iota_dimension = 1: i64} : () -> tensor<1x32x128x128xf32> + return %0 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir new file mode 100644 index 000000000..6f72e56f1 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/Decomposition/arange_decomposition.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.arange"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]] + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 1: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + return %1 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/decompositions/select_decomposition_tests.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir similarity index 100% rename from test/ttmlir/Dialect/TTIR/decompositions/select_decomposition_tests.mlir rename to test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir new file mode 100644 index 000000000..dc3f09fba --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir @@ -0,0 +1,12 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for matmul operation +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: error: 'ttir.arange' op Output tensor shape must be 16 at dim 1 (since start=0, end=32, step=2), but got 32 + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 2: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + %dps = tensor.empty() : tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %2 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir new file mode 100644 index 000000000..4c04e138b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 1: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> + %dps = tensor.empty() : tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %2 : tensor<1x32x128x128xf32> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir new file mode 100644 index 000000000..d911ec6fe --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim2.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 2: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir new file mode 100644 index 000000000..01aa0e91b --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_dynamic_iota_dim3.mlir @@ -0,0 +1,16 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + %output_shape = stablehlo.constant dense<[1, 1, 32, 128]> : tensor<4xi64> + // CHECK: ttnn.arange + %0 = "stablehlo.dynamic_iota"(%output_shape) {iota_dimension = 3: i64} : (tensor<4xi64>) -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir new file mode 100644 index 000000000..d911ec6fe --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim2.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 2: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir new file mode 100644 index 000000000..a231432ab --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Iota/simple_device_iota_dim3.mlir @@ -0,0 +1,15 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: ttnn.arange + %0 = "stablehlo.iota"() {iota_dimension = 3: i64} : () -> tensor<1x1x32x128xbf16> + %2 = "stablehlo.multiply"(%arg0, %0) : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir new file mode 100644 index 000000000..ec509a1b6 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %0 = "ttir.arange"() <{start = 0: si64, end = 64: si64, step = 2: si64, arange_dimension = 2: i64}> : () -> tensor<1x1x32x128xbf16> + %1 = tensor.empty() : tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir new file mode 100644 index 000000000..196e75709 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] + %0 = "ttir.arange"() <{start = 0: si64, end = 128: si64, step = 1: si64, arange_dimension = 3: i64}> : () -> tensor<1x1x32x128xbf16> + %1 = tensor.empty() : tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %2 : tensor<1x1x32x128xbf16> + } +} From c908d529c2d435aeaa1ef96a3bd3288fa0f736d5 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Fri, 22 Nov 2024 14:02:17 -0600 Subject: [PATCH 3/3] Ignore `*.ttnn` & `*.ttm` Files (#1365) These flatbuffer files are generated as part of `test_infra`, and should not be comitted. --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index b20627983..274c39c1f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,7 @@ query_results.json run_results.json ttrt_report.xml cluster_descriptor.yaml + +# TTNN and TTMetal flatbuffers +*.ttnn +*.ttm