-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ingesting Opset 15 Shape Op #1684
Changes from 10 commits
46729cb
f125c6b
7b16406
4e24738
ae9726a
e75b833
39865f5
0716df1
a5d7b24
9209217
7fc0824
e3f773f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,60 +9,83 @@ | |
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp" | ||
#include <tuple> | ||
#include <utility> | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
// The Shape op spec says: | ||
// | ||
// "Note that axes will be clipped to the range [0, r-1], where r is the | ||
// rank of the input tensor if they are out-of-range (after adding r in the case | ||
// of negative axis). Thus, specifying any end value > r is equivalent to | ||
// specifying an end value of r, and specifying any start value < -r is | ||
// equivalent to specifying a start value of 0." | ||
int64_t normalize(int64_t axis, int64_t rank) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see what you do here, it sounds reasonable but is it the standard? Or should we instead report an error? @chentong319 do you have an opinion on this issue? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry if this was confusing. This was actually lifted for the ONNX specification
It seems weird, but I have based it off the ONNX spec of Shape. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, I would then just reflect this policy in the function name, so that others may not be tempted to use it when it is not applicable. Add something of the sort "ClampedPerSpec" or something similar. |
||
if (axis < 0) | ||
axis += rank; | ||
|
||
if (axis < 0) | ||
axis = 0; | ||
|
||
if (axis > rank) | ||
axis = rank; | ||
|
||
return axis; | ||
} | ||
|
||
} // namespace | ||
|
||
// Compute a slice of the input tensor's shape. The slice starts from axis 0. | ||
// The axes upto the last one will be included. Negative axes indicate counting | ||
// The axes up to the last one will be included. Negative axes indicate counting | ||
// back from the last axis. | ||
static std::pair<int64_t, int64_t> getDataShapeBounds( | ||
std::pair<int64_t, int64_t> getDataShapeBounds( | ||
ONNXShapeOpAdaptor &operandAdaptor) { | ||
Value data = operandAdaptor.data(); | ||
MemRefBoundsIndexCapture dataBounds(data); | ||
int64_t dataRank = dataBounds.getRank(); | ||
int64_t rank = dataBounds.getRank(); | ||
|
||
// Compute the normalized start/end. Negative value means counting | ||
// dimensions from the back. | ||
int64_t normalizedStart = 0; | ||
int64_t normalizedEnd = dataRank; | ||
int64_t start = operandAdaptor.start(); | ||
int64_t end = rank; | ||
if (operandAdaptor.end().has_value()) { | ||
end = operandAdaptor.end().value(); | ||
} | ||
|
||
if (normalizedStart < 0) | ||
normalizedStart += dataRank; | ||
if (normalizedEnd < 0) | ||
normalizedEnd += dataRank; | ||
|
||
return std::make_pair(normalizedStart, normalizedEnd); | ||
return std::make_pair(normalize(start, rank), normalize(end, rank)); | ||
} | ||
|
||
LogicalResult ONNXShapeOpShapeHelper::computeShape( | ||
ONNXShapeOpAdaptor operandAdaptor) { | ||
Value data = operandAdaptor.data(); | ||
MemRefBoundsIndexCapture dataBounds(data); | ||
int64_t dataRank = dataBounds.getRank(); | ||
std::pair<int64_t, int64_t> bounds = getDataShapeBounds(operandAdaptor); | ||
|
||
if (bounds.first < 0 || bounds.first > dataRank) | ||
return op->emitError("start value is out of bound"); | ||
if (bounds.second < 0 || bounds.second > dataRank) | ||
return op->emitError("end value is out of bound"); | ||
int64_t start; | ||
int64_t end; | ||
std::tie(start, end) = getDataShapeBounds(operandAdaptor); | ||
|
||
// Output is the actual number of values (1D) | ||
dimsForOutput().emplace_back(LiteralIndexExpr(bounds.second - bounds.first)); | ||
dimsForOutput().emplace_back(LiteralIndexExpr(end - start)); | ||
|
||
return success(); | ||
} | ||
|
||
// Compute the data selected by the Shape operator. | ||
DimsExpr computeSelectedData(ONNXShapeOpAdaptor &operandAdaptor) { | ||
MemRefBoundsIndexCapture dataBounds(operandAdaptor.data()); | ||
std::pair<int64_t, int64_t> bounds = getDataShapeBounds(operandAdaptor); | ||
assert(bounds.first >= 0 && bounds.first <= bounds.second && | ||
bounds.second <= (int64_t)dataBounds.getRank() && "Unexpected bounds"); | ||
int64_t start; | ||
int64_t end; | ||
std::tie(start, end) = getDataShapeBounds(operandAdaptor); | ||
assert(start >= 0 && start <= end && end <= (int64_t)dataBounds.getRank() && | ||
"Unexpected bounds"); | ||
|
||
DimsExpr selectedData; | ||
for (int64_t i = bounds.first; i < bounds.second; ++i) | ||
for (int64_t i = start; i < end; ++i) | ||
selectedData.emplace_back(dataBounds.getDim(i)); | ||
|
||
return selectedData; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to have an assert here for sizes, or better, it could be added in the verifier (apologies if it is already there)