Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Fix bug in re-processing call node in MergeComposite pass #4879

Merged
merged 10 commits into from
Feb 17, 2020

Conversation

soiferj
Copy link
Contributor

@soiferj soiferj commented Feb 14, 2020

This fixes a bug where call nodes are recursively processed more than once, potentially resulting in a composite function containing duplicate nodes. This change introduces a call_map, similar to the var_map, to keep track of call nodes that we've processed.

I found this bug while writing a pattern for a single-layer transformer.

@mbarrett97 @comaniac would you be able to take a look?

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix looks good to me. Could you add a simple unit test to cover this change?

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

Sure, I'll work on adding a unit test.

@tqchen tqchen added status: need test case need test cases to cover the change status: need update need update based on feedbacks labels Feb 14, 2020
@mbaret
Copy link
Contributor

mbaret commented Feb 14, 2020

Good catch :) Fix looks to be correct, looking forward to the test case.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

@mbarrett97 @comaniac I just pushed a test where "result" creates an incorrect graph, and "expected" is correct. Even though these two graphs are different, and "result" generated an incorrect function, the two functions generated are computationally equivalent. This means that the test actually passes alpha_equal both with and without the bug.

It is still worth fixing this bug, since this problem blows up when matching large patterns, but I am not sure how to actually fail the test with output like this. Do you have any suggestions?

Here are the Relay outputs. The pattern I am trying to match is add -> add -> add.

Result (incorrect output that generated function of add -> add -> add -> add):

v0.0.4
fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
  %0 = subtract(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %4 = fn (%x: Tensor[(10, 10), float32], %y: Tensor[(10, 10), float32], Primitive=1, Composite="add_add_add") -> Tensor[(10, 10), float32] {
    %1 = add(%x, %y) /* ty=Tensor[(10, 10), float32] */;
    %2 = add(%x, %1) /* ty=Tensor[(10, 10), float32] */;
    %3 = add(%x, %y) /* ty=Tensor[(10, 10), float32] */;
    add(%2, %3) /* ty=Tensor[(10, 10), float32] */
  };
  %4(%0, %b) /* ty=Tensor[(10, 10), float32] */
}

Expected (correct output):

v0.0.4
fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
  %0 = subtract(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %3 = fn (%in_1: Tensor[(10, 10), float32], %in_2: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
    %1 = add(%in_1, %in_2) /* ty=Tensor[(10, 10), float32] */;
    %2 = add(%in_1, %1) /* ty=Tensor[(10, 10), float32] */;
    add(%2, %1) /* ty=Tensor[(10, 10), float32] */
  };
  %3(%0, %b) /* ty=Tensor[(10, 10), float32] */
}

@zhiics
Copy link
Member

zhiics commented Feb 14, 2020

@soiferj Sorry. I don't quite understand the problem. Do you mean that these two expressions can pass alpha_equal check?

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

Yes, they pass the alpha_equal check.

@zhiics
Copy link
Member

zhiics commented Feb 14, 2020

@soiferj hmm, this looks a bit weird to me, I will take a look at it. Thanks.

@mbaret
Copy link
Contributor

mbaret commented Feb 14, 2020

You could try graph_equal which claims to check for data flow equivalency, although I am a bit surprised alpha_equal doesn't catch this. Failing that, maybe a static traversal is another option.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

