Skip to content

Commit

Permalink
Merge branch 'main' into staylor/wf-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
staylorTT authored Nov 22, 2024
2 parents 4bc9370 + c908d52 commit f7bbd7c
Show file tree
Hide file tree
Showing 37 changed files with 1,574 additions and 11 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ query_results.json
run_results.json
ttrt_report.xml
cluster_descriptor.yaml

# TTNN and TTMetal flatbuffers
*.ttnn
*.ttm
70 changes: 70 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,48 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> {
let hasVerifier = 1;
}

def TTIR_ArangeOp : TTIR_Op<"arange"> {
let summary = "Arange operation.";
let description = [{
Tensor arange operation.

Produces a tensor with values from `start` to `end` (exclusive) with a step size of `step`, along the dimension specified by `arange_dimension`.

Examples:
%0 = "ttir.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5xi64>
// %0: [0, 1, 2, 3, 4]

%1 = "ttir.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64, arange_dimension = 0 : i64} : () -> tensor<5xf32>
// %1: [0.0, 2.0, 4.0, 6.0, 8.0]

%2 = "ttir.arange"() {start = 0 : i64, end = 5 : i64, step = 1 : i64, arange_dimension = 0 : i64} : () -> tensor<5x3xi64>
// %2: [
[0, 0, 0],
[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]
]

%3 = "ttir.arange"() {start = 0 : i64, end = 3 : i64, step = 1 : i64, arange_dimension = 1 : i64} : () -> tensor<5x3xi64>
// %3: [
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]
]
}];

let arguments = (ins SI64Attr:$start,
SI64Attr:$end,
SI64Attr:$step,
I64Attr:$arange_dimension);

let results = (outs AnyRankedTensor:$result);
let hasVerifier = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down Expand Up @@ -1091,6 +1133,34 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> {
}];
}

def TTIR_LinearOp : TTIR_DPSOp<"linear"> {
let summary = "Linear transformation of inputs.";
let description = [{
Produces the matmul of tensors `a` and `b` with optional addition with `bias`.

Example:
%a = tensor.empty() : () -> tensor<10x64x32xbf16>
%b = tensor.empty() : () -> tensor<32x128xbf16>
%bias = tensor.empty() : () -> tensor<128xbf16>
%output = tensor.empty() : () -> tensor<10x64x128xbf16>
%0 = "ttir.linear"(%a, %b, %bias, %output) : (tensor<10x64x32xbf16>, tensor<32x128xbf16>, tensor<128xbf16>, tensor<10x64x128xbf16>) -> tensor<10x64x128xbf16>
}];

let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
54 changes: 54 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,34 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> {
let hasVerifier = 1;
}

def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> {
let summary = "Linear transformation of inputs.";

let description = [{
Produces the matmul of tensors `a` and `b` with optional addition with `bias`.

Example:
// %a = [[1., 2.]], [2., 1.]]
// %b = [[0., 1.], [1., 0.]]
// %bias = [[1.]]
"ttnn.linear"(%a, %b, %bias, %result) : (tensor<2x2xf16>, tensor<2x2xf16>, tensor<1xf16>, tensor<2x2xf16>) -> tensor<2x2xf16>
// %result = [[3., 2.], [2., 3.]]
}];

let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output);
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}


// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> {
let arguments = (ins AnyRankedTensor:$a,
Expand Down Expand Up @@ -759,6 +787,32 @@ def TTNN_EmptyOp : TTNN_Op<"empty", [NoMemoryEffect]> {
let hasVerifier = 1;
}

def TTNN_ArangeOp : TTNN_Op<"arange"> {
let summary = "Arange operation.";
let description = [{
Tensor arange operation.

Produces a (1, 1, 1, N)-shaped tensor with values from `start` to `end` (exclusive) with a step size of `step`.

Examples:
%0 = "ttnn.arange"() {start = 0 : i64, end = 5 : i64 step = 1 : i64} : () -> tensor<1x1x1x5xi64>
// %0: [[[[0, 1, 2, 3, 4]]]]

%1 = "ttnn.arange"() {start = 0 : i64, end = 10 : i64, step = 2 : i64} : () -> tensor<1x1x1x5xf32>
// %1: [[[[0.0, 2.0, 4.0, 6.0, 8.0]]]]
}];

let arguments = (ins I64Attr:$start,
I64Attr:$end,
I64Attr:$step,
OptionalAttr<TT_DataTypeAttr>:$dtype,
Optional<TT_Device>:$device,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);

let results = (outs AnyRankedTensor:$result);
let hasVerifier = 1;
}

def TTNN_FullOp : TTNN_Op<"full"> {
let summary = "Full op.";
let description = [{
Expand Down
19 changes: 19 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ table FullOp {
out: tt.target.TensorRef;
}

table ArangeOp {
start: float;
end: float;
step: float;
dtype: tt.target.DataType = null; // optional
device: tt.target.DeviceRef; // optional
memcfg: tt.target.MemoryConfigDesc; // optional
out: tt.target.TensorRef;
}

enum EltwiseOpType: uint32 {
Add = 0,
Multiply = 1,
Expand Down Expand Up @@ -178,6 +188,13 @@ table SliceOp {
step: [int64];
}

table LinearOp {
in0: tt.target.TensorRef;
in1: tt.target.TensorRef;
bias: tt.target.TensorRef;
out: tt.target.TensorRef;
}

// ANCHOR: adding_an_op_matmul_fbs
table MatmulOp {
in0: tt.target.TensorRef;
Expand Down Expand Up @@ -249,6 +266,7 @@ union OpType {
EmptyOp,
FullOp,
EltwiseOp,
LinearOp,
MatmulOp,
ReductionOp,
EmbeddingOp,
Expand All @@ -261,6 +279,7 @@ union OpType {
MaxPool2dOp,
DeallocateOp,
AllGatherOp,
ArangeOp,
}

table Operation {
Expand Down
40 changes: 40 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,36 @@ class StableHLOToTTIRGatherOpConversionPattern
}
};

template <typename SrcIotaOp, typename Adaptor = typename SrcIotaOp::Adaptor>
class StableHLOToTTIROpIotaOpConversionPattern
: public OpConversionPattern<SrcIotaOp> {

using OpConversionPattern<SrcIotaOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(SrcIotaOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

RankedTensorType outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));
rewriter.replaceOpWithNewOp<ttir::ArangeOp>(
srcOp, outputType, 0, outputType.getDimSize(adaptor.getIotaDimension()),
1, adaptor.getIotaDimension());

// Dynamic Iota has an output_shape attribute but the output shape is
// already known by the result type This is to remove the operand that will
// become dead code
for (auto operand : adaptor.getOperands()) {
if (operand.getDefiningOp()) {
rewriter.eraseOp(operand.getDefiningOp());
}
}

return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1365,6 +1395,15 @@ void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
patterns.add<StableHLOToTTIRGatherOpConversionPattern>(typeConverter, ctx);
}

void addIotaOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpIotaOpConversionPattern<stablehlo::IotaOp>>(
typeConverter, ctx);
patterns
.add<StableHLOToTTIROpIotaOpConversionPattern<stablehlo::DynamicIotaOp>>(
typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1389,6 +1428,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addClampOpConversionPattern(ctx, patterns, typeConverter);
addGatherOpConversionPattern(ctx, patterns, typeConverter);
addIotaOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
Loading

0 comments on commit f7bbd7c

Please sign in to comment.