Skip to content

Commit

Permalink
Added LinearOp support
Browse files Browse the repository at this point in the history
  • Loading branch information
azecevicTT committed Nov 22, 2024
1 parent e2c3fe4 commit a58f449
Show file tree
Hide file tree
Showing 16 changed files with 934 additions and 8 deletions.
28 changes: 28 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,34 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> {
}];
}

def TTIR_LinearOp : TTIR_DPSOp<"linear"> {
let summary = "Linear transformation of inputs.";
let description = [{
Produces the matmul of tensors `a` and `b` with optional addition with `bias`.

Example:
%a = tensor.empty() : () -> tensor<10x64x32xbf16>
%b = tensor.empty() : () -> tensor<32x128xbf16>
%bias = tensor.empty() : () -> tensor<128xbf16>
%output = tensor.empty() : () -> tensor<10x64x128xbf16>
%0 = "ttir.linear"(%a, %b, %bias, %output) : (tensor<10x64x32xbf16>, tensor<32x128xbf16>, tensor<128xbf16>, tensor<10x64x128xbf16>) -> tensor<10x64x128xbf16>
}];

let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
28 changes: 28 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,34 @@ def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> {
let hasVerifier = 1;
}

def TTNN_LinearOp : TTNN_NamedDPSOp<"linear"> {
let summary = "Linear transformation of inputs.";

let description = [{
Produces the matmul of tensors `a` and `b` with optional addition with `bias`.

Example:
// %a = [[1., 2.]], [2., 1.]]
// %b = [[0., 1.], [1., 0.]]
// %bias = [[1.]]
"ttnn.linear"(%a, %b, %bias, %result) : (tensor<2x2xf16>, tensor<2x2xf16>, tensor<1xf16>, tensor<2x2xf16>) -> tensor<2x2xf16>
// %result = [[3., 2.], [2., 3.]]
}];

let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output);
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}


// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> {
let arguments = (ins AnyRankedTensor:$a,
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ table SliceOp {
step: [int64];
}

table LinearOp {
in0: tt.target.TensorRef;
in1: tt.target.TensorRef;
bias: tt.target.TensorRef;
out: tt.target.TensorRef;
}

// ANCHOR: adding_an_op_matmul_fbs
table MatmulOp {
in0: tt.target.TensorRef;
Expand Down Expand Up @@ -249,6 +256,7 @@ union OpType {
EmptyOp,
FullOp,
EltwiseOp,
LinearOp,
MatmulOp,
ReductionOp,
EmbeddingOp,
Expand Down
17 changes: 16 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,19 @@ class ConstantOpConversionPattern
}
};

} // namespace
class LinearOpConversionPattern : public OpConversionPattern<ttir::LinearOp> {
public:
using OpConversionPattern<ttir::LinearOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::LinearOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::LinearOp>(
op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(),
adaptor.getB(), adaptor.getBias(), adaptor.getOutput());
return success();
}
};

// ANCHOR: adding_an_op_matmul_op_rewriter
class MatmulOpConversionPattern : public OpConversionPattern<ttir::MatmulOp> {
Expand Down Expand Up @@ -908,6 +920,8 @@ class AllGatherOpConversionPattern
}
};

} // namespace

namespace mlir::tt {

void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Expand Down Expand Up @@ -969,6 +983,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
SqueezeOpConversionPattern,
UnsqueezeOpConversionPattern,
ConstantOpConversionPattern,
LinearOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern,
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,

// Matmul ops
//
patterns.add<DefaultOpConversionPattern<ttnn::MatmulOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::LinearOp>,
DefaultOpConversionPattern<ttnn::MatmulOp>>(typeConverter, ctx);

// Reduction ops
//
Expand Down
152 changes: 152 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,158 @@ mlir::tt::ttir::ToLayoutOp::compoundComponents() {
isMemoryLayoutChange};
}

//===----------------------------------------------------------------------===//
// LinearOp
//===----------------------------------------------------------------------===//

// LinearOp verification
::mlir::LogicalResult mlir::tt::ttir::LinearOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
::mlir::RankedTensorType inputBType = getB().getType();
std::optional<::mlir::RankedTensorType> biasType =
getBias() ? std::make_optional(getBias().getType()) : std::nullopt;
::mlir::RankedTensorType outputType = getOutput().getType();

llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
llvm::SmallVector<int64_t> inputAShape(inputAType.getShape());
llvm::SmallVector<int64_t> inputBShape(inputBType.getShape());

// Verify that the input A is at least 1D tensor.
if (inputAType.getRank() < 1) {
return emitOpError("Input A must be at least a 1D tensor");
}

// Verify that the input B is at least 1D tensor.
if (inputBType.getRank() < 1) {
return emitOpError("Input B must be at least a 1D tensor");
}