graph_equal also succeeds on the buggy graph :( @zhiics, let me know when you have any findings!

@mbaret
Copy link
Contributor

mbaret commented Feb 14, 2020

Unless I'm missing something, those graphs don't appear to me to be 'data flow equivalent', so this may be a bug with graph_equal.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

I'll do some investigation :)

@zhiics
Copy link
Member

zhiics commented Feb 14, 2020

I might not have time today. I can spend some time on it over the weekend.

@zhiics
Copy link
Member

zhiics commented Feb 14, 2020

hmm, I just tried this:

import numpy as np
import tvm
from tvm import relay
from tvm.relay import analysis
from tvm.relay.testing import run_opt_pass

def test():
    tt = relay.TensorType([10, 10], "float32")
    a = relay.Var("a", tt)
    b = relay.Var("b", tt)
    sub = relay.subtract(a, b)

    x = relay.Var("x", tt)
    y = relay.Var("y", tt)

    add1 = x + y
    add2 = x + add1
    add3 = x + y
    add4 = add2 + add3

    fn = relay.Function([x, y], add4)
    fn = fn.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
    fn = fn.set_attribute("Composite", tvm.tir.StringImm("add_add_add"))
    fn_call = relay.Call(fn, [sub, b])

    func = relay.Function([a, b], fn_call)
    func = run_opt_pass(func, relay.transform.InferType())
    print(func)

    tt0 = relay.TensorType([10, 10], "float32")
    a0 = relay.Var("a0", tt0)
    b0 = relay.Var("b0", tt0)
    sub0 = relay.subtract(a0, b0)

    x0 = relay.Var("x0", tt0)
    y0 = relay.Var("y0", tt0)

    add01 = x0 + y0
    add02 = x0 + add01
    add03 = add02 + add01

    fn0 = relay.Function([x0, y0], add03)
    fn_call0 = relay.Call(fn0, [sub0, b0])
    func0 = relay.Function([a0, b0], fn_call0)
    func0 = run_opt_pass(func0, relay.transform.InferType())

    print(func0)
    assert analysis.alpha_equal(func, func0)

It could not pass alpha_equal. Are we missing something here? Can you double check if the program I provided are identical yours here?

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

That's really strange - it looks right. Are you able to pull my branch and give the test a try?

@zhiics
Copy link
Member

zhiics commented Feb 14, 2020

I can give it a try over the weekend, but why do you feed the expected with an unexpected expression?

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

Sorry, what exactly do you mean?

@zhiics
Copy link
Member

zhiics commented Feb 14, 2020

I though you provided

v0.0.4
fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
  %0 = subtract(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %3 = fn (%in_1: Tensor[(10, 10), float32], %in_2: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
    %1 = add(%in_1, %in_2) /* ty=Tensor[(10, 10), float32] */;
    %2 = add(%in_1, %1) /* ty=Tensor[(10, 10), float32] */;
    add(%2, %1) /* ty=Tensor[(10, 10), float32] */
  };
  %3(%0, %b) /* ty=Tensor[(10, 10), float32] */
}

as expected, instead of,

v0.0.4
fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
  %0 = subtract(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %4 = fn (%x: Tensor[(10, 10), float32], %y: Tensor[(10, 10), float32], Primitive=1, Composite="add_add_add") -> Tensor[(10, 10), float32] {
    %1 = add(%x, %y) /* ty=Tensor[(10, 10), float32] */;
    %2 = add(%x, %1) /* ty=Tensor[(10, 10), float32] */;
    %3 = add(%x, %y) /* ty=Tensor[(10, 10), float32] */;
    add(%2, %3) /* ty=Tensor[(10, 10), float32] */
  };
  %4(%0, %b) /* ty=Tensor[(10, 10), float32] */
}

If so, why did you provide that one? The mentioned bug (if it is) is actually a separate issue that doesn't block this PR. So we are good for this PR, right?

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

I think so. If everyone else is okay, I think it's best to check this fix in.

@mbaret
Copy link
Contributor

mbaret commented Feb 14, 2020

So is the issue that the test case currently passes both pre and post fix? If so, I'd say it probably does block the PR. We need to understand whether alpha_equal is an acceptable way to test equality here and if it's not, use an alternative method.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 14, 2020

Actually, are you sure the second argument is expected? It looks like AlphaEqual loops through the LHS args. This has some weird implications, when I run the other tests and flip the order of result and expected in alpha_equal, they fail.

Maybe this is because the "expected" doesn't have the composite and primitive attributes?

Update: that's the part that's returning false. I think "expected" should be arg one, and we need to set the attributes on the function.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 15, 2020

If I flip the arguments to alpha_equals and properly add the attributes to the "expected" function, the tests work as expected. With everyone's okay, can I go ahead with this change?

@zhiics
Copy link
Member

zhiics commented Feb 15, 2020

I am okay with it as the failure should be separate issue from alpha_equal. But I would suggest we create minimal example to reproduce the bug and open an issue for it so that ppl can conveniently look into it.

@mbaret
Copy link
Contributor

mbaret commented Feb 15, 2020

So long as the test now correctly fails for the current behaviour, I'm happy. The ordering mattering is a concern for all the other tests though which have 'expected' as the second argument.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 15, 2020

This is actually a good exercise, as some other tests are actually failing now due to the graphs not being the same. For example, test_branch_merge. @mbarrett97, I'll push the changes, let me know what you think.

Expected:

fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32], %c: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
  %2 = fn (%in_1: Tensor[(10, 10), float32], %in_2: Tensor[(10, 10), float32], Composite="add_sub_mul", Primitive=1) -> Tensor[(10, 10), float32] {
    %0 = add(%in_1, %in_2) /* ty=Tensor[(10, 10), float32] */;
    %1 = subtract(%in_1, %in_2) /* ty=Tensor[(10, 10), float32] */;
    multiply(%0, %1) /* ty=Tensor[(10, 10), float32] */
  };
  %3 = %2(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %4 = %2(%c, %3) /* ty=Tensor[(10, 10), float32] */;
  nn.relu(%4) /* ty=Tensor[(10, 10), float32] */
}

