Skip to content

Commit

Permalink
Add support for get_dimension_size op (#1071)
Browse files Browse the repository at this point in the history
- Add end-to-end implementation of the get_dimension_size op in the HLODIalect
- It can be folded in the compiler, on the ttir level, producing a constant op
- Related to issue Add StableHLO to TTIR conversion for get_dimension_size OP #1016
  • Loading branch information
ajakovljevicTT authored Nov 8, 2024
1 parent 7122775 commit 9a16b18
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 0 deletions.
19 changes: 19 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,25 @@ def TTIR_GenericOp : TTIR_DPSOp<"generic", [AttrSizedOperandSegments]> {
}];
}

def TTIR_GetDimensionSizeOp : TTIR_Op<"get_dimension_size"> {
let summary = "GetDimensionSize op.";
let description = [{
Produces the size of the given `dimension` of the `operand`.

Example:
%operand: [[3, 2, 7], [1, 4, 4]]
"ttir.get_dimension_size"(%operand, value = dense<0>, %out) -> %out: [[3]]
}];

let arguments = (ins AnyRankedTensor:$operand,
I32Attr:$dimension);

let results = (outs AnyRankedTensor:$result);

let hasFolder = 1;
let hasVerifier = 1;
}

def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> {
let summary = "Layout op.";
let description = [{
Expand Down
32 changes: 32 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,30 @@ class StableHLOToTTIRDotGeneralOpConversionPattern
}
};

class StableHLOToTTIRGetDimensionSizeOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::GetDimensionSizeOp> {

using OpConversionPattern<
mlir::stablehlo::GetDimensionSizeOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::GetDimensionSizeOp srcOp,
mlir::stablehlo::GetDimensionSizeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
IntegerType intType = IntegerType::get(getContext(), 32);
RankedTensorType outputType = RankedTensorType::get({1}, intType);
mlir::OpBuilder builder(getContext());
IntegerAttr dimension_attr = builder.getIntegerAttr(
intType, static_cast<int32_t>(srcOp.getDimension()));

rewriter.replaceOpWithNewOp<mlir::tt::ttir::GetDimensionSizeOp>(
srcOp, outputType, srcOp.getOperand(), dimension_attr);

return success();
}
};

class StableHLOToTTIRConstantOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ConstantOp> {

Expand Down Expand Up @@ -973,6 +997,13 @@ void addMatmulOpsConversionPatterns(MLIRContext *ctx,
ctx);
}

void addGetDimensionSizeOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRGetDimensionSizeOpConversionPattern>(
typeConverter, ctx);
}

void addTensorCreationOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1048,6 +1079,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReduceOpsConversionPatterns(ctx, patterns, typeConverter);
addTransposeOpsConversionPatterns(ctx, patterns, typeConverter);
addMatmulOpsConversionPatterns(ctx, patterns, typeConverter);
addGetDimensionSizeOpsConversionPatterns(ctx, patterns, typeConverter);
addTensorCreationOpsConversionPatterns(ctx, patterns, typeConverter);
addBroadcastOpConversionPattern(ctx, patterns, typeConverter);
addConv2dOpConversionPattern(ctx, patterns, typeConverter);
Expand Down
29 changes: 29 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,40 @@ struct ConvolutionToConv2dPattern
}
};

class GetDimensionSizeToConstantConversionPattern
: public OpConversionPattern<ttir::GetDimensionSizeOp> {
public:
using OpConversionPattern<ttir::GetDimensionSizeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::GetDimensionSizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(op.getOperand().getType());

int64_t dimensionIndex = op.getDimension();

int32_t dimSize = inputTensorType.getShape()[dimensionIndex];

mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(op.getType());

mlir::ElementsAttr valueAttr =
mlir::DenseElementsAttr::get<int>(valueType, dimSize);

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(op, valueType,
valueAttr);

return success();
}
};

void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<IndexToSliceConversionPattern>(typeConverter, ctx);
patterns.add<ConvolutionToConv2dPattern>(typeConverter, ctx);
patterns.add<GetDimensionSizeToConstantConversionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct TTIRToTTIRDecompositionPass
// These are the ops we intend to remove entirely with this pass
target.addIllegalOp<ttir::IndexOp>();
target.addIllegalOp<ttir::ConvolutionOp>();
target.addIllegalOp<ttir::GetDimensionSizeOp>();

TypeConverter typeConverter;
// All types map 1:1.
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNNPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,46 @@ ::mlir::OpFoldResult mlir::tt::ttir::ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}

//===----------------------------------------------------------------------===//
// GetDimensionSizeOp
//===----------------------------------------------------------------------===//

// GetDimensionSizeOp folder
::mlir::OpFoldResult
mlir::tt::ttir::GetDimensionSizeOp::fold(FoldAdaptor adaptor) {

const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(getOperand().getType());

int64_t dimensionIndex = getDimension();

if (dimensionIndex >=
static_cast<int64_t>(inputTensorType.getShape().size())) {
return nullptr;
};

int32_t dimSize = inputTensorType.getShape()[dimensionIndex];

mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(getType());

return mlir::DenseElementsAttr::get<int>(valueType, dimSize);
}

// GetDimensionSizeOp verification
::mlir::LogicalResult mlir::tt::ttir::GetDimensionSizeOp::verify() {
const RankedTensorType inputTensorType =
mlir::cast<RankedTensorType>(getOperand().getType());

int64_t dimensionIndex = getDimension();

if (dimensionIndex >=
static_cast<int64_t>(inputTensorType.getShape().size())) {
return failure();
};

return success();
}

//===----------------------------------------------------------------------===//
// Conv2dOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/get_dimension_size_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_get_dimension_size attributes {} {
func.func public @test_get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<i32> {
%0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<13x21x3xf32>) -> tensor<i32>
// CHECK: [[VAL:%[0-9]+]] = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<{{[0-9]+}}x{{[0-9]+}}x{{[0-9]+}}xf32>) -> tensor<1xi32>
return %0 : tensor<i32>
// CHECK: return [[VAL]] : tensor<1xi32>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> {
%0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32>
// CHECK: [[VAL:%[0-9]+]] = "ttnn.full"(%{{[0-9]+}}) <{fillValue = 2.100000e+01 : f32}> : (!tt.device<#device>) -> tensor<1xi32, {{.*}}>
return %0 : tensor<1xi32>
// CHECK: return [[VAL]] : tensor<1xi32, {{.*}}>
}
}
7 changes: 7 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,10 @@ func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tens
return %1 : tensor<32x32xf32>
// CHECK: return {{.*}} : tensor<32x32xf32, {{.*}}
}

func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> {
%0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32>
// CHECK: [[VAL:%[0-9]+]] = "ttnn.full"(%{{[0-9]+}}) <{fillValue = 2.100000e+01 : f32}> : (!tt.device<#device>) -> tensor<1xi32, {{.*}}>
return %0 : tensor<1xi32>
// CHECK: return [[VAL]] : tensor<1xi32, {{.*}}>
}

0 comments on commit 9a16b18

Please sign in to comment.