From 133d6e17dc94d6b5269d1da840cbc124b6c1f3fc Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 29 Nov 2024 13:28:34 +0000 Subject: [PATCH] Added support for scatter op --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 34 ++++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 7 + include/ttmlir/Target/TTNN/program.fbs | 1 + .../StableHLOToTTIRPatterns.cpp | 145 ++++++++++++++++++ lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 18 ++- lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 1 + lib/Dialect/TTIR/IR/TTIROps.cpp | 62 ++++++++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 8 + lib/Target/TTNN/TTNNToFlatbuffer.cpp | 6 + .../eltwise/binary/binary_composite.cpp | 4 + .../eltwise/binary/binary_composite.h | 1 + .../ttnn/operations/eltwise/binary/utils.cpp | 12 +- .../StableHLOToTTIR/scatter_op.mlir | 16 ++ test/ttmlir/Dialect/TTNN/simple_scatter.mlir | 16 ++ test/ttmlir/Silicon/TTNN/simple_eltwise.mlir | 10 ++ 15 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir create mode 100644 test/ttmlir/Dialect/TTNN/simple_scatter.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index f5e284078..8908e470e 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1244,6 +1244,40 @@ def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> { }]; } +def TTIR_ScatterOp: TTIR_DPSOp<"scatter"> { + let summary = "Scatter operation"; + let description = [{ + Produces a 'result' tensor which are equal to `input` tensor except that + several slices specified by `scatter_indices` are updated with the values + `updates`. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$scatter_indices, + AnyRankedTensor:$update, + DenseI32ArrayAttr:$update_window_dims, + DenseI32ArrayAttr:$inserted_window_dims, + DenseI32ArrayAttr:$input_batching_dims, + DenseI32ArrayAttr:$scatter_indices_batching_dims, + DenseI32ArrayAttr:$scatter_dims_to_operand_dims, + I32Attr:$index_vector_dim, + BoolAttr:$indices_are_sorted, + BoolAttr:$unique_indices, + AnyRankedTensor:$output, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let regions = (region SizedRegion<1>:$update_computation); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + +} + //===----------------------------------------------------------------------===// // TTIR region ops (ops that may appear inside of ttir.generic region) //===----------------------------------------------------------------------===// diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 21eb704cf..57383c007 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -860,6 +860,13 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> { let hasVerifier = 1; } +def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> { + let summary = "Scatter op."; + let description = [{ + Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor. + }]; +} + def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> { let summary = "Reduce scatter op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 39535e2f0..f145aaf65 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -112,6 +112,7 @@ enum EltwiseOpType: uint32 { LogicalXor, Clamp, LeakyRelu, + Scatter } table ClampOpParams { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index ccf21ff27..d81b6e214 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1666,6 +1666,137 @@ class StableHLOToTTIROpIotaOpConversionPattern } }; +class StableHLOToTTIRScatterOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ScatterOp srcOp, + mlir::stablehlo::ScatterOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResults()[0].getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + Value operand = srcOp.getInputs()[0]; + Value scatterIndices = srcOp.getScatterIndices(); + Value update = srcOp.getUpdates()[0]; + mlir::ArrayAttr binaryConstraints = rewriter.getArrayAttr( + SmallVector(4, rewriter.getAttr( + OperandConstraint::AnyDeviceTile))); + auto updateWindowsDims = + adaptor.getScatterDimensionNumbers().getUpdateWindowDims(); + auto insertedWindowDims = + adaptor.getScatterDimensionNumbers().getInsertedWindowDims(); + auto inputBatchingDims = + adaptor.getScatterDimensionNumbers().getInputBatchingDims(); + auto scatterIndicesBatchingDims = + adaptor.getScatterDimensionNumbers().getScatterIndicesBatchingDims(); + auto scatterDimsToOperandDims = + adaptor.getScatterDimensionNumbers().getScatterDimsToOperandDims(); + auto indexVectorDim = + adaptor.getScatterDimensionNumbers().getIndexVectorDim(); + auto indicesAreSorted = adaptor.getIndicesAreSorted(); + auto uniqueIndices = adaptor.getUniqueIndices(); + + auto newScatterOp = rewriter.create( + srcOp.getLoc(), outputType, operand, scatterIndices, update, + llvm::ArrayRef( + convertArrayRefToInt32vector(updateWindowsDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(insertedWindowDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(inputBatchingDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(scatterIndicesBatchingDims)), + llvm::ArrayRef( + convertArrayRefToInt32vector(scatterDimsToOperandDims)), + indexVectorDim, indicesAreSorted, uniqueIndices, outputTensor, + binaryConstraints); + + // Replaces with different types do not work and will fail silently, so we + // manually set the second operand, since the type changes there from i32 to + // i64. + newScatterOp.setOperand( + 1, adaptor.getScatterIndices().getDefiningOp()->getResult(0)); + + newScatterOp->getRegion(0).takeBody(adaptor.getUpdateComputation()); + changeRegionTypes(newScatterOp->getRegion(0), *getTypeConverter(), + rewriter); + + rewriter.replaceOp(srcOp, newScatterOp); + + return success(); + } + +private: + std::vector + convertArrayRefToInt32vector(const llvm::ArrayRef &source) const { + std::vector converted; + converted.reserve(source.size()); + + for (int64_t value : source) { + converted.push_back(static_cast(value)); + } + + return converted; + } + + void changeRegionTypes(mlir::Region ®ion, + const mlir::TypeConverter &typeConverter, + mlir::PatternRewriter &rewriter) const { + Block &block = *region.getBlocks().begin(); + llvm::SmallVector oldArguments( + block.getArguments().begin(), block.getArguments().end()); + llvm::SmallVector newArguments; + + // Add new arguments with updated types to the block. + for (auto arg : oldArguments) { + if (auto newType = typeConverter.convertType(arg.getType())) { + mlir::BlockArgument newArg = block.addArgument(newType, arg.getLoc()); + newArguments.push_back(newArg); + } else { + newArguments.push_back(arg); // Type didn't change + } + } + + for (auto it : llvm::zip(oldArguments, newArguments)) { + mlir::BlockArgument oldArg = std::get<0>(it); + mlir::Value newArg = std::get<1>(it); + if (oldArg != newArg) { + oldArg.replaceAllUsesWith(newArg); + } + } + + for (auto arg : oldArguments) { + if (!llvm::is_contained(newArguments, arg)) { + block.eraseArgument(arg.getArgNumber()); + } + } + } +}; + +class StableHLOToTTIRReturnOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ReturnOp srcOp, + mlir::stablehlo::ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(srcOp, + srcOp.getResults()); + + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1846,6 +1977,18 @@ void addIotaOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, typeConverter, ctx); } +void addScatterOpConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + +void addReturnOpConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1872,6 +2015,8 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addClampOpConversionPattern(ctx, patterns, typeConverter); addGatherOpConversionPattern(ctx, patterns, typeConverter); addIotaOpConversionPattern(ctx, patterns, typeConverter); + addScatterOpConversionPatterns(ctx, patterns, typeConverter); + addReturnOpConversionPatterns(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 3241928f4..18efb982e 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -953,6 +953,21 @@ class ArangeOpConversionPattern : public OpConversionPattern { } }; +class ScatterOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::ScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The ttnn interface has the inverse inputs of the TTIR dialect op (which + // matches torch ops). + rewriter.replaceOpWithNewOp( + op, adaptor.getUpdate(), adaptor.getInput(), adaptor.getOutput()); + + return success(); + } +}; } // namespace namespace mlir::tt { @@ -1022,7 +1037,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, MaxPool2dOpConversionPattern, SubtractOpConversionPattern, AllGatherOpConversionPattern, - ArangeOpConversionPattern + ArangeOpConversionPattern, + ScatterOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set // clang-format on diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index c5ab71b23..f04d5566b 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -713,6 +713,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern>(typeConverter, ctx); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 11cfbb8fb..aacb2a43d 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1323,6 +1323,68 @@ ::mlir::LogicalResult mlir::tt::ttir::MeshShardOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +bool matchSimpleBlock(mlir::Region ®ion) { + if (!region.hasOneBlock()) { + return false; + } + mlir::Block &block = region.front(); + if (block.getNumArguments() != 2) { + return false; + } + auto argType1 = + mlir::cast(block.getArgument(0).getType()); + auto argType2 = + mlir::cast(block.getArgument(1).getType()); + if (!argType1 || !argType2) { + return false; + } + if (block.getOperations().size() != 1) { + return false; + } + mlir::tt::ttir::YieldOp returnOp = + mlir::cast(&block.front()); + if (!returnOp) { + return false; + } + if (returnOp.getNumOperands() != 1 || + returnOp.getOperand(0) != block.getArgument(1)) { + return false; + } + return true; +} + +::mlir::LogicalResult mlir::tt::ttir::ScatterOp::verify() { + + ArrayRef inputShape = + mlir::cast(getInput().getType()).getShape(); + + if (getUpdateWindowDims().size() + getInsertedWindowDims().size() != + inputShape.size()) { + return emitOpError("Batching currently not supported"); + } + + for (uint64_t insertedWindowDims : getInsertedWindowDims()) { + if (inputShape[insertedWindowDims] != 1) { + return emitOpError("Dimension size to slice into must be 1"); + } + } + + // We currently do not support custom functions in the scatter function, + // which is a possbility in StableHLO dialect. See issue: + // https://github.com/tenstorrent/tt-mlir/issues/1278 + if (!matchSimpleBlock(getUpdateComputation())) { + return emitOpError( + "Currently not supporting custom scatter function in TTNN " + "dialect and TT-metal."); + } + + return success(); +} + //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index cd2746aad..8e41368cb 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -950,6 +950,10 @@ ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AllGatherOp +//===----------------------------------------------------------------------===// + ::mlir::LogicalResult AllGatherOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); int32_t dim = getDim(); @@ -961,6 +965,10 @@ ::mlir::LogicalResult AllGatherOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ReduceScatterOp +//===----------------------------------------------------------------------===// + ::mlir::LogicalResult ReduceScatterOp::verify() { // TODO(gfengTT) return success(); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index e7df85956..34a0c4725 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -526,6 +526,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Div; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Sigmoid; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Scatter; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Log1p; } else if constexpr (std::is_same_v) { @@ -819,6 +821,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto log1pOp = dyn_cast(op); log1pOp) { return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString); } + if (auto scatterOp = dyn_cast(op); scatterOp) { + return createOperation(cache, createEltwiseOp(cache, scatterOp), + debugString); + } if (auto reciprocalOp = dyn_cast(op); reciprocalOp) { return createOperation(cache, createEltwiseOp(cache, reciprocalOp), debugString); diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp index 2a05d6246..5c1d056f9 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp @@ -41,6 +41,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::remainder); break; } + case ::tt::target::ttnn::EltwiseOpType::Scatter: { + runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::scatter); + break; + } default: LOG_FATAL("Unsupported Eltwise Binary Composite operation"); } diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h index 9be8bc6b7..bd497fe98 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.h @@ -15,6 +15,7 @@ inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { case ::tt::target::ttnn::EltwiseOpType::Maximum: case ::tt::target::ttnn::EltwiseOpType::Minimum: case ::tt::target::ttnn::EltwiseOpType::Remainder: + case ::tt::target::ttnn::EltwiseOpType::Scatter: return true; default: return false; diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp index a54777ab2..f97f71e40 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp @@ -7,6 +7,15 @@ namespace tt::runtime::ttnn::operations::binary { +bool shouldSwapBinaryOperands(const ::tt::target::ttnn::EltwiseOp *op, + ::ttnn::Tensor **lhs, ::ttnn::Tensor **rhs) { + // For scatter, we expect the left-hand side operator to be lesser or equal in + // volume to the right hand side, so we omit the swap. + return (op->type() != ::tt::target::ttnn::EltwiseOpType::Scatter && + workaround::Env::get().swapBinaryOperands && + (*lhs)->volume() < (*rhs)->volume()); +} + void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, ::ttnn::Tensor **lhs, @@ -21,8 +30,7 @@ void getEltwiseBinaryOpInputTensors(const ::tt::target::ttnn::EltwiseOp *op, // TODO(bug #1124): We're currently swapping the operands for binary ops // in runtime if the lhs operand is smaller (and requires broadcast onto the // rhs operand). We should add this check in the compiler. - if (workaround::Env::get().swapBinaryOperands && - (*lhs)->volume() < (*rhs)->volume()) { + if (shouldSwapBinaryOperands(op, lhs, rhs)) { std::swap(*lhs, *rhs); } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir new file mode 100644 index 000000000..92cd8895f --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir @@ -0,0 +1,16 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_scatter attributes {} { + func.func public @test_scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x1xi64>, %arg2: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE1:tensor<[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + %result = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi64>, tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> + // CHECK: [[VAL1:%[0-9]+]] = "ttir.scatter"(%arg0, %arg1, %arg2, [[VAL0]]) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array} + // CHECK: ([[TENSOR_SIZE1]], tensor<1x1xi32>, tensor<1x3x32x32xf32>, [[TENSOR_SIZE1]]) -> tensor<1x3x320x320xf32> + return %result : tensor<1x3x320x320xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE1]] + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_scatter.mlir b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir new file mode 100644 index 000000000..5991efeab --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + %0 = tensor.empty() : tensor<1x3x320x320xf32> + %1 = tensor.empty() : tensor<1x1xi32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, shape = #ttnn.shape<[[TENSOR_SHAPE0:[0-9]+x[0-9]+x[0-9]+x[0-9]+]]>}> : (!tt.device<#device>) -> tensor<[[TENSOR_SHAPE1:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+]], {{.*}}> + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): + "ttir.yield"(%arg4) : (tensor<1xf32>) -> () + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> + // CHECK: {{[0-9]+}} = "ttnn.scatter"(%4, %2, %5) <{operandSegmentSizes = array}> : (tensor<1x3x32x32xf32, {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>) -> tensor<[[TENSOR_SHAPE1]], {{.*}}> + return %2 : tensor<1x3x320x320xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE1]], {{.*}}> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 976f2867d..b7912d4c1 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -306,3 +306,13 @@ func.func @addint32(%arg0: tensor<64x128xi32>, %arg1: tensor<64x128xi32>) -> ten %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xi32>, tensor<64x128xi32>, tensor<64x128xi32>) -> tensor<64x128xi32> return %1 : tensor<64x128xi32> } + +func.func @scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + %0 = tensor.empty() : tensor<1x3x320x320xf32> + %1 = tensor.empty() : tensor<1x1xi32> + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): + "ttir.yield"(%arg4) : (tensor<1xf32>) -> () + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> + return %2 : tensor<1x3x320x320xf32> +}