Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legalize - Use Non-recursive Rewriter. #5296

Merged
merged 2 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
*
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
Expand Down Expand Up @@ -408,7 +408,7 @@ class ExprRewriter {

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
Expand Down
19 changes: 9 additions & 10 deletions src/relay/transforms/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,18 @@ namespace legalize {

// Call registered FTVMLegalize of an op
// Returns the legalized expression
class Legalizer : public ExprMutator {
class Legalizer : public ExprRewriter {
public:
explicit Legalizer(const std::string& legalize_map_attr_name)
: legalize_map_attr_name_{legalize_map_attr_name} {}

Expr VisitExpr_(const CallNode* call_node) {
Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
// Get the new_call node without any changes to current call node.
Expr new_e = ExprMutator::VisitExpr_(call_node);
Call new_call = Downcast<Call>(new_e);
Call new_call = Downcast<Call>(post);

// Check if the string is registered in the OpRegistry.
if (!Op::HasAttr(legalize_map_attr_name_)) {
return new_e;
return post;
}

// Collect the registered legalize function.
Expand All @@ -70,27 +69,27 @@ class Legalizer : public ExprMutator {
// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);

// Reassign new_e if the transformation succeeded.
// Return the new expr if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";

new_e = legalized_value;
return legalized_value;
}
}
}

return new_e;
return post;
}

private:
std::string legalize_map_attr_name_;
};

Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
return Legalizer(legalize_map_attr_name).Mutate(expr);
auto rewriter = Legalizer(legalize_map_attr_name);
return PostOrderRewrite(expr, &rewriter);
}

} // namespace legalize
Expand Down