Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legalize - Use Non-recursive Rewriter. #5296

Merged
merged 2 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
*
* ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
*
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which will
* non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
* node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
* ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more complex
Expand Down Expand Up @@ -408,7 +408,7 @@ class ExprRewriter {

/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
*
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls the
* ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite call,
* PostOrderRewrite provides the original node and the node with altered inputs for use by the
* ExprRewriter.
Expand Down
9 changes: 5 additions & 4 deletions src/relay/transforms/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ namespace legalize {

// Call registered FTVMLegalize of an op
// Returns the legalized expression
class Legalizer : public ExprMutator {
class Legalizer : public ExprRewriter {
public:
explicit Legalizer(const std::string& legalize_map_attr_name)
: legalize_map_attr_name_{legalize_map_attr_name} {}

Expr VisitExpr_(const CallNode* call_node) {
Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
// Get the new_call node without any changes to current call node.
Expr new_e = ExprMutator::VisitExpr_(call_node);
Expr new_e = post;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just rename post to new_e and remove this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not directly rename, as new_e was supposed to be the new expr and post is a cont. But, I have cleaned up the code, so that new_e is not needed anymore.

Call new_call = Downcast<Call>(new_e);

// Check if the string is registered in the OpRegistry.
Expand Down Expand Up @@ -90,7 +90,8 @@ class Legalizer : public ExprMutator {
};

Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
return Legalizer(legalize_map_attr_name).Mutate(expr);
auto rewriter = Legalizer(legalize_map_attr_name);
return PostOrderRewrite(expr, &rewriter);
}

} // namespace legalize
Expand Down