Skip to content

Commit

Permalink
Bringup ttir.arange, ttnn.arange.
Browse files Browse the repository at this point in the history
Add conversion patterns from stablehlo.iota and stablehlo.dynamic_iota
to ttir.arange

Add pattern in TTIRToTTIRDecompositionPass to rewrite all ttir.arange
ops where the arange_dimension is not the right-most dim. This has the
effect of making-explicit the broadcasts and tms that would need to be
done after executin ttnn.arange

add runtime support and basic silicon test

Add special TTNNLayout case for ttir.arange since it is a creation op

Added decomposition test

Bringup runtime tests for arange

add stablehlo silicon tests
  • Loading branch information
LPanosTT committed Nov 20, 2024
1 parent 351a587 commit fb05a84
Show file tree
Hide file tree
Showing 27 changed files with 590 additions and 4 deletions.
42 changes: 42 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,48 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> {
let hasVerifier = 1;
}

def TTIR_ArangeOp : TTIR_Op<"arange"> {
let summary = "Arange operation.";
let description = [{
Tensor arange operation.

Produces a tensor with values from `start` to `end` (exclusive) with a step size of `step`, along the dimension specified by `arange_dimension`.

Examples:
%0 = "ttir.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5xi64>
// %0: [0, 1, 2, 3, 4]

%1 = "ttir.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64, arange_dimension = 0 : i64} : () -> tensor<5xf32>
// %1: [0.0, 2.0, 4.0, 6.0, 8.0]

%2 = "ttir.arange"() {start = 0 : i64, end = 5 : i64, step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5x3xi64>
// %2: [
[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]
]

%3 = "ttir.arange"() {start = 0 : i64, end = 3 : i64, step = 1 : i64, arange_dimension = 1 : i64} : () -> tensor<5x3xi64>
// %3: [
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]
]
}];

let arguments = (ins SI64Attr:$start,
SI64Attr:$end,
SI64Attr:$step,
I64Attr:$arange_dimension);

let results = (outs AnyRankedTensor:$result);
let hasVerifier = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
24 changes: 24 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,30 @@ def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> {
let hasVerifier = 1;
}

def TTNN_ArangeOp : TTNN_Op<"arange"> {
let summary = "Arange operation.";
let description = [{
Tensor arange operation.

Produces a (1, 1, 1, N)-shaped tensor with values from `start` to `end` (exclusive) with a step size of `step`.

Examples:
%0 = "ttnn.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64} : () -> tensor<1x1x1x5xi64>
// %0: [[[[0, 1, 2, 3, 4]]]]

%1 = "ttnn.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64} : () -> tensor<1x1x1x5xf32>
// %1: [[[[0.0, 2.0, 4.0, 6.0, 8.0]]]]
}];

let arguments = (ins I64Attr:$start,
I64Attr:$end,
I64Attr:$step,
TT_Device:$device);

let results = (outs AnyRankedTensor:$result);
let hasVerifier = 1;
}

