Skip to content

Commit

Permalink
add DepthToSpace
Browse files Browse the repository at this point in the history
Signed-off-by: Yan Xu <[email protected]>
  • Loading branch information
Connor-XY committed Dec 27, 2023
1 parent 36f0fe7 commit 9003ce1
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/Conversion/ONNXToStableHlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ add_onnx_mlir_library(OMONNXToStableHlo
Tensor/ArgMax.cpp
Tensor/Concat.cpp
Tensor/Constant.cpp
Tensor/DepthToSpace.cpp
Tensor/Expand.cpp
Tensor/Flatten.cpp
Tensor/Gather.cpp
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/ONNXToStableHlo/ConvertONNXToStableHlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void populateONNXToStableHloConversionPattern(
populateLoweringONNXArgMaxOpToStableHloPattern(patterns, ctx);
populateLoweringONNXConcatOpToStableHloPattern(patterns, ctx);
populateLoweringONNXConstantOpToStableHloPattern(patterns, ctx);
populateLoweringONNXDepthToSpaceOpToStableHloPattern(patterns, ctx);
populateLoweringONNXExpandOpToStableHloPattern(patterns, ctx);
populateLoweringONNXFlattenOpToStableHloPattern(patterns, ctx);
populateLoweringONNXGatherOpToStableHloPattern(patterns, ctx);
Expand Down
113 changes: 96 additions & 17 deletions src/Conversion/ONNXToStableHlo/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "llvm/ADT/TypeSwitch.h"

#include "src/Conversion/ONNXToStableHlo/DialectBuilder.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Support/TypeUtilities.hpp"
#include "stablehlo/dialect/StablehloOps.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;

namespace onnx_mlir {

Value StablehloBuilder::constant(mlir::Type type, double val) const {
// =============================================================================
// stablehlo Builder
// =============================================================================

Value StablehloBuilder::constant(Type type, double val) const {
Value constant = nullptr;
// Could be a vector type; look at the element type.
Type elementType = type;
Expand Down Expand Up @@ -79,47 +82,123 @@ Value StablehloBuilder::constantI64(int64_t val) const {
return b().create<stablehlo::ConstantOp>(loc(), constantAttr);
}

Value StablehloBuilder::shaped_zero(mlir::Type type) const {
Value StablehloBuilder::shaped_zero(Type type) const {
return b().create<stablehlo::ConstantOp>(loc(), b().getZeroAttr(type));
}

Value StablehloBuilder::reshape(Type resultType, Value operand) const {
return b().create<stablehlo::ReshapeOp>(loc(), resultType, operand);
}

mlir::Value StablehloBuilder::real_dynamic_slice(mlir::Type type,
mlir::Value operand, mlir::Value startIndices, mlir::Value limitIndices,
mlir::Value strides) const {
Value StablehloBuilder::dynamic_reshape(
Type type, Value input, Value shape) const {
return b().create<stablehlo::DynamicReshapeOp>(loc(), type, input, shape);
}

Value StablehloBuilder::real_dynamic_slice(Type type, Value operand,
Value startIndices, Value limitIndices, Value strides) const {
return b().create<stablehlo::RealDynamicSliceOp>(
loc(), type, operand, startIndices, limitIndices, strides);
}

mlir::Value StablehloBuilder::dynamic_slice(mlir::Value operand,
Value StablehloBuilder::dynamic_slice(Value operand,
SmallVector<Value> startIndices, SmallVector<int64_t> sliceSizes) const {
return b().create<stablehlo::DynamicSliceOp>(
loc(), operand, startIndices, sliceSizes);
}

mlir::Value StablehloBuilder::dynamic_slice(mlir::Value operand,
Value StablehloBuilder::dynamic_slice(Value operand,
SmallVector<Value> startIndices, DenseI64ArrayAttr sliceSizes) const {
return b().create<stablehlo::DynamicSliceOp>(
loc(), operand, startIndices, sliceSizes);
}

mlir::Value StablehloBuilder::slice(mlir::Value operand,
SmallVector<int64_t> startIndices, SmallVector<int64_t> limitIndices,
SmallVector<int64_t> strides) const {
Value StablehloBuilder::slice(Value operand, SmallVector<int64_t> startIndices,
SmallVector<int64_t> limitIndices, SmallVector<int64_t> strides) const {
return b().create<stablehlo::SliceOp>(
loc(), operand, startIndices, limitIndices, strides);
}

mlir::Value StablehloBuilder::slice(mlir::Value operand,
DenseI64ArrayAttr startIndices, DenseI64ArrayAttr limitIndices,
DenseI64ArrayAttr strides) const {
Value StablehloBuilder::slice(Value operand, DenseI64ArrayAttr startIndices,
DenseI64ArrayAttr limitIndices, DenseI64ArrayAttr strides) const {
return b().create<stablehlo::SliceOp>(
loc(), operand, startIndices, limitIndices, strides);
}

//===----------------------------------------------------------------------===//
// Extends OnnxBuilder with member functions that might generate stablehlo
// related dialect operations.
//===----------------------------------------------------------------------===//

Value OnnxToStablehloBuilder::reshape(
const Value input, const ArrayRef<DimIndexExpr> shapeDims) const {
assert(!shapeDims.empty() && "Shape dimensions should not be empty");

ShapedType inputType = input.getType().cast<ShapedType>();
Type elementType = inputType.getElementType();
MultiDialectBuilder<StablehloBuilder, OnnxBuilder, ShapeBuilder> create(
b(), loc());

// If the output dimensions are all literals the 'onnx/Reshape' operation
// can take the new shape via an 'onnx.Constant'.
if (llvm::all_of(
shapeDims, [](const DimIndexExpr &dim) { return dim.isLiteral(); })) {
SmallVector<int64_t, 6> shape;
for (const IndexExpr &dim : shapeDims)
shape.push_back(dim.getLiteral());

auto constantOp = create.onnx.constantInt64(shape);

Value reshapeRes = create.onnx.reshape(
RankedTensorType::get(shape, elementType), input, constantOp);

return reshapeRes;
}

// When the output dimensions aren't all literals we need to generate code
// to compute the shape.
int64_t length = shapeDims.size();
SmallVector<Value> dims;
for (int64_t i = 0; i < length; ++i) {
Value data = shapeDims[i].getValue();
dims.push_back(data);
}

Value shapeExtent = create.shape.fromExtents(dims);
Value shapeTensor = create.shape.toExtentTensor(
RankedTensorType::get({length}, b().getIndexType()), shapeExtent);
// result shape
SmallVector<int64_t, 6> outputShape;
for (const IndexExpr &dim : shapeDims)
outputShape.push_back(
dim.isLiteral() ? dim.getLiteral() : ShapedType::kDynamic);
Value res = create.stablehlo.dynamic_reshape(
RankedTensorType::get(outputShape, elementType), input, shapeTensor);
return res;
}

Value OnnxToStablehloBuilder::transpose(const Value input,
const ArrayRef<int64_t> perm,
const ArrayRef<DimIndexExpr> outputDims) const {
assert(!outputDims.empty() && "Output dimensions should not be empty");
assert(!perm.empty() && perm.size() == outputDims.size() &&
"Expecting valid permutation array");
MultiDialectBuilder<OnnxBuilder> create(b(), loc());

// Compute the shape of the 'onnx.Transpose' result.
SmallVector<int64_t, 6> shape;
for (const IndexExpr &dim : outputDims)
shape.push_back(dim.isLiteral() ? dim.getLiteral() : ShapedType::kDynamic);

// Create the "onnx.Transpose" operation.
ShapedType inputType = input.getType().cast<ShapedType>();
Value transposeRes = create.onnx.transpose(
RankedTensorType::get(shape, inputType.getElementType()), input,
b().getI64ArrayAttr(perm));

return transposeRes;
}

// =============================================================================
// IndexExpr Builder for Lowering using Shape/StableHlo Dialect.
// =============================================================================
Expand Down
42 changes: 41 additions & 1 deletion src/Conversion/ONNXToStableHlo/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/Mlir/IndexExpr.hpp"
#include "src/Dialect/Mlir/IndexExprBuilder.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"

namespace onnx_mlir {

Expand All @@ -44,6 +45,8 @@ struct StablehloBuilder : DialectBuilder {
mlir::Value shaped_zero(mlir::Type type) const;
// ReshapeOp
mlir::Value reshape(mlir::Type resultType, mlir::Value operand) const;
mlir::Value dynamic_reshape(
mlir::Type type, mlir::Value input, mlir::Value shape) const;
// SliceOp
mlir::Value real_dynamic_slice(mlir::Type type, mlir::Value operand,
mlir::Value startIndices, mlir::Value limitIndices,
Expand Down Expand Up @@ -73,6 +76,30 @@ struct StablehloBuilder : DialectBuilder {
mlir::OpBuilder *patternRewriter;
};

//===----------------------------------------------------------------------===//
// Extends OnnxBuilder with member functions that might generate Stablehlo
// related dialect operations.
//===----------------------------------------------------------------------===//

struct OnnxToStablehloBuilder : public OnnxBuilder {
OnnxToStablehloBuilder(mlir::Location loc) : OnnxBuilder(loc) {}
OnnxToStablehloBuilder(mlir::OpBuilder &b, mlir::Location loc)
: OnnxBuilder(b, loc) {}
OnnxToStablehloBuilder(const DialectBuilder &db) : OnnxBuilder(db) {}
virtual ~OnnxToStablehloBuilder() {}

// Generate an 'onnx.reshape' operation on the 'input' tensor, the new shape
// is provided by 'shapeDims'.
mlir::Value reshape(const mlir::Value input,
const llvm::ArrayRef<DimIndexExpr> shapeDims) const;

// Generate a 'onnx.Transpose' operation on the 'input' tensor given the
// permutation array 'perm' and the operator output dimensions 'outputDims'.
mlir::Value transpose(const mlir::Value input,
const llvm::ArrayRef<int64_t> perm,
const llvm::ArrayRef<DimIndexExpr> outputDims) const;
};

// =============================================================================
// IndexExpr Builder for Shape lowering
// =============================================================================
Expand Down Expand Up @@ -107,7 +134,20 @@ struct MultiDialectBuilder<StablehloBuilder, Ts...>
StablehloBuilder stablehlo;
};

// Recursive class specialized for AffineBuilder refereed to as affine.
// Recursive class specialized for OnnxToStablehloBuilder referred to as
// stablehloOnnx.
template <class... Ts>
struct MultiDialectBuilder<OnnxToStablehloBuilder, Ts...>
: MultiDialectBuilder<Ts...> {
MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc)
: MultiDialectBuilder<Ts...>(b, loc), stablehloOnnx(b, loc) {}
MultiDialectBuilder(const DialectBuilder &db)
: MultiDialectBuilder<Ts...>(db), stablehloOnnx(db) {}
OnnxToStablehloBuilder stablehloOnnx;
};

// Recursive class specialized for IndexExprBuilderForStableHlo referred to as
// stableHloIE.
template <class... Ts>
struct MultiDialectBuilder<IndexExprBuilderForStableHlo, Ts...>
: MultiDialectBuilder<Ts...> {
Expand Down
12 changes: 5 additions & 7 deletions src/Conversion/ONNXToStableHlo/ONNXToStableHloCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,17 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"

#include "stablehlo/dialect/StablehloOps.h"

#include "src/Conversion/ONNXToStableHlo/DialectBuilder.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/Mlir/IndexExpr.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "stablehlo/dialect/StablehloOps.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;

Expand Down Expand Up @@ -182,7 +180,7 @@ void populateLoweringONNXConcatOpToStableHloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXConstantOpToStableHloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXReshapeOpToStableHloPattern(
void populateLoweringONNXDepthToSpaceOpToStableHloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXExpandOpToStableHloPattern(
RewritePatternSet &, MLIRContext *);
Expand Down
99 changes: 99 additions & 0 deletions src/Conversion/ONNXToStableHlo/Tensor/DepthToSpace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------------ DepthToSpace.cpp - Lowering DepthToSpace Op -------------===//
//
// Copyright 2023
//
// =============================================================================
//
// This file lowers the ONNX DepthToSpace Operator to Stablehlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToStableHlo/DialectBuilder.hpp"
#include "src/Conversion/ONNXToStableHlo/ONNXToStableHloCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

struct ONNXDepthToSpaceOpLoweringToStablehlo
: public OpConversionPattern<ONNXDepthToSpaceOp> {
ONNXDepthToSpaceOpLoweringToStablehlo(MLIRContext *ctx)
: OpConversionPattern(ctx) {}

LogicalResult matchAndRewrite(ONNXDepthToSpaceOp depthToSpaceOp,
ONNXDepthToSpaceOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Operation *op = depthToSpaceOp.getOperation();
Location loc = ONNXLoc<ONNXDepthToSpaceOp>(op);
ValueRange operands = adaptor.getOperands();
Value input = adaptor.getInput();

MultiDialectBuilder<IndexExprBuilderForStableHlo, OnnxToStablehloBuilder>
create(rewriter, loc);
ONNXDepthToSpaceOpShapeHelper shapeHelper(
op, operands, &create.stableHloIE);
shapeHelper.computeShapeAndAssertOnFailure();

int64_t bs = depthToSpaceOp.getBlocksize();
StringRef mode = depthToSpaceOp.getMode();
assert(create.stableHloIE.getShapedTypeRank(input) == 4 &&
"Input tensor should have rank equal to 4");

// Compute the new dimensions.

DimIndexExpr B(create.stableHloIE.getShapeAsDim(input, 0));
DimIndexExpr C(create.stableHloIE.getShapeAsDim(input, 1));
DimIndexExpr H(create.stableHloIE.getShapeAsDim(input, 2));
DimIndexExpr W(create.stableHloIE.getShapeAsDim(input, 3));
DimIndexExpr newC = C.floorDiv(bs * bs);
DimIndexExpr newH = H * bs;
DimIndexExpr newW = W * bs;

// Compute the output dimension of the first reshape operation, and the
// permutation array for the transpose operation.
LiteralIndexExpr bsLit(bs);
SmallVector<DimIndexExpr, 6> outputDims1;
SmallVector<int64_t, 6> perm;
if (mode == "DCR") {
outputDims1 = {B, bsLit, bsLit, newC, H, W};
perm = {0, 3, 4, 1, 5, 2};
} else {
assert(mode == "CRD" && "Unexpected mode");
outputDims1 = {B, newC, bsLit, bsLit, H, W};
perm = {0, 1, 4, 2, 5, 3};
}

// Reshape input tensor to shape:
// [B, bs, bs, C/(bs*bs), H, W] when mode=DCR
// [B, C/(bs*bs), bs, bs, H, W] when mode=CRD
Value reshapeRes1 = create.stablehloOnnx.reshape(input, outputDims1);

// Transpose the reshape result into shape [B, C/(bs*bs), H, bs, W, bs].
SmallVector<DimIndexExpr> outputDims2({B, newC, H, bsLit, W, bsLit});
Value transposeRes =
create.stablehloOnnx.transpose(reshapeRes1, perm, outputDims2);

// Reshape the transpose result into shape [B, C/(bs*bs), H*bs, W*bs].
SmallVector<DimIndexExpr> outputDims3({B, newC, newH, newW});
Value reshapeRes2 = create.stablehloOnnx.reshape(transposeRes, outputDims3);

rewriter.replaceOp(op, reshapeRes2);
return success();
}
};

} // namespace

void populateLoweringONNXDepthToSpaceOpToStableHloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXDepthToSpaceOpLoweringToStablehlo>(ctx);
}

} // namespace onnx_mlir
6 changes: 2 additions & 4 deletions src/Conversion/ONNXToStableHlo/Tensor/Expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,9 @@ struct ONNXExpandOpLoweringToStableHlo : public ConversionPattern {
RankedTensorType onesType = RankedTensorType::get(onesShape, elementType);
broadcastedOnes = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, onesType, ones, shape, rewriter.getI64TensorAttr({}));
} else if (ONNXConstantOp shapeOp =
dyn_cast_or_null<ONNXConstantOp>(shapeDefOp)) {
} else if (mlir::ElementsAttr constShape =
getElementAttributeFromConstValue(shape)) {
llvm::SmallVector<int64_t, 4> shapeValues;
mlir::ElementsAttr constShape =
shapeOp.getValueAttr().cast<ElementsAttr>();
for (mlir::IntegerAttr element : constShape.getValues<IntegerAttr>())
shapeValues.push_back(element.getInt());
RankedTensorType broadcastedType =
Expand Down
Loading

0 comments on commit 9003ce1

Please sign in to comment.