Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ingesting Opset 15 Shape Op #1684

Merged
merged 12 commits into from
Sep 17, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
//
//===----------------------------------------------------------------------===//

// Get a type for a tensor that stores the shape of another tensor.
def GetStart: NativeCodeCall<
//"IntegerType::get($_builder, 64, IntegerType::Signed)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should remove commented code. Have you though about adding a custom builder of the operation that does not have the start/end and would automatically generate them.... instead of modifying each of the patterns involved with the new shape op?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider adding a custom builder. In all honesty I tried and was having some issues getting it to work.

But you makes good points. I have added a Native code call to more succinctly create ONNX Shape ops

// Create an ONNX Shape Op with type
def CreateShapeOp: NativeCodeCall<
 "$_builder.create<mlir::ONNXShapeOp>($_loc, $0, $1, IntegerAttr(), 0)"
>;

And I now call this, to avoid having to use the start and end attributes everywhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, glad you found a way to simplify the code

"$_builder.getIntegerAttr($_builder.getIntegerType(64, true), 0)"
>;

def GetEnd: NativeCodeCall<
//"IntegerAttr::get(IntegerType(64, true, APInt(64, $0.getType().cast<ShapedType>().getRank(), true)))"
"$_builder.getIntegerAttr($_builder.getIntegerType(64, true), $0.getType().cast<ShapedType>().getRank())"
>;

