-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Fix operator fusion for multiple output #3871
Conversation
Can you manually construct a Relay program for test? |
@vinx13 I can try. No grantee though as the failure case is generated and complex. |
@vinx13 the test case was found. |
@@ -304,14 +304,16 @@ class PrettyPrinter : | |||
* \return The corresponding name. | |||
*/ | |||
Doc AllocTypeVar(const TypeVar& var) { | |||
if (memo_type_.count(var)) { | |||
Doc val = memo_type_[var]; | |||
val << "-malformed-ir"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we change this to print in a more informative way? for example maybe use a colored highlight to show which part is malformed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no color highlighting in the current doc.
src/relay/pass/fuse_ops.cc
Outdated
auto* ret_group = gmap_.at(tuple)->FindRoot(); | ||
if (ret_group == gmap_.at(tuple)) { | ||
auto* ret_group = gmap_.at(tuple)->FindRoot(); | ||
if (ret_group->root_ref == tuple) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra spaces?
Here is my understanding of the problem, based on @MarisaKirisame's testcast:
Is this correct? @MarisaKirisame |
@masahi yes. the gradient of concatenate (which is split) triggered this problem. |
Can you make a new function that return the LCA of all output nodes, and replace existing implementation based on raw loop? In your test case the new function should return null. I think this is clearer and less ad hoc. |
@masahi like this? |
I mean, I want to make the following snippet easier to understand. I know your fix works, but it took me a while to understand what problem this PR is supposed to solve. Especially the logic around assigning the parent pointer temporarily and letting LCA fix up when the parent should be null seems tricky.
I think the part of the problem is assignment of parent pointer is tightly coupled with LCA computation inside a loop. It would be great if we could replace snippet above with something like (using a mix of non c++ syntax):
|
How about something like this? It passes your test but fails on other fusion tests :)
|
@masahi you have to treat the first node specially. You are not really doing a accumulate, but you are: |
hmm I updated the code below and now it passes all fusion test cases. Not sure if I am doing something wrong. My fuse_ops.cc is here if you want to try it.
|
@masahi I just wrote one that does basically the same thing. How is it? |
Also your std::get<1>(accum) and &std::get<1>(accum) will create very subtle errors that depend on the order of which execute first. |
Yeah it looks better. Thanks :) |
@masahi can you approve explicitly? I can then merge then |
@jroesch sorry I was waiting for the CI to finish :) |
* save * add test * refactor * fix indent * save * refactor
* save * add test * refactor * fix indent * save * refactor
* save * add test * refactor * fix indent * save * refactor
* save * add test * refactor * fix indent * save * refactor
See #3867.
@tqchen @masahi @vinx13
I cant get a test case as it depend on the gradient of a concat which is another WIP. I could add the test afterward though.