// If input A is a vector (1D tensor), 1 is prepended to its dimension for the
// purpose of the matrix multiplication. After the matrix multiplication, the
// prepended dimension is removed.
if (inputAType.getRank() == 1) {
inputAShape.insert(inputAShape.begin(), 1);
}

// If input B is a vector (1D tensor), a 1 is appended to its dimension for
// the purpose of the matrix-vector product and removed afterwards.
if (inputBType.getRank() == 1) {
inputBShape.push_back(1);
}

// Verify that the input A and input B has matching inner dimensions.
if (inputAShape[inputAShape.size() - 1] !=
inputBShape[inputBShape.size() - 2]) {
return emitOpError(
"Input A[-1](" + std::to_string(inputAShape[inputAShape.size() - 1]) +
") and B[-2](" + std::to_string(inputBShape[inputBShape.size() - 2]) +
") must have matching inner dimensions");
}

llvm::SmallVector<int64_t> expectedOutputShape;
// Verify that the batch dimensions are broadcast compatible and construct the
// expected output shape.
if (inputAShape.size() > 2 || inputBShape.size() > 2) {
llvm::SmallVector<int64_t> inputABatchDims, inputBBatchDims;

if (inputAShape.size() > 2) {
inputABatchDims.insert(inputABatchDims.begin(), inputAShape.begin(),
inputAShape.end() - 2);
}

if (inputBShape.size() > 2) {
inputBBatchDims.insert(inputBBatchDims.begin(), inputBShape.begin(),
inputBShape.end() - 2);
}

// Verify that the batch dimensions of input A and B are broadcast
// compatible.
llvm::SmallVector<int64_t, 4> broadcastedShape;
if (!OpTrait::util::getBroadcastedShape(inputABatchDims, inputBBatchDims,
broadcastedShape)) {

return emitOpError("Batch dimensions of input A(" +
ttmlir::utils::join(inputABatchDims, ",") +
") and B(" +
ttmlir::utils::join(inputBBatchDims, ",") +
") are not broadcast compatible");
}

// Insert the broadcasted batch dimensions in the expected output shape.
expectedOutputShape.insert(expectedOutputShape.begin(),
broadcastedShape.begin(),
broadcastedShape.end());
}

// Insert the input A and B inner dimensions in expected output shape.
// Consider the case where input A and B are vectors. In that case,
// the dimension 1 is ommited from the output shape.
if (inputAType.getRank() > 1) {
expectedOutputShape.push_back(inputAShape[inputAShape.size() - 2]);
}

if (inputBType.getRank() > 1) {
expectedOutputShape.push_back(inputBShape[inputBShape.size() - 1]);
}

if (biasType) {
// Verify that the input bias is at least 1D tensor.
if (biasType.value().getRank() < 1) {
return emitOpError("Bias must be at least a 1D tensor");
}

llvm::SmallVector<int64_t> biasShape(biasType.value().getShape());

// Verify that the dimensions of the matmul of A and B are broadcast
// compatible with input bias.
llvm::SmallVector<int64_t> matmulShape = expectedOutputShape;
if (!OpTrait::util::getBroadcastedShape(matmulShape, biasShape,
expectedOutputShape)) {
return emitOpError("Bias shape(" + ttmlir::utils::join(biasShape, ",") +
") is not broadcast compatible with the matmul output "
"shape(" +
ttmlir::utils::join(matmulShape, ",") + ")");
}
}

// Check the case of a vector-vector product. At this moment we don't support
// scalars in IR, hence check that the output is at least 1D tensor of size 1.
if (expectedOutputShape.size() == 0) {
if (outputType.getRank() < 1) {
return emitOpError("Scalar output is not supported, output must be at "
"least a 1D tensor");
}

if (outputType.getRank() > 1 || outputType.getShape()[0] != 1) {
return emitOpError("Scalar output must be a 1D tensor of size 1");
}

return success();
}

// Verify that the output shape dimension count is correct.
if (outputShape.size() != expectedOutputShape.size()) {
return emitOpError("Output shape rank(" +
std::to_string(outputShape.size()) +
") must match the expected output shape rank(" +
std::to_string(expectedOutputShape.size()) + ")");
}

// Verify each dim of the output shape.
for (size_t i = 0; i < outputShape.size(); i++) {
if (outputShape[i] != expectedOutputShape[i]) {
return emitOpError(
"Output shape dimension[" + std::to_string(i) + "](" +
std::to_string(outputShape[i]) +
") doesn't match the expected output shape dimension[" +
std::to_string(i) + "](" + std::to_string(expectedOutputShape[i]) +
")");
}
}

return success();
}

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit a58f449

Please sign in to comment.