-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay, OpFusion] Better tuple fusion implementation #3092
Conversation
@tqchen I need to fix the tensorflow test, but can you have a look at the patch and let me know if I am on the right track? |
I think we can still simplify it. We do not need to record the inputs of the group. We can still just do the traversal in the node relation group. If there is a relation |
thanks, I was also able to remove kTupleFields, one of two new op patterns I added. I think it is much simpler now. |
@MarisaKirisame @vinx13 please help review this PR |
Ready for review @tqchen @jroesch @MarisaKirisame @zhiics @vinx13 |
Given that there is a lot of recent interest in the fusor, I opened a new issue for better docs #3109 |
include/tvm/relay/op_attr_types.h
Outdated
@@ -41,14 +41,17 @@ enum OpPatternKind { | |||
// for example :code:`out[i, ax1, j, ax2] = input[i, j]`. | |||
// Note that the axis need to be in order so transpose is not a bcast operator. | |||
kBroadcast = 1, | |||
// The pattern for tuple nodes. Can fuse into subsequent injective ops. | |||
kTuple = 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to justify a bit why put kTuple as 2. Because tuple is special, I would rather put it say like 7, and use
pattern <= kInjective || pattern == kTuple
to indicate the pattern is tuple aware
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the pattern of tuple needs to be smaller than kInjective, because when we fuse the tuple into subsequent injective ops, we want the pattern of fused group to be kInjective (CombinePattern returns the larger of the two patterns begin combined).
But I realized that CombinePattern is only called when child group's master ref is non null, so it doesn't work the way I expected. kTuple doesn't have to be smaller than kInjective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also means even if injective ops are fused into a broadcast op, the combined pattern is still kBroadcast. Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, never mind, we can keep this as it is then. Please add a comment on why we choose such order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just changed the kTuple to 7, no change to fuse_ops.cc was needed. I also prefer this since it makes the diff smaller
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us update CombinePattern to specially handle tuple + injective and tuple + broadcast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean, both injective and broadcast win against tuple, even though their op pattern is smaller? (now kTuple is 7)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen do you want to update both CombinePattern and its call site? (it seems CombinePattern is only called when one of its arg is kOutEWiseFusable.)
LGTM |
1 similar comment
LGTM |
@vinx13 @tqchen looks like CI is broken, I'm getting error from topi group conv2d test (verify_group_conv2d_NCHWc_int8). Maybe #3070 is related? http://ci.tvm.ai:8080/blue/organizations/jenkins/tvm/detail/PR-3092/5/pipeline/235 |
See #3039 for the context and discussion.
This is my second cut at fixing tuple fusion, which I hope is a better approach than the rather ad hoc one in #3049 .