-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ingesting Opset 15 Shape Op #1684
Changes from 7 commits
46729cb
f125c6b
7b16406
4e24738
ae9726a
e75b833
39865f5
0716df1
a5d7b24
9209217
7fc0824
e3f773f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -495,12 +495,34 @@ Value normalizeConstantOp( | |
} | ||
|
||
// Create a DenseElementsAttr based on the shape of type. | ||
DenseElementsAttr createDenseElementsAttrFromShape( | ||
PatternRewriter &rewriter, Value value) { | ||
DenseElementsAttr createDenseElementsAttrFromShape(PatternRewriter &rewriter, | ||
Value value, Attribute startAttr, Attribute endAttr) { | ||
// Check that end is provided | ||
|
||
auto inType = value.getType().cast<ShapedType>(); | ||
auto shape = inType.getShape(); | ||
SmallVector<int64_t, 1> dims = {inType.getRank()}; | ||
SmallVector<int64_t, 4> values(shape.begin(), shape.end()); | ||
int64_t rank = inType.getRank(); | ||
|
||
int64_t start = 0; | ||
int64_t end = rank; | ||
|
||
if (startAttr) { | ||
start = startAttr.cast<IntegerAttr>().getSInt(); | ||
} | ||
if (endAttr) { | ||
end = endAttr.cast<IntegerAttr>().getSInt(); | ||
} | ||
|
||
// Normalize if start/end are not in (0, ..., rank) | ||
if (start < 0) { | ||
start = start + rank; | ||
} | ||
if (end < 0) { | ||
end = end + rank; | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it make sense to add asserts to make sure that start and end are now in inclusively 0..rank-1 for start, 0..rank for end? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! This was actually wrong because the behaviour is meant to do be the following
I've rejigged the logic to use |
||
SmallVector<int64_t, 1> dims = {end - start}; | ||
SmallVector<int64_t, 4> values(shape.begin() + start, shape.begin() + end); | ||
auto tensorType = RankedTensorType::get(dims, rewriter.getIntegerType(64)); | ||
return DenseElementsAttr::get(tensorType, makeArrayRef(values)); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should remove commented code. Have you though about adding a custom builder of the operation that does not have the start/end and would automatically generate them.... instead of modifying each of the patterns involved with the new shape op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did consider adding a custom builder. In all honesty I tried and was having some issues getting it to work.
But you makes good points. I have added a Native code call to more succinctly create ONNX Shape ops
And I now call this, to avoid having to use the start and end attributes everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, glad you found a way to simplify the code