forked from onnx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Yan Xu <[email protected]>
- Loading branch information
Showing
12 changed files
with
313 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===------------ DepthToSpace.cpp - Lowering DepthToSpace Op -------------===// | ||
// | ||
// Copyright 2023 | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers the ONNX DepthToSpace Operator to Stablehlo dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToStableHlo/DialectBuilder.hpp" | ||
#include "src/Conversion/ONNXToStableHlo/ONNXToStableHloCommon.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXDepthToSpaceOpLoweringToStablehlo | ||
: public OpConversionPattern<ONNXDepthToSpaceOp> { | ||
ONNXDepthToSpaceOpLoweringToStablehlo(MLIRContext *ctx) | ||
: OpConversionPattern(ctx) {} | ||
|
||
LogicalResult matchAndRewrite(ONNXDepthToSpaceOp depthToSpaceOp, | ||
ONNXDepthToSpaceOpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const final { | ||
Operation *op = depthToSpaceOp.getOperation(); | ||
Location loc = ONNXLoc<ONNXDepthToSpaceOp>(op); | ||
ValueRange operands = adaptor.getOperands(); | ||
Value input = adaptor.getInput(); | ||
|
||
MultiDialectBuilder<IndexExprBuilderForStableHlo, OnnxToStablehloBuilder> | ||
create(rewriter, loc); | ||
ONNXDepthToSpaceOpShapeHelper shapeHelper( | ||
op, operands, &create.stableHloIE); | ||
shapeHelper.computeShapeAndAssertOnFailure(); | ||
|
||
int64_t bs = depthToSpaceOp.getBlocksize(); | ||
StringRef mode = depthToSpaceOp.getMode(); | ||
assert(create.stableHloIE.getShapedTypeRank(input) == 4 && | ||
"Input tensor should have rank equal to 4"); | ||
|
||
// Compute the new dimensions. | ||
|
||
DimIndexExpr B(create.stableHloIE.getShapeAsDim(input, 0)); | ||
DimIndexExpr C(create.stableHloIE.getShapeAsDim(input, 1)); | ||
DimIndexExpr H(create.stableHloIE.getShapeAsDim(input, 2)); | ||
DimIndexExpr W(create.stableHloIE.getShapeAsDim(input, 3)); | ||
DimIndexExpr newC = C.floorDiv(bs * bs); | ||
DimIndexExpr newH = H * bs; | ||
DimIndexExpr newW = W * bs; | ||
|
||
// Compute the output dimension of the first reshape operation, and the | ||
// permutation array for the transpose operation. | ||
LiteralIndexExpr bsLit(bs); | ||
SmallVector<DimIndexExpr, 6> outputDims1; | ||
SmallVector<int64_t, 6> perm; | ||
if (mode == "DCR") { | ||
outputDims1 = {B, bsLit, bsLit, newC, H, W}; | ||
perm = {0, 3, 4, 1, 5, 2}; | ||
} else { | ||
assert(mode == "CRD" && "Unexpected mode"); | ||
outputDims1 = {B, newC, bsLit, bsLit, H, W}; | ||
perm = {0, 1, 4, 2, 5, 3}; | ||
} | ||
|
||
// Reshape input tensor to shape: | ||
// [B, bs, bs, C/(bs*bs), H, W] when mode=DCR | ||
// [B, C/(bs*bs), bs, bs, H, W] when mode=CRD | ||
Value reshapeRes1 = create.stablehloOnnx.reshape(input, outputDims1); | ||
|
||
// Transpose the reshape result into shape [B, C/(bs*bs), H, bs, W, bs]. | ||
SmallVector<DimIndexExpr> outputDims2({B, newC, H, bsLit, W, bsLit}); | ||
Value transposeRes = | ||
create.stablehloOnnx.transpose(reshapeRes1, perm, outputDims2); | ||
|
||
// Reshape the transpose result into shape [B, C/(bs*bs), H*bs, W*bs]. | ||
SmallVector<DimIndexExpr> outputDims3({B, newC, newH, newW}); | ||
Value reshapeRes2 = create.stablehloOnnx.reshape(transposeRes, outputDims3); | ||
|
||
rewriter.replaceOp(op, reshapeRes2); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXDepthToSpaceOpToStableHloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXDepthToSpaceOpLoweringToStablehlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.