Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ingesting Opset 15 Shape Op #1684

Merged
merged 12 commits into from
Sep 17, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
//
//===----------------------------------------------------------------------===//


// Create an ONNX Shape Op with type
def CreateShapeOp: NativeCodeCall<
"$_builder.create<mlir::ONNXShapeOp>($_loc, $0, $1, IntegerAttr(), 0)"
>;

// Get a type for a tensor that stores the shape of another tensor.
def GetShapeTypeOf: NativeCodeCall<
"RankedTensorType::get({$0.getType().cast<ShapedType>().getRank()}, $_builder.getIntegerType(64))"
Expand All @@ -112,7 +118,7 @@ def expandConstantOperandForAddOp1: Pat<
def expandConstantOperandForAddOp2: Pat<
(ONNXAddOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXAddOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(CreateShapeOp (GetShapeTypeOf $x), $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;
Expand All @@ -124,15 +130,15 @@ def expandConstantOperandForAddOp2: Pat<
def expandConstantOperandForDivOp1: Pat<
(ONNXDivOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXDivOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(CreateShapeOp (GetShapeTypeOf $x), $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;

def expandConstantOperandForDivOp2: Pat<
(ONNXDivOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXDivOp (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(CreateShapeOp (GetShapeTypeOf $x), $x),
(returnType $x)),
$x),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
Expand All @@ -151,7 +157,7 @@ def expandConstantOperandForMulOp1: Pat<
def expandConstantOperandForMulOp2: Pat<
(ONNXMulOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXMulOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(CreateShapeOp (GetShapeTypeOf $x), $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;
Expand All @@ -163,15 +169,15 @@ def expandConstantOperandForMulOp2: Pat<
def expandConstantOperandForSubOp1: Pat<
(ONNXSubOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXSubOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(CreateShapeOp (GetShapeTypeOf $x), $x),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;

def expandConstantOperandForSubOp2: Pat<
(ONNXSubOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXSubOp (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(CreateShapeOp (GetShapeTypeOf $x), $x),
(returnType $x)),
$x),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
Expand Down
4 changes: 2 additions & 2 deletions src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ op_dialect_version_map_["SequenceEmpty"] = {11};
op_dialect_version_map_["SequenceErase"] = {11};
op_dialect_version_map_["SequenceInsert"] = {11};
op_dialect_version_map_["SequenceLength"] = {11};
op_dialect_version_map_["Shape"] = {13};
op_dialect_version_map_["Shape"] = {15};
op_dialect_version_map_["Shrink"] = {9};
op_dialect_version_map_["Sigmoid"] = {13};
op_dialect_version_map_["Sign"] = {13};
Expand Down Expand Up @@ -739,7 +739,7 @@ op_dialect_top_version_map_["SequenceEmpty"] = 11;
op_dialect_top_version_map_["SequenceErase"] = 11;
op_dialect_top_version_map_["SequenceInsert"] = 11;
op_dialect_top_version_map_["SequenceLength"] = 11;
op_dialect_top_version_map_["Shape"] = 13;
op_dialect_top_version_map_["Shape"] = 15;
op_dialect_top_version_map_["Shrink"] = 9;
op_dialect_top_version_map_["Sigmoid"] = 13;
op_dialect_top_version_map_["Sign"] = 13;
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/Tensor/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct ONNXShapeOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Get shape.
ONNXShapeOpAdaptor operandAdaptor(operands);
ONNXShapeOpAdaptor operandAdaptor(operands, op->getAttrDictionary());
ONNXShapeOp shapeOp = cast<ONNXShapeOp>(op);
Location loc = op->getLoc();
ONNXShapeOpShapeHelper shapeHelper(&shapeOp, &rewriter,
Expand Down
32 changes: 31 additions & 1 deletion src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5565,8 +5565,38 @@ def ONNXShapeOp:ONNX_Op<"Shape",
let summary = "ONNX Shape operation";
let description = [{
"Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
"Optional attributes start and end can be used to compute a slice of the input tensor's shape."
"If start axis is omitted, the slice starts from axis 0."
"The end axis, if specified, is exclusive (and the returned value will not include the size of that axis)."
"If the end axis is omitted, the axes upto the last one will be included."
"Negative axes indicate counting back from the last axis."
"Note that axes will be clipped to the range [0, r-1], where r is the"
"rank of the input tensor if they are out-of-range (after adding r in the case of"
"negative axis). Thus, specifying any end value > r is equivalent to specifying an end"
"value of r, and specifying any start value < -r is equivalent to specifying a start"
"value of 0."
""
"For example:"
"Input tensor with shape: [2, 3, 4] "
"No attributes specified."
"Output: [2, 3, 4] "
""
"Input tensor with shape: [2, 3, 4] "
"start: -1"
"Output: [4] "
""
"Input tensor with shape: [2, 3, 4] "
"end: -1"
"Output: [2, 3]"
""
"Input tensor with shape: [2, 3, 4] "
"start: 1"
"end: 2"
"Output: [3] "
}];
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$data);
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$data,
OptionalAttr<SI64Attr>:$end,
DefaultValuedAttr<SI64Attr, "0">:$start);
let results = (outs TensorOf<[I64]>:$shape);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
Expand Down
11 changes: 0 additions & 11 deletions src/Dialect/ONNX/ONNXOpsHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,17 +494,6 @@ Value normalizeConstantOp(
ArrayAttr(), StringAttr(), ArrayAttr());
}

// Create a DenseElementsAttr based on the shape of type.
DenseElementsAttr createDenseElementsAttrFromShape(
PatternRewriter &rewriter, Value value) {
auto inType = value.getType().cast<ShapedType>();
auto shape = inType.getShape();
SmallVector<int64_t, 1> dims = {inType.getRank()};
SmallVector<int64_t, 4> values(shape.begin(), shape.end());
auto tensorType = RankedTensorType::get(dims, rewriter.getIntegerType(64));
return DenseElementsAttr::get(tensorType, makeArrayRef(values));
}

// Create a DenseElementsAttr based on the size of type.
DenseElementsAttr createDenseElementsAttrFromSize(
PatternRewriter &rewriter, Value value) {
Expand Down
4 changes: 0 additions & 4 deletions src/Dialect/ONNX/ONNXOpsHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ mlir::DenseElementsAttr createDenseElementsAttrFromRawBuffer(
mlir::Value normalizeConstantOp(
mlir::PatternRewriter &rewriter, mlir::Value output, mlir::Attribute attr);

// Create a DenseElementsAttr based on the shape of type.
mlir::DenseElementsAttr createDenseElementsAttrFromShape(
mlir::PatternRewriter &rewriter, mlir::Value value);

// Create a DenseElementsAttr based on the size of type.
mlir::DenseElementsAttr createDenseElementsAttrFromSize(
mlir::PatternRewriter &rewriter, mlir::Value value);
Expand Down
31 changes: 31 additions & 0 deletions src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;
Expand Down Expand Up @@ -102,6 +103,36 @@ bool areProducedByTransposeOp(ValueRange values) {
});
}

// Create a DenseElementsAttr based on the shape of type.
DenseElementsAttr createDenseElementsAttrFromShape(PatternRewriter &rewriter,
Value value, int64_t start = 0, llvm::Optional<int64_t> end = llvm::None) {

auto inType = value.getType().cast<ShapedType>();
assert(inType.hasRank() && "inType must be ranked");
auto shape = inType.getShape();
int64_t rank = inType.getRank();

int64_t endValue = end.has_value() ? end.value() : rank;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to have an assert here for sizes, or better, it could be added in the verifier (apologies if it is already there)

SmallVector<int64_t, 1> dims = {endValue - start};
SmallVector<int64_t, 4> values(
shape.begin() + start, shape.begin() + endValue);
auto tensorType = RankedTensorType::get(dims, rewriter.getIntegerType(64));
return DenseElementsAttr::get(tensorType, makeArrayRef(values));
}

// Create a DenseElementsAttr from Shape Op
DenseElementsAttr createDenseElementsAttrFromShapeOp(
PatternRewriter &rewriter, Operation *op) {
ONNXShapeOp shapeOp = llvm::cast<ONNXShapeOp>(op);
ONNXShapeOpAdaptor operandAdaptor(shapeOp);

int64_t start, end;
std::tie(start, end) = getDataShapeBounds(operandAdaptor);

return createDenseElementsAttrFromShape(rewriter, shapeOp.data(), start, end);
}

} // namespace onnx_mlir

// =============================================================================
Expand Down
11 changes: 8 additions & 3 deletions src/Dialect/ONNX/Rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def createDenseElementsAttrFromFloatAttr : NativeCodeCall<
def createDenseElementsAttrFromShape : NativeCodeCall<
"onnx_mlir::createDenseElementsAttrFromShape($_builder, $0)">;

def createDenseElementsAttrFromShapeResult : NativeCodeCall<
"onnx_mlir::createDenseElementsAttrFromShapeOp($_builder, $0.getDefiningOp())">;

// Create a DenseElementsAttr from the size of the type of a value.
def createDenseElementsAttrFromSize : NativeCodeCall<
"onnx_mlir::createDenseElementsAttrFromSize($_builder, $0)">;
Expand Down Expand Up @@ -531,7 +534,9 @@ def SwapReshapeMatMulPattern: Pattern<
// TODO: Support dynamic dimensions.
(ONNXMatMulOp:$res2 (ONNXReshapeOp:$res1 $A, $_, $az), $B),
[(ONNXReshapeOp (ONNXMatMulOp $A, $B, (returnType (GetReturnTypeForMatMulOpND2D $A, $B))),
(ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromShape $res2)), $az)],
(ONNXConstantOpFromDenseAttr
(createDenseElementsAttrFromShape $res2)
), $az)],
[(HasRankGT<2> $A), (HasRankOf<2> $res1), (HasRankOf<2> $B), // A is reshaped to 2D.
(HaveSameLastDim $A, $res1), // The last dim of A is unchanged by reshape.
(IsStaticShapeTensor:$res2) // $res2 has static dims in order to create ReshapeOp.
Expand Down Expand Up @@ -719,9 +724,9 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat<

// TODO: this rules may need to be updated once ShapeOp has start/end.
def ShapeToConstantPattern: Pat<
(ONNXShapeOp $A),
(ONNXShapeOp:$res $A, $end, $start),
(ONNXConstantOp
(GetNullAttr), (createDenseElementsAttrFromShape $A),
(GetNullAttr), (createDenseElementsAttrFromShapeResult $res),
(GetNullFloatAttr), (GetNullArrayAttr), (GetNullIntegerAttr),
(GetNullArrayAttr), (GetNullStringAttr), (GetNullArrayAttr)
),
Expand Down
8 changes: 8 additions & 0 deletions src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#pragma once

#include <utility>

#include "llvm/ADT/BitVector.h"

#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -248,6 +250,12 @@ DECLARE_SHAPE_HELPER(ONNXUnsqueezeOp)
DECLARE_SHAPE_HELPER(ONNXUnsqueezeV11Op)
#undef DECLARE_SHAPE_HELPER

// Compute a slice of the input tensor's shape. The slice starts from axis 0.
// The axes up to the last one will be included. Negative axes indicate counting
// back from the last axis.
std::pair<int64_t, int64_t> getDataShapeBounds(
mlir::ONNXShapeOpAdaptor &operandAdaptor);

// Compute the data selected by the Shape operator.
DimsExpr computeSelectedData(mlir::ONNXShapeOpAdaptor &operandAdaptor);

Expand Down
67 changes: 45 additions & 22 deletions src/Dialect/ONNX/ShapeInference/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,60 +9,83 @@
//===----------------------------------------------------------------------===//

#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
#include <tuple>
#include <utility>

using namespace mlir;

namespace onnx_mlir {

namespace {

// The Shape op spec says:
//
// "Note that axes will be clipped to the range [0, r-1], where r is the
// rank of the input tensor if they are out-of-range (after adding r in the case
// of negative axis). Thus, specifying any end value > r is equivalent to
// specifying an end value of r, and specifying any start value < -r is
// equivalent to specifying a start value of 0."
int64_t normalize(int64_t axis, int64_t rank) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you do here, it sounds reasonable but is it the standard? Or should we instead report an error?
In general, we have erred in favor of reporting errors so as to make the user aware of the issue, instead of "fixing" the model the way we "hope" the user wanted.
As a minimum, I would issue a warning.
Note also that if I recall correctly, sometime the clipping is at rank and sometime rank-1... so this method may not be applied everywhere.

@chentong319 do you have an opinion on this issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if this was confusing. This was actually lifted for the ONNX specification

Shape-15
Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor. Optional attributes start and end can be used to compute a slice of the input tensor's shape. If start axis is omitted, the slice starts from axis 0. The end axis, if specified, is exclusive (and the returned value will not include the size of that axis). If the end axis is omitted, the axes upto the last one will be included. Negative axes indicate counting back from the last axis. Note that axes will be clamped to the range [0, r-1], where r is the rank of the input tensor if they are out-of-range (after adding r in the case of negative axis). Thus, specifying any end value > r is equivalent to specifying an end value of r, and specifying any start value < -r is equivalent to specifying a start value of 0.

It seems weird, but I have based it off the ONNX spec of Shape.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I would then just reflect this policy in the function name, so that others may not be tempted to use it when it is not applicable. Add something of the sort "ClampedPerSpec" or something similar.

if (axis < 0)
axis += rank;

if (axis < 0)
axis = 0;

if (axis > rank)
axis = rank;

return axis;
}

} // namespace

// Compute a slice of the input tensor's shape. The slice starts from axis 0.
// The axes upto the last one will be included. Negative axes indicate counting
// The axes up to the last one will be included. Negative axes indicate counting
// back from the last axis.
static std::pair<int64_t, int64_t> getDataShapeBounds(
std::pair<int64_t, int64_t> getDataShapeBounds(
ONNXShapeOpAdaptor &operandAdaptor) {
Value data = operandAdaptor.data();
MemRefBoundsIndexCapture dataBounds(data);
int64_t dataRank = dataBounds.getRank();
int64_t rank = dataBounds.getRank();

// Compute the normalized start/end. Negative value means counting
// dimensions from the back.
int64_t normalizedStart = 0;
int64_t normalizedEnd = dataRank;
int64_t start = operandAdaptor.start();
int64_t end = rank;
if (operandAdaptor.end().has_value()) {
end = operandAdaptor.end().value();
}

if (normalizedStart < 0)
normalizedStart += dataRank;
if (normalizedEnd < 0)
normalizedEnd += dataRank;

return std::make_pair(normalizedStart, normalizedEnd);
return std::make_pair(normalize(start, rank), normalize(end, rank));
}

LogicalResult ONNXShapeOpShapeHelper::computeShape(
ONNXShapeOpAdaptor operandAdaptor) {
Value data = operandAdaptor.data();
MemRefBoundsIndexCapture dataBounds(data);
int64_t dataRank = dataBounds.getRank();
std::pair<int64_t, int64_t> bounds = getDataShapeBounds(operandAdaptor);

if (bounds.first < 0 || bounds.first > dataRank)
return op->emitError("start value is out of bound");
if (bounds.second < 0 || bounds.second > dataRank)
return op->emitError("end value is out of bound");
int64_t start;
int64_t end;
std::tie(start, end) = getDataShapeBounds(operandAdaptor);

// Output is the actual number of values (1D)
dimsForOutput().emplace_back(LiteralIndexExpr(bounds.second - bounds.first));
dimsForOutput().emplace_back(LiteralIndexExpr(end - start));

return success();
}

// Compute the data selected by the Shape operator.
DimsExpr computeSelectedData(ONNXShapeOpAdaptor &operandAdaptor) {
MemRefBoundsIndexCapture dataBounds(operandAdaptor.data());
std::pair<int64_t, int64_t> bounds = getDataShapeBounds(operandAdaptor);
assert(bounds.first >= 0 && bounds.first <= bounds.second &&
bounds.second <= (int64_t)dataBounds.getRank() && "Unexpected bounds");
int64_t start;
int64_t end;
std::tie(start, end) = getDataShapeBounds(operandAdaptor);
assert(start >= 0 && start <= end && end <= (int64_t)dataBounds.getRank() &&
"Unexpected bounds");

DimsExpr selectedData;
for (int64_t i = bounds.first; i < bounds.second; ++i)
for (int64_t i = start; i < end; ++i)
selectedData.emplace_back(dataBounds.getDim(i));

return selectedData;
Expand Down
Loading