Skip to content

Commit

Permalink
Added support for scatter op (#1279)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT authored Dec 2, 2024
1 parent 99331c7 commit c4b3dff
Show file tree
Hide file tree
Showing 15 changed files with 338 additions and 3 deletions.
34 changes: 34 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,40 @@ def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> {
}];
}

def TTIR_ScatterOp: TTIR_DPSOp<"scatter"> {
let summary = "Scatter operation";
let description = [{
Produces a 'result' tensor which are equal to `input` tensor except that
several slices specified by `scatter_indices` are updated with the values
`updates`.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$scatter_indices,
AnyRankedTensor:$update,
DenseI32ArrayAttr:$update_window_dims,
DenseI32ArrayAttr:$inserted_window_dims,
DenseI32ArrayAttr:$input_batching_dims,
DenseI32ArrayAttr:$scatter_indices_batching_dims,
DenseI32ArrayAttr:$scatter_dims_to_operand_dims,
I32Attr:$index_vector_dim,
BoolAttr:$indices_are_sorted,
BoolAttr:$unique_indices,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);

let regions = (region SizedRegion<1>:$update_computation);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

}

//===----------------------------------------------------------------------===//
// TTIR region ops (ops that may appear inside of ttir.generic region)
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,13 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
let hasVerifier = 1;
}

def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> {
let summary = "Scatter op.";
let description = [{
Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor.
}];
}

def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> {
let summary = "Reduce scatter op.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum EltwiseOpType: uint32 {
LogicalXor,
Clamp,
LeakyRelu,
Scatter
}

table ClampOpParams {
Expand Down
145 changes: 145 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,137 @@ class StableHLOToTTIROpIotaOpConversionPattern
}
};

class StableHLOToTTIRScatterOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ScatterOp> {

using OpConversionPattern<mlir::stablehlo::ScatterOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ScatterOp srcOp,
mlir::stablehlo::ScatterOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResults()[0].getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
Value operand = srcOp.getInputs()[0];
Value scatterIndices = srcOp.getScatterIndices();
Value update = srcOp.getUpdates()[0];
mlir::ArrayAttr binaryConstraints = rewriter.getArrayAttr(
SmallVector<Attribute>(4, rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile)));
auto updateWindowsDims =
adaptor.getScatterDimensionNumbers().getUpdateWindowDims();
auto insertedWindowDims =
adaptor.getScatterDimensionNumbers().getInsertedWindowDims();
auto inputBatchingDims =
adaptor.getScatterDimensionNumbers().getInputBatchingDims();
auto scatterIndicesBatchingDims =
adaptor.getScatterDimensionNumbers().getScatterIndicesBatchingDims();
auto scatterDimsToOperandDims =
adaptor.getScatterDimensionNumbers().getScatterDimsToOperandDims();
auto indexVectorDim =
adaptor.getScatterDimensionNumbers().getIndexVectorDim();
auto indicesAreSorted = adaptor.getIndicesAreSorted();
auto uniqueIndices = adaptor.getUniqueIndices();

auto newScatterOp = rewriter.create<mlir::tt::ttir::ScatterOp>(
srcOp.getLoc(), outputType, operand, scatterIndices, update,
llvm::ArrayRef<int32_t>(
convertArrayRefToInt32vector(updateWindowsDims)),
llvm::ArrayRef<int32_t>(
convertArrayRefToInt32vector(insertedWindowDims)),
llvm::ArrayRef<int32_t>(
convertArrayRefToInt32vector(inputBatchingDims)),
llvm::ArrayRef<int32_t>(
convertArrayRefToInt32vector(scatterIndicesBatchingDims)),
llvm::ArrayRef<int32_t>(
convertArrayRefToInt32vector(scatterDimsToOperandDims)),
indexVectorDim, indicesAreSorted, uniqueIndices, outputTensor,
binaryConstraints);

// Replaces with different types do not work and will fail silently, so we
// manually set the second operand, since the type changes there from i32 to
// i64.
newScatterOp.setOperand(
1, adaptor.getScatterIndices().getDefiningOp()->getResult(0));

newScatterOp->getRegion(0).takeBody(adaptor.getUpdateComputation());
changeRegionTypes(newScatterOp->getRegion(0), *getTypeConverter(),
rewriter);

rewriter.replaceOp(srcOp, newScatterOp);

return success();
}

private:
std::vector<int32_t>
convertArrayRefToInt32vector(const llvm::ArrayRef<int64_t> &source) const {
std::vector<int32_t> converted;
converted.reserve(source.size());

for (int64_t value : source) {
converted.push_back(static_cast<int32_t>(value));
}

return converted;
}

void changeRegionTypes(mlir::Region &region,
const mlir::TypeConverter &typeConverter,
mlir::PatternRewriter &rewriter) const {
Block &block = *region.getBlocks().begin();
llvm::SmallVector<mlir::BlockArgument, 4> oldArguments(
block.getArguments().begin(), block.getArguments().end());
llvm::SmallVector<mlir::Value, 4> newArguments;

// Add new arguments with updated types to the block.
for (auto arg : oldArguments) {
if (auto newType = typeConverter.convertType(arg.getType())) {
mlir::BlockArgument newArg = block.addArgument(newType, arg.getLoc());
newArguments.push_back(newArg);
} else {
newArguments.push_back(arg); // Type didn't change
}
}

for (auto it : llvm::zip(oldArguments, newArguments)) {
mlir::BlockArgument oldArg = std::get<0>(it);
mlir::Value newArg = std::get<1>(it);
if (oldArg != newArg) {
oldArg.replaceAllUsesWith(newArg);
}
}

for (auto arg : oldArguments) {
if (!llvm::is_contained(newArguments, arg)) {
block.eraseArgument(arg.getArgNumber());
}
}
}
};

class StableHLOToTTIRReturnOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ReturnOp> {

using OpConversionPattern<mlir::stablehlo::ReturnOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ReturnOp srcOp,
mlir::stablehlo::ReturnOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<mlir::tt::ttir::YieldOp>(srcOp,
srcOp.getResults());

return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1846,6 +1977,18 @@ void addIotaOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
typeConverter, ctx);
}

void addScatterOpConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRScatterOpConversionPattern>(typeConverter, ctx);
}

void addReturnOpConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRReturnOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1872,6 +2015,8 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addClampOpConversionPattern(ctx, patterns, typeConverter);
addGatherOpConversionPattern(ctx, patterns, typeConverter);
addIotaOpConversionPattern(ctx, patterns, typeConverter);
addScatterOpConversionPatterns(ctx, patterns, typeConverter);
addReturnOpConversionPatterns(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
18 changes: 17 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,21 @@ class ArangeOpConversionPattern : public OpConversionPattern<ttir::ArangeOp> {
}
};

class ScatterOpConversionPattern : public OpConversionPattern<ttir::ScatterOp> {
public:
using OpConversionPattern<ttir::ScatterOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// The ttnn interface has the inverse inputs of the TTIR dialect op (which
// matches torch ops).
rewriter.replaceOpWithNewOp<ttnn::ScatterOp>(
op, adaptor.getUpdate(), adaptor.getInput(), adaptor.getOutput());

return success();
}
};
} // namespace

namespace mlir::tt {
Expand Down Expand Up @@ -1022,7 +1037,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
AllGatherOpConversionPattern,
ArangeOpConversionPattern
ArangeOpConversionPattern,
ScatterOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
// clang-format on
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::MaximumOp>,
DefaultOpConversionPattern<ttnn::MinimumOp>,
DefaultOpConversionPattern<ttnn::DivOp>,
DefaultOpConversionPattern<ttnn::ScatterOp>,
DefaultOpConversionPattern<ttnn::RemainderOp>>(typeConverter,
ctx);

Expand Down
62 changes: 62 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,68 @@ ::mlir::LogicalResult mlir::tt::ttir::MeshShardOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//

bool matchSimpleBlock(mlir::Region &region) {
if (!region.hasOneBlock()) {
return false;
}
mlir::Block &block = region.front();
if (block.getNumArguments() != 2) {
return false;
}
auto argType1 =
mlir::cast<mlir::RankedTensorType>(block.getArgument(0).getType());
auto argType2 =
mlir::cast<mlir::RankedTensorType>(block.getArgument(1).getType());
if (!argType1 || !argType2) {
return false;
}
if (block.getOperations().size() != 1) {
return false;
}
mlir::tt::ttir::YieldOp returnOp =
mlir::cast<mlir::tt::ttir::YieldOp>(&block.front());
if (!returnOp) {
return false;
}
if (returnOp.getNumOperands() != 1 ||
returnOp.getOperand(0) != block.getArgument(1)) {
return false;
}
return true;
}

::mlir::LogicalResult mlir::tt::ttir::ScatterOp::verify() {

ArrayRef<int64_t> inputShape =
mlir::cast<RankedTensorType>(getInput().getType()).getShape();

if (getUpdateWindowDims().size() + getInsertedWindowDims().size() !=
inputShape.size()) {
return emitOpError("Batching currently not supported");
}

for (uint64_t insertedWindowDims : getInsertedWindowDims()) {
if (inputShape[insertedWindowDims] != 1) {
return emitOpError("Dimension size to slice into must be 1");
}
}

// We currently do not support custom functions in the scatter function,
// which is a possbility in StableHLO dialect. See issue:
// https://github.com/tenstorrent/tt-mlir/issues/1278
if (!matchSimpleBlock(getUpdateComputation())) {
return emitOpError(
"Currently not supporting custom scatter function in TTNN "
"dialect and TT-metal.");
}

return success();
}

//===----------------------------------------------------------------------===//
// GenericOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,10 @@ ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AllGatherOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult AllGatherOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t dim = getDim();
Expand All @@ -961,6 +965,10 @@ ::mlir::LogicalResult AllGatherOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ReduceScatterOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult ReduceScatterOp::verify() {
// TODO(gfengTT)
return success();
Expand Down
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Div;
} else if constexpr (std::is_same_v<EltwiseOp, SigmoidOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Sigmoid;
} else if constexpr (std::is_same_v<EltwiseOp, ScatterOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Scatter;
} else if constexpr (std::is_same_v<EltwiseOp, Log1pOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Log1p;
} else if constexpr (std::is_same_v<EltwiseOp, ExpOp>) {
Expand Down Expand Up @@ -819,6 +821,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto log1pOp = dyn_cast<Log1pOp>(op); log1pOp) {
return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString);
}
if (auto scatterOp = dyn_cast<ScatterOp>(op); scatterOp) {
return createOperation(cache, createEltwiseOp(cache, scatterOp),
debugString);
}
if (auto reciprocalOp = dyn_cast<ReciprocalOp>(op); reciprocalOp) {
return createOperation(cache, createEltwiseOp(cache, reciprocalOp),
debugString);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::remainder);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Scatter: {
runEltwiseBinaryCompositeOp(op, tensorPool, ::ttnn::scatter);
break;
}
default:
LOG_FATAL("Unsupported Eltwise Binary Composite operation");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ inline bool isBinaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
case ::tt::target::ttnn::EltwiseOpType::Maximum:
case ::tt::target::ttnn::EltwiseOpType::Minimum:
case ::tt::target::ttnn::EltwiseOpType::Remainder:
case ::tt::target::ttnn::EltwiseOpType::Scatter:
return true;
default:
return false;
Expand Down
Loading

0 comments on commit c4b3dff

Please sign in to comment.