// Get a type for a tensor that stores the shape of another tensor.
def GetShapeTypeOf: NativeCodeCall<
"RankedTensorType::get({$0.getType().cast<ShapedType>().getRank()}, $_builder.getIntegerType(64))"
Expand All @@ -112,7 +123,10 @@ def expandConstantOperandForAddOp1: Pat<
def expandConstantOperandForAddOp2: Pat<
(ONNXAddOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXAddOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(ONNXShapeOp $x,
(GetEnd $x),
(GetStart),
(returnType (GetShapeTypeOf $x))),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;
Expand All @@ -124,15 +138,21 @@ def expandConstantOperandForAddOp2: Pat<
def expandConstantOperandForDivOp1: Pat<
(ONNXDivOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXDivOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(ONNXShapeOp $x,
(GetEnd $x),
(GetStart),
(returnType (GetShapeTypeOf $x))),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;

def expandConstantOperandForDivOp2: Pat<
(ONNXDivOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXDivOp (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(ONNXShapeOp $x,
(GetEnd $x),
(GetStart),
(returnType (GetShapeTypeOf $x))),
(returnType $x)),
$x),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
Expand All @@ -151,7 +171,10 @@ def expandConstantOperandForMulOp1: Pat<
def expandConstantOperandForMulOp2: Pat<
(ONNXMulOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXMulOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(ONNXShapeOp $x,
(GetEnd $x),
(GetStart),
(returnType (GetShapeTypeOf $x))),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;
Expand All @@ -163,15 +186,21 @@ def expandConstantOperandForMulOp2: Pat<
def expandConstantOperandForSubOp1: Pat<
(ONNXSubOp $x, (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_)),
(ONNXSubOp $x, (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(ONNXShapeOp $x,
(GetEnd $x),
(GetStart),
(returnType (GetShapeTypeOf $x))),
(returnType $x))),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
>;

def expandConstantOperandForSubOp2: Pat<
(ONNXSubOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x),
(ONNXSubOp (ONNXExpandOp $c,
(ONNXShapeOp $x, (returnType (GetShapeTypeOf $x))),
(ONNXShapeOp $x,
(GetEnd $x),
(GetStart),
(returnType (GetShapeTypeOf $x))),
(returnType $x)),
$x),
[(IsUniBroadcastingFromFirstToSecond $c, $x)]
Expand Down
4 changes: 2 additions & 2 deletions src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ op_dialect_version_map_["SequenceEmpty"] = {11};
op_dialect_version_map_["SequenceErase"] = {11};
op_dialect_version_map_["SequenceInsert"] = {11};
op_dialect_version_map_["SequenceLength"] = {11};
op_dialect_version_map_["Shape"] = {13};
op_dialect_version_map_["Shape"] = {15};
op_dialect_version_map_["Shrink"] = {9};
op_dialect_version_map_["Sigmoid"] = {13};
op_dialect_version_map_["Sign"] = {13};
Expand Down Expand Up @@ -739,7 +739,7 @@ op_dialect_top_version_map_["SequenceEmpty"] = 11;
op_dialect_top_version_map_["SequenceErase"] = 11;
op_dialect_top_version_map_["SequenceInsert"] = 11;
op_dialect_top_version_map_["SequenceLength"] = 11;
op_dialect_top_version_map_["Shape"] = 13;
op_dialect_top_version_map_["Shape"] = 15;
op_dialect_top_version_map_["Shrink"] = 9;
op_dialect_top_version_map_["Sigmoid"] = 13;
op_dialect_top_version_map_["Sign"] = 13;
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/Tensor/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct ONNXShapeOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Get shape.
ONNXShapeOpAdaptor operandAdaptor(operands);
ONNXShapeOpAdaptor operandAdaptor(operands, op->getAttrDictionary());
ONNXShapeOp shapeOp = cast<ONNXShapeOp>(op);
Location loc = op->getLoc();
ONNXShapeOpShapeHelper shapeHelper(&shapeOp, &rewriter,
Expand Down
32 changes: 31 additions & 1 deletion src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -5565,8 +5565,38 @@ def ONNXShapeOp:ONNX_Op<"Shape",
let summary = "ONNX Shape operation";
let description = [{
"Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor."
"Optional attributes start and end can be used to compute a slice of the input tensor's shape."
"If start axis is omitted, the slice starts from axis 0."
"The end axis, if specified, is exclusive (and the returned value will not include the size of that axis)."
"If the end axis is omitted, the axes upto the last one will be included."
"Negative axes indicate counting back from the last axis."
"Note that axes will be clipped to the range [0, r-1], where r is the"
"rank of the input tensor if they are out-of-range (after adding r in the case of"
"negative axis). Thus, specifying any end value > r is equivalent to specifying an end"
"value of r, and specifying any start value < -r is equivalent to specifying a start"
"value of 0."
""
"For example:"
"Input tensor with shape: [2, 3, 4] "
"No attributes specified."
"Output: [2, 3, 4] "
""
"Input tensor with shape: [2, 3, 4] "
"start: -1"
"Output: [4] "
""
"Input tensor with shape: [2, 3, 4] "
"end: -1"
"Output: [2, 3]"
""
"Input tensor with shape: [2, 3, 4] "
"start: 1"
"end: 2"
"Output: [3] "
}];
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$data);
let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>]>:$data,
OptionalAttr<SI64Attr>:$end,
DefaultValuedAttr<SI64Attr, "0">:$start);
let results = (outs TensorOf<[I64]>:$shape);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
Expand Down
30 changes: 26 additions & 4 deletions src/Dialect/ONNX/ONNXOpsHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,12 +495,34 @@ Value normalizeConstantOp(
}

// Create a DenseElementsAttr based on the shape of type.
DenseElementsAttr createDenseElementsAttrFromShape(
PatternRewriter &rewriter, Value value) {
DenseElementsAttr createDenseElementsAttrFromShape(PatternRewriter &rewriter,
Value value, Attribute startAttr, Attribute endAttr) {
// Check that end is provided

auto inType = value.getType().cast<ShapedType>();
auto shape = inType.getShape();
SmallVector<int64_t, 1> dims = {inType.getRank()};
SmallVector<int64_t, 4> values(shape.begin(), shape.end());
int64_t rank = inType.getRank();

int64_t start = 0;
int64_t end = rank;

if (startAttr) {
start = startAttr.cast<IntegerAttr>().getSInt();
}
if (endAttr) {
end = endAttr.cast<IntegerAttr>().getSInt();
}

// Normalize if start/end are not in (0, ..., rank)
if (start < 0) {
start = start + rank;
}
if (end < 0) {
end = end + rank;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to add asserts to make sure that start and end are now in inclusively 0..rank-1 for start, 0..rank for end?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! This was actually wrong because the behaviour is meant to do be the following

"Optional attributes start and end can be used to compute a slice of the input tensor's shape."
"If start axis is omitted, the slice starts from axis 0."
"The end axis, if specified, is exclusive (and the returned value will not include the size of that axis)."
"If the end axis is omitted, the axes upto the last one will be included."
"Negative axes indicate counting back from the last axis."
"Note that axes will be clipped to the range [0, r-1], where r is the"
"rank of the input tensor if they are out-of-range (after adding r in the case of"
"negative axis). Thus, specifying any end value > r is equivalent to specifying an end"
"value of r, and specifying any start value < -r is equivalent to specifying a start"
"value of 0."

I've rejigged the logic to use getDataShapeBounds which handles the clipping logic. This puts the normalizing behavior in one place, where it is handled correctly

SmallVector<int64_t, 1> dims = {end - start};
SmallVector<int64_t, 4> values(shape.begin() + start, shape.begin() + end);
auto tensorType = RankedTensorType::get(dims, rewriter.getIntegerType(64));
return DenseElementsAttr::get(tensorType, makeArrayRef(values));
}
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/ONNXOpsHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ mlir::Value normalizeConstantOp(

// Create a DenseElementsAttr based on the shape of type.
mlir::DenseElementsAttr createDenseElementsAttrFromShape(
mlir::PatternRewriter &rewriter, mlir::Value value);
mlir::PatternRewriter &rewriter, mlir::Value value,
mlir::Attribute startAttr, mlir::Attribute endAttr);

// Create a DenseElementsAttr based on the size of type.
mlir::DenseElementsAttr createDenseElementsAttrFromSize(
Expand Down
10 changes: 6 additions & 4 deletions src/Dialect/ONNX/Rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def createDenseElementsAttrFromFloatAttr : NativeCodeCall<

// Create a DenseElementsAttr from the shape of the type of a value.
def createDenseElementsAttrFromShape : NativeCodeCall<
"onnx_mlir::createDenseElementsAttrFromShape($_builder, $0)">;
"onnx_mlir::createDenseElementsAttrFromShape($_builder, $0, $1, $2)">;

// Create a DenseElementsAttr from the size of the type of a value.
def createDenseElementsAttrFromSize : NativeCodeCall<
Expand Down Expand Up @@ -531,7 +531,9 @@ def SwapReshapeMatMulPattern: Pattern<
// TODO: Support dynamic dimensions.
(ONNXMatMulOp:$res2 (ONNXReshapeOp:$res1 $A, $_, $az), $B),
[(ONNXReshapeOp (ONNXMatMulOp $A, $B, (returnType (GetReturnTypeForMatMulOpND2D $A, $B))),
(ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromShape $res2)), $az)],
(ONNXConstantOpFromDenseAttr
(createDenseElementsAttrFromShape $res2, (GetNullAttr), (GetNullAttr))
), $az)],
[(HasRankGT<2> $A), (HasRankOf<2> $res1), (HasRankOf<2> $B), // A is reshaped to 2D.
(HaveSameLastDim $A, $res1), // The last dim of A is unchanged by reshape.
(IsStaticShapeTensor:$res2) // $res2 has static dims in order to create ReshapeOp.
Expand Down Expand Up @@ -719,9 +721,9 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat<

// TODO: this rules may need to be updated once ShapeOp has start/end.
def ShapeToConstantPattern: Pat<
(ONNXShapeOp $A),
(ONNXShapeOp $A, $end, $start),
(ONNXConstantOp
(GetNullAttr), (createDenseElementsAttrFromShape $A),
(GetNullAttr), (createDenseElementsAttrFromShape $A, $start, $end),
(GetNullFloatAttr), (GetNullArrayAttr), (GetNullIntegerAttr),
(GetNullArrayAttr), (GetNullStringAttr), (GetNullArrayAttr)
),
Expand Down
69 changes: 47 additions & 22 deletions src/Dialect/ONNX/ShapeInference/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,60 +9,85 @@
//===----------------------------------------------------------------------===//

#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
#include <tuple>
#include <utility>

using namespace mlir;

namespace onnx_mlir {

namespace {

// If start axis is omitted, the slice starts from axis 0.
// The end axis, if specified, is exclusive (and the returned value will not
// include the size of that axis). If the end axis is omitted, the axes upto the
// last one will be included. Negative axes indicate counting back from the last
// axis. Note that axes will be clipped to the range [0, r-1], where r is the
// rank of the input tensor if they are out-of-range (after adding r in the case
// of negative axis). Thus, specifying any end value > r is equivalent to
// specifying an end value of r, and specifying any start value < -r is
// equivalent to specifying a start value of 0.
int64_t normalize(int64_t value, int64_t rank) {
if (value < 0)
value += rank;

if (value < 0)
value = 0;

if (value > rank)
value = rank;

return value;
}

// Compute a slice of the input tensor's shape. The slice starts from axis 0.
// The axes upto the last one will be included. Negative axes indicate counting
// The axes up to the last one will be included. Negative axes indicate counting
// back from the last axis.
static std::pair<int64_t, int64_t> getDataShapeBounds(
std::pair<int64_t, int64_t> getDataShapeBounds(
ONNXShapeOpAdaptor &operandAdaptor) {
Value data = operandAdaptor.data();
MemRefBoundsIndexCapture dataBounds(data);
int64_t dataRank = dataBounds.getRank();
int64_t rank = dataBounds.getRank();

// Compute the normalized start/end. Negative value means counting
// dimensions from the back.
int64_t normalizedStart = 0;
int64_t normalizedEnd = dataRank;

if (normalizedStart < 0)
normalizedStart += dataRank;
if (normalizedEnd < 0)
normalizedEnd += dataRank;
int64_t start = operandAdaptor.start();
int64_t end = rank;
if (operandAdaptor.end().has_value()) {
end = operandAdaptor.end().value();
}

return std::make_pair(normalizedStart, normalizedEnd);
return std::make_pair(normalize(start, rank), normalize(end, rank));
}

} // namespace

LogicalResult ONNXShapeOpShapeHelper::computeShape(
ONNXShapeOpAdaptor operandAdaptor) {
Value data = operandAdaptor.data();
MemRefBoundsIndexCapture dataBounds(data);
int64_t dataRank = dataBounds.getRank();
std::pair<int64_t, int64_t> bounds = getDataShapeBounds(operandAdaptor);

if (bounds.first < 0 || bounds.first > dataRank)
return op->emitError("start value is out of bound");
if (bounds.second < 0 || bounds.second > dataRank)
return op->emitError("end value is out of bound");
int64_t start;
int64_t end;
std::tie(start, end) = getDataShapeBounds(operandAdaptor);

// Output is the actual number of values (1D)
dimsForOutput().emplace_back(LiteralIndexExpr(bounds.second - bounds.first));
dimsForOutput().emplace_back(LiteralIndexExpr(end - start));

return success();
}

// Compute the data selected by the Shape operator.
DimsExpr computeSelectedData(ONNXShapeOpAdaptor &operandAdaptor) {
MemRefBoundsIndexCapture dataBounds(operandAdaptor.data());
std::pair<int64_t, int64_t> bounds = getDataShapeBounds(operandAdaptor);
assert(bounds.first >= 0 && bounds.first <= bounds.second &&
bounds.second <= (int64_t)dataBounds.getRank() && "Unexpected bounds");
int64_t start;
int64_t end;
std::tie(start, end) = getDataShapeBounds(operandAdaptor);
assert(start >= 0 && start <= end && end <= (int64_t)dataBounds.getRank() &&
"Unexpected bounds");

DimsExpr selectedData;
for (int64_t i = bounds.first; i < bounds.second; ++i)
for (int64_t i = start; i < end; ++i)
selectedData.emplace_back(dataBounds.getDim(i));

return selectedData;
Expand Down
Loading