def TTNN_FullOp : TTNN_Op<"full"> {
let summary = "Full op.";
let description = [{
Expand Down
10 changes: 10 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ table FullOp {
out: tt.target.TensorRef;
}

table ArangeOp {
start: float;
end: float;
step: float;
dtype: DataType;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
}

enum EltwiseOpType: uint32 {
Add = 0,
Multiply = 1,
Expand Down Expand Up @@ -261,6 +270,7 @@ union OpType {
MaxPool2dOp,
DeallocateOp,
AllGatherOp,
ArangeOp,
}

table Operation {
Expand Down
40 changes: 40 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,36 @@ class StableHLOToTTIRGatherOpConversionPattern
}
};

template <typename SrcIotaOp, typename Adaptor = typename SrcIotaOp::Adaptor>
class StableHLOToTTIROpIotaOpConversionPattern
: public OpConversionPattern<SrcIotaOp> {

using OpConversionPattern<SrcIotaOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(SrcIotaOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

RankedTensorType outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));
rewriter.replaceOpWithNewOp<ttir::ArangeOp>(
srcOp, outputType, 0, outputType.getDimSize(adaptor.getIotaDimension()),
1, adaptor.getIotaDimension());

// Dynamic Iota has an output_shape attribute but the output shape is
// already known by the result type This is to remove the operand that will
// become dead code
for (auto operand : adaptor.getOperands()) {
if (operand.getDefiningOp()) {
rewriter.eraseOp(operand.getDefiningOp());
}
}

return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1365,6 +1395,15 @@ void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
patterns.add<StableHLOToTTIRGatherOpConversionPattern>(typeConverter, ctx);
}

void addIotaOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpIotaOpConversionPattern<stablehlo::IotaOp>>(
typeConverter, ctx);
patterns
.add<StableHLOToTTIROpIotaOpConversionPattern<stablehlo::DynamicIotaOp>>(
typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1389,6 +1428,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addClampOpConversionPattern(ctx, patterns, typeConverter);
addGatherOpConversionPattern(ctx, patterns, typeConverter);
addIotaOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
139 changes: 139 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Transforms/DialectConversion.h"

#include <algorithm>
#include <cstdint>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -775,6 +776,143 @@ class GetDimensionSizeToConstantConversionPattern
}
};

/*
* This pattern rewrites ArangeOp by forcing the arange_dimension to be
* rightmost dimension of the output tensor. This is done by replacing the
* ArangeOp with a new one that has this property, and then transposing out last
* dimension to the dimension specified by the original ArangeOp, and also
* inserting a reshape to match the rank of the intended output and broadcasts
* to repeat the data along the other dimensions.
*
* The ArangeOp that is generated here will be equivalent to how ttnn::ArangeOp
* behaves. The reason this pass is done in TTIR rather than generated when we
* want to lower to TTNN is because in the future we will want to consteval the
* ArangeOp, but have the option to not include repeated data in the constant
* tensor and broadcast at runtime instead. Consteval will be implemented for
* the TTIR dialect only and so this explication of the TMs implicit in ArangeOp
* must be done in TTIR.
*/
struct ArangeForceLastDimensionPattern
: public OpConversionPattern<ttir::ArangeOp> {
public:
using OpConversionPattern<ttir::ArangeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

const RankedTensorType outputType =
mlir::cast<RankedTensorType>(op.getResult().getType());

int64_t arangeDimension = adaptor.getArangeDimension();
int64_t arangeDimensionNegative = arangeDimension - outputType.getRank();
int64_t start = adaptor.getStart();
int64_t end = adaptor.getEnd();
int64_t step = adaptor.getStep();

int64_t arangeLength = (end - start) / step;

ArrayRef<int64_t> ttnnShape = {1, 1, 1, arangeLength};
if (ttnnShape == outputType.getShape()) {
return success();
}

RankedTensorType arangeOutputType = RankedTensorType::get(
SmallVector<int64_t>({1, 1, 1, arangeLength}),
outputType.getElementType(), outputType.getEncoding());

Value output =
rewriter
.create<ttir::ArangeOp>( // perform arange on the last dimension to
// match how ttnn behaves
op.getLoc(), arangeOutputType, start, end, step, 3)
.getResult();

std::vector<int64_t> outputShape = arangeOutputType.getShape().vec();
// Must transpose the output so that the data changes along the axis defined
// by arangeDimension
if (arangeDimensionNegative != -1) {
std::vector<int64_t> transposeShape = outputShape;
transposeShape[arangeDimensionNegative + transposeShape.size()] =
arangeLength;
transposeShape[arangeOutputType.getRank() - 1] = 1;
RankedTensorType transposeType = RankedTensorType::get(
transposeShape, arangeOutputType.getElementType(),
arangeOutputType.getEncoding());

tensor::EmptyOp dpsOutput = rewriter.create<tensor::EmptyOp>(
op.getLoc(), transposeShape, transposeType.getElementType());

output = rewriter.create<ttir::TransposeOp>(
op.getLoc(), transposeType, output, dpsOutput,
arangeDimensionNegative + transposeShape.size(),
arangeOutputType.getRank() - 1,
rewriter.getArrayAttr(SmallVector<Attribute>(
2, rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

outputShape = transposeShape;
}

// Must match up the rank of the output with the rank of the intended output
// from the original arange, with the arangeDimension in the correct
// position
if (outputType.getRank() != static_cast<int64_t>(outputShape.size())) {
std::vector<int32_t> reshapeShape;
for (uint32_t i = 0; i < outputType.getRank(); i++) {
i == arangeDimension ? reshapeShape.push_back(end)
: reshapeShape.push_back(1);
}

RankedTensorType reshapeType = RankedTensorType::get(
SmallVector<int64_t>(reshapeShape.begin(), reshapeShape.end()),
outputType.getElementType(), outputType.getEncoding());
tensor::EmptyOp dpsOutput = rewriter.create<tensor::EmptyOp>(
op.getLoc(),
SmallVector<int64_t>(reshapeShape.begin(), reshapeShape.end()),
reshapeType.getElementType());
output = rewriter.create<ttir::ReshapeOp>(
op.getLoc(), reshapeType, output, dpsOutput,
rewriter.getI32ArrayAttr(reshapeShape),
rewriter.getArrayAttr(SmallVector<Attribute>(
2, rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

outputShape =
std::vector<int64_t>(reshapeShape.begin(), reshapeShape.end());
}

// Must broadcast the rest of the dimensions
SmallVector<Attribute> broadcastDims;
for (uint32_t i = 0; i < outputShape.size(); i++) {
if (i != arangeDimension && outputShape[i] != outputType.getShape()[i]) {
outputShape[i] = outputType.getShape()[i];
broadcastDims.push_back(rewriter.getI64IntegerAttr(i));
}
}
if (!broadcastDims.empty()) {
RankedTensorType broadcastType = RankedTensorType::get(
outputShape, outputType.getElementType(), outputType.getEncoding());

tensor::EmptyOp dpsOutput = rewriter.create<tensor::EmptyOp>(
op.getLoc(), outputShape, outputType.getElementType());

output = rewriter.create<ttir::BroadcastOp>(
op.getLoc(), broadcastType, output, dpsOutput,
rewriter.getArrayAttr(broadcastDims),
rewriter.getArrayAttr(SmallVector<Attribute>(
2, rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

assert(mlir::cast<RankedTensorType>(output.getType()).getShape() ==
outputType.getShape() &&
"Output shape must match the shape of the input tensor");
}
rewriter.replaceOp(op, output);
return success();
}
};

void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand All @@ -783,6 +921,7 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
patterns.add<ConvolutionToConv2dPattern>(typeConverter, ctx);
patterns.add<GetDimensionSizeToConstantConversionPattern>(typeConverter, ctx);
patterns.add<GatherToEmbeddingConversionPattern>(typeConverter, ctx);
patterns.add<ArangeForceLastDimensionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ struct TTIRToTTIRDecompositionPass
target.addIllegalOp<ttir::PoolingOp>();
target.addIllegalOp<ttir::GatherOp>();

// These are the ops that must satisfy some conditions after this pass
target.addDynamicallyLegalOp<ttir::ArangeOp>([&](ttir::ArangeOp op) {
auto shape = op.getResult().getType().getShape();
return (static_cast<int64_t>(op.getArangeDimension()) == 3 &&
shape.size() == 4 && shape[0] == 1 && shape[1] == 1 &&
shape[2] == 1);
});

TypeConverter typeConverter;
// All types map 1:1.
typeConverter.addConversion([](Type type) { return type; });
Expand Down
31 changes: 30 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,34 @@ class AllGatherOpConversionPattern
}
};

class ArangeOpConversionPattern : public OpConversionPattern<ttir::ArangeOp> {
public:
using OpConversionPattern<ttir::ArangeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ArangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

RankedTensorType outputType =
mlir::cast<RankedTensorType>(op.getResult().getType());
assert(static_cast<int64_t>(adaptor.getArangeDimension()) ==
outputType.getRank() - 1 &&
"Arange dimension must be the final dimension of the output tensor "
"to convert to ttnn.arange");

auto newOutputType = RankedTensorType::get(outputType.getShape(),
outputType.getElementType(),
outputType.getEncoding());

Value device = getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::ArangeOp>(
op, newOutputType, adaptor.getStart(), adaptor.getEnd(),
adaptor.getStep(), device);

return success();
}
};

namespace mlir::tt {

void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Expand Down Expand Up @@ -973,7 +1001,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
AllGatherOpConversionPattern
AllGatherOpConversionPattern,
ArangeOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
// clang-format on
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Tensor ops
//
patterns
.add<EmptyOpConversionPattern, DefaultOpConversionPattern<ttnn::FullOp>>(
typeConverter, ctx);
.add<EmptyOpConversionPattern, DefaultOpConversionPattern<ttnn::FullOp>,
DefaultOpConversionPattern<ttnn::ArangeOp>>(typeConverter, ctx);

// Eltwise unary ops
//
Expand Down
Loading

0 comments on commit fb05a84

Please sign in to comment.