Result:

fn (%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32], %c: Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] {
  %2 = fn (%x: Tensor[(10, 10), float32], %y: Tensor[(10, 10), float32], Primitive=1, Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %0 = add(%x, %y) /* ty=Tensor[(10, 10), float32] */;
    %1 = subtract(%x, %y) /* ty=Tensor[(10, 10), float32] */;
    multiply(%0, %1) /* ty=Tensor[(10, 10), float32] */
  };
  %3 = %2(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %6 = fn (%x1: Tensor[(10, 10), float32], %y1: Tensor[(10, 10), float32], Primitive=1, Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %4 = add(%x1, %y1) /* ty=Tensor[(10, 10), float32] */;
    %5 = subtract(%x1, %y1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%4, %5) /* ty=Tensor[(10, 10), float32] */
  };
  %7 = %6(%c, %3) /* ty=Tensor[(10, 10), float32] */;
  nn.relu(%7) /* ty=Tensor[(10, 10), float32] */
}

+1 that we should confirm which order is expected for alpha_equals.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 15, 2020

Thanks everyone for working through this with me!

@zhiics
Copy link
Member

zhiics commented Feb 15, 2020

@soiferj Can you also try to create a minimal counterexample? Thanks.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 15, 2020

Definitely, I'll do that after fixing this.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 15, 2020

This test (test_branch_merge) is now failing at the de-duplicate pass. It's interesting since the "correct" (result) graph has the exact same function duplicated twice. Whereas the "expected" graph is just re-using the same function. What should be the correct behavior?

Update: I am still having some trouble with this test. I will look next week. I am wondering if we should check this fix in and update the tests wholesale in another PR?

I was able to fix this test. Will push temporary changes so I can work from another machine this weekend.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 15, 2020

Alright, sorry for all of the spam. All tests are now fixed. cc @zhiics @mbarrett97

@zhiics
Copy link
Member

zhiics commented Feb 15, 2020

Thanks for the effort. I will take a look later. So, do we still need to flip the args? If so, we still need to create a repo and open an issue, right?

@soiferj
Copy link
Contributor Author

soiferj commented Feb 15, 2020

It seems like we need to flip the args. I’ll open the issue on Monday.

@soiferj
Copy link
Contributor Author

soiferj commented Feb 17, 2020

@zhiics I merged your changes and updated the branch. Would you mind taking another look?

Copy link
Member

@zhiics zhiics left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiics
Copy link
Member

zhiics commented Feb 17, 2020

@mbaret PTAL. Let's land this if it looks good to you as well.

@mbaret
Copy link
Contributor

mbaret commented Feb 17, 2020

Looks good.

@zhiics zhiics merged commit 27a0284 into apache:master Feb 17, 2020
@zhiics
Copy link
Member

zhiics commented Feb 17, 2020

Thanks @soiferj @mbaret @cbalint13

@zhiics zhiics added status: accepted and removed status: need test case need test cases to cover the change status: need update need update based on feedbacks labels Feb 17, 2020
alexwong pushed a commit to alexwong/tvm that referenced this pull request Feb 26, 2020
…ss (apache#4879)

* Fix bug in re-processing call node

* Add test

* Add to main

* temp changes to work from another machine

* fix rest of tests

* fix test_reuse_call_merge

* fix merge

Co-authored-by: Jon Soifer <[email protected]>
alexwong pushed a commit to alexwong/tvm that referenced this pull request Feb 28, 2020
…ss (apache#4879)

* Fix bug in re-processing call node

* Add test

* Add to main

* temp changes to work from another machine

* fix rest of tests

* fix test_reuse_call_merge

* fix merge

Co-authored-by: Jon Soifer <[email protected]>
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2020
…ss (apache#4879)

* Fix bug in re-processing call node

* Add test

* Add to main

* temp changes to work from another machine

* fix rest of tests

* fix test_reuse_call_merge

* fix merge

Co-authored-by: Jon Soifer <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants