-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pass] Fix bug in re-processing call node in MergeComposite pass #4879
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix looks good to me. Could you add a simple unit test to cover this change?
Sure, I'll work on adding a unit test. |
Good catch :) Fix looks to be correct, looking forward to the test case. |
@mbarrett97 @comaniac I just pushed a test where "result" creates an incorrect graph, and "expected" is correct. Even though these two graphs are different, and "result" generated an incorrect function, the two functions generated are computationally equivalent. This means that the test actually passes It is still worth fixing this bug, since this problem blows up when matching large patterns, but I am not sure how to actually fail the test with output like this. Do you have any suggestions? Here are the Relay outputs. The pattern I am trying to match is Result (incorrect output that generated function of
Expected (correct output):
|
@soiferj Sorry. I don't quite understand the problem. Do you mean that these two expressions can pass alpha_equal check? |
Yes, they pass the alpha_equal check. |
@soiferj hmm, this looks a bit weird to me, I will take a look at it. Thanks. |
You could try graph_equal which claims to check for data flow equivalency, although I am a bit surprised alpha_equal doesn't catch this. Failing that, maybe a static traversal is another option. |
graph_equal also succeeds on the buggy graph :( @zhiics, let me know when you have any findings! |
Unless I'm missing something, those graphs don't appear to me to be 'data flow equivalent', so this may be a bug with graph_equal. |
I'll do some investigation :) |
I might not have time today. I can spend some time on it over the weekend. |
hmm, I just tried this: import numpy as np
import tvm
from tvm import relay
from tvm.relay import analysis
from tvm.relay.testing import run_opt_pass
def test():
tt = relay.TensorType([10, 10], "float32")
a = relay.Var("a", tt)
b = relay.Var("b", tt)
sub = relay.subtract(a, b)
x = relay.Var("x", tt)
y = relay.Var("y", tt)
add1 = x + y
add2 = x + add1
add3 = x + y
add4 = add2 + add3
fn = relay.Function([x, y], add4)
fn = fn.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
fn = fn.set_attribute("Composite", tvm.tir.StringImm("add_add_add"))
fn_call = relay.Call(fn, [sub, b])
func = relay.Function([a, b], fn_call)
func = run_opt_pass(func, relay.transform.InferType())
print(func)
tt0 = relay.TensorType([10, 10], "float32")
a0 = relay.Var("a0", tt0)
b0 = relay.Var("b0", tt0)
sub0 = relay.subtract(a0, b0)
x0 = relay.Var("x0", tt0)
y0 = relay.Var("y0", tt0)
add01 = x0 + y0
add02 = x0 + add01
add03 = add02 + add01
fn0 = relay.Function([x0, y0], add03)
fn_call0 = relay.Call(fn0, [sub0, b0])
func0 = relay.Function([a0, b0], fn_call0)
func0 = run_opt_pass(func0, relay.transform.InferType())
print(func0)
assert analysis.alpha_equal(func, func0) It could not pass alpha_equal. Are we missing something here? Can you double check if the program I provided are identical yours here? |
That's really strange - it looks right. Are you able to pull my branch and give the test a try? |
I can give it a try over the weekend, but why do you feed the expected with an unexpected expression? |
Sorry, what exactly do you mean? |
I though you provided
as expected, instead of,
If so, why did you provide that one? The mentioned bug (if it is) is actually a separate issue that doesn't block this PR. So we are good for this PR, right? |
I think so. If everyone else is okay, I think it's best to check this fix in. |
So is the issue that the test case currently passes both pre and post fix? If so, I'd say it probably does block the PR. We need to understand whether alpha_equal is an acceptable way to test equality here and if it's not, use an alternative method. |
Actually, are you sure the second argument is expected? It looks like AlphaEqual loops through the LHS args. This has some weird implications, when I run the other tests and flip the order of Maybe this is because the "expected" doesn't have the composite and primitive attributes? Update: that's the part that's returning false. I think "expected" should be arg one, and we need to set the attributes on the function. |
If I flip the arguments to alpha_equals and properly add the attributes to the "expected" function, the tests work as expected. With everyone's okay, can I go ahead with this change? |
I am okay with it as the failure should be separate issue from |
So long as the test now correctly fails for the current behaviour, I'm happy. The ordering mattering is a concern for all the other tests though which have 'expected' as the second argument. |
This is actually a good exercise, as some other tests are actually failing now due to the graphs not being the same. For example, Expected:
Result:
+1 that we should confirm which order is expected for alpha_equals. |
Thanks everyone for working through this with me! |
@soiferj Can you also try to create a minimal counterexample? Thanks. |
Definitely, I'll do that after fixing this. |
This test (test_branch_merge) is now failing at the de-duplicate pass. It's interesting since the "correct" (result) graph has the exact same function duplicated twice. Whereas the "expected" graph is just re-using the same function. What should be the correct behavior? Update: I am still having some trouble with this test. I will look next week. I am wondering if we should check this fix in and update the tests wholesale in another PR? I was able to fix this test. Will push temporary changes so I can work from another machine this weekend. |
Alright, sorry for all of the spam. All tests are now fixed. cc @zhiics @mbarrett97 |
Thanks for the effort. I will take a look later. So, do we still need to flip the args? If so, we still need to create a repo and open an issue, right? |
It seems like we need to flip the args. I’ll open the issue on Monday. |
@zhiics I merged your changes and updated the branch. Would you mind taking another look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@mbaret PTAL. Let's land this if it looks good to you as well. |
Looks good. |
Thanks @soiferj @mbaret @cbalint13 |
…ss (apache#4879) * Fix bug in re-processing call node * Add test * Add to main * temp changes to work from another machine * fix rest of tests * fix test_reuse_call_merge * fix merge Co-authored-by: Jon Soifer <[email protected]>
…ss (apache#4879) * Fix bug in re-processing call node * Add test * Add to main * temp changes to work from another machine * fix rest of tests * fix test_reuse_call_merge * fix merge Co-authored-by: Jon Soifer <[email protected]>
…ss (apache#4879) * Fix bug in re-processing call node * Add test * Add to main * temp changes to work from another machine * fix rest of tests * fix test_reuse_call_merge * fix merge Co-authored-by: Jon Soifer <[email protected]>
This fixes a bug where call nodes are recursively processed more than once, potentially resulting in a composite function containing duplicate nodes. This change introduces a
call_map
, similar to thevar_map
, to keep track of call nodes that we've processed.I found this bug while writing a pattern for a single-layer transformer.
@mbarrett97 @comaniac would you be able to take a look?