Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Fix operator fusion for multiple output #3871

Merged
merged 6 commits into from
Sep 5, 2019

Conversation

MarisaKirisame
Copy link
Contributor

See #3867.

@tqchen @masahi @vinx13

I cant get a test case as it depend on the gradient of a concat which is another WIP. I could add the test afterward though.

@vinx13
Copy link
Member

vinx13 commented Sep 1, 2019

Can you manually construct a Relay program for test?

@MarisaKirisame
Copy link
Contributor Author

@vinx13 I can try. No grantee though as the failure case is generated and complex.

@MarisaKirisame
Copy link
Contributor Author

@vinx13 the test case was found.

@masahi masahi self-assigned this Sep 3, 2019
@@ -304,14 +304,16 @@ class PrettyPrinter :
* \return The corresponding name.
*/
Doc AllocTypeVar(const TypeVar& var) {
if (memo_type_.count(var)) {
Doc val = memo_type_[var];
val << "-malformed-ir";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change this to print in a more informative way? for example maybe use a colored highlight to show which part is malformed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no color highlighting in the current doc.

auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group == gmap_.at(tuple)) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group->root_ref == tuple) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra spaces?

@masahi
Copy link
Member

masahi commented Sep 4, 2019

Here is my understanding of the problem, based on @MarisaKirisame's testcast:

  • The Split node has multiple outputs. All outputs meet at the add node at the end, but since one of the output involves reference, they cannot be fused into a single op
  • So the split node doesn't have a parent in the postdom tree. (parent should be null)
  • In the first iteration, the split node gets one of its outputs node as its parent.
  • After the first LCA, the parent becomes null correctly, but in the next iteration the parent pointer is reset to another output node, which should not happen.

Is this correct? @MarisaKirisame

@MarisaKirisame
Copy link
Contributor Author

@masahi yes. the gradient of concatenate (which is split) triggered this problem.

@masahi
Copy link
Member

masahi commented Sep 4, 2019

Can you make a new function that return the LCA of all output nodes, and replace existing implementation based on raw loop? In your test case the new function should return null. I think this is clearer and less ad hoc.

@MarisaKirisame
Copy link
Contributor Author

@masahi like this?

@masahi
Copy link
Member

masahi commented Sep 5, 2019

I mean, I want to make the following snippet easier to understand. I know your fix works, but it took me a while to understand what problem this PR is supposed to solve. Especially the logic around assigning the parent pointer temporarily and letting LCA fix up when the parent should be null seems tricky.

// find the LCAs of all outputs.
OpPatternKind pattern = kElemWise;
Node* parent = nullptr;
bool init = true;
for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
  size_t oindex = link->value.node->index;
  CHECK_LT(oindex, nodes.size());
  Node* onode = nodes[oindex];
  CHECK(onode != nullptr);
  if (init) {
    parent = onode;
    init = false;
  } else {
    parent = LeastCommonAncestor(parent, onode, &pattern);
  }
  pattern = CombinePattern(pattern, link->value.pattern);
}
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->parent = parent;
tnode->pattern = pattern;

I think the part of the problem is assignment of parent pointer is tightly coupled with LCA computation inside a loop. It would be great if we could replace snippet above with something like (using a mix of non c++ syntax):

parent, pattern = LeastCommonAncestor(gnode->outputs) // LCA of all output nodes, not just two
tnode->parent = parent;
tnode->depth = parent ? parent->depth + 1 : 1;
tnode->pattern = pattern;

@masahi
Copy link
Member

masahi commented Sep 5, 2019

How about something like this? It passes your test but fails on other fusion tests :)

  static std::pair<DominatorTree::Node*, OpPatternKind> 
  LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& outputs,
                      const DominatorTree& tree) {
    std::vector<DominatorTree::Node*> nodes;
    for (auto link = outputs.head; link != nullptr; link = link->next) {
      size_t oindex = link->value.node->index;
      CHECK_LT(oindex, tree.nodes.size());
      auto* onode = tree.nodes[oindex];
      CHECK(onode != nullptr);
      nodes.push_back(onode);
    }
    OpPatternKind pattern = kElemWise;
    if (nodes.empty()) return std::make_pair(nullptr, pattern);
    auto combine_func = [=, &pattern](DominatorTree::Node* n1, DominatorTree::Node* n2) {
      return LeastCommonAncestor(n1, n2, &pattern);
    };
    auto lca = std::accumulate(nodes.begin(), nodes.end(), nodes[0], combine_func);
    return std::make_pair(lca, pattern);
  }

@MarisaKirisame
Copy link
Contributor Author

@masahi you have to treat the first node specially. You are not really doing a accumulate, but you are:
0: interleaving between all the children
1: if there is no children, provide a default value.

@masahi
Copy link
Member

masahi commented Sep 5, 2019

hmm I updated the code below and now it passes all fusion test cases. Not sure if I am doing something wrong. My fuse_ops.cc is here if you want to try it.

 static std::pair<DominatorTree::Node*, OpPatternKind> 
 LeastCommonAncestor(const LinkedList<IndexedForwardGraph::Edge>& outputs,
                     const DominatorTree& tree) {
   std::vector<std::pair<DominatorTree::Node*, OpPatternKind>> nodes;
   for (auto link = outputs.head; link != nullptr; link = link->next) {
     size_t oindex = link->value.node->index;
     CHECK_LT(oindex, tree.nodes.size());
     auto* onode = tree.nodes[oindex];
     CHECK(onode != nullptr);
     nodes.push_back(std::make_pair(onode, link->value.pattern));
   }
   if (nodes.empty()) return std::make_pair(nullptr, kElemWise);
   auto combine_func = [=](std::pair<DominatorTree::Node*, OpPatternKind> accum,
                           std::pair<DominatorTree::Node*, OpPatternKind> next) {
     return std::make_pair(
         LeastCommonAncestor(std::get<0>(accum), std::get<0>(next), &std::get<1>(accum)),
         CombinePattern(std::get<1>(accum), std::get<1>(next)));
   };
   auto lca = std::accumulate(nodes.begin(), nodes.end(), nodes[0], combine_func);
   return lca;
 }

@MarisaKirisame
Copy link
Contributor Author

@masahi I just wrote one that does basically the same thing. How is it?

@MarisaKirisame
Copy link
Contributor Author

Also your std::get<1>(accum) and &std::get<1>(accum) will create very subtle errors that depend on the order of which execute first.

@masahi
Copy link
Member

masahi commented Sep 5, 2019

Yeah it looks better. Thanks :)

@jroesch
Copy link
Member

jroesch commented Sep 5, 2019

@masahi can you approve explicitly? I can then merge then

@masahi masahi merged commit ca35277 into apache:master Sep 5, 2019
@masahi
Copy link
Member

masahi commented Sep 5, 2019

@jroesch sorry I was waiting for the CI to finish :)

MarisaKirisame added a commit to MarisaKirisame/tvm that referenced this pull request Sep 7, 2019
* save

* add test

* refactor

* fix indent

* save

* refactor
@MarisaKirisame MarisaKirisame deleted the fix-fuseops branch September 8, 2019 02:40
wweic pushed a commit to wweic/tvm that referenced this pull request Sep 16, 2019
* save

* add test

* refactor

* fix indent

* save

* refactor
wweic pushed a commit to wweic/tvm that referenced this pull request Sep 16, 2019
* save

* add test

* refactor

* fix indent

* save

* refactor
wweic pushed a commit to neo-ai/tvm that referenced this pull request Sep 16, 2019
* save

* add test

* refactor

* fix indent

* save

* refactor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants