-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Pass] Merge two consecutive reshape ops #6052
Changes from 2 commits
8bce8da
8b69388
7b35182
842b5a7
b58b998
43a172a
f28a13f
d5c1efd
7bd9401
8479dd9
e0add95
45b2dcc
4ace1ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file src/relay/transforms/simplify_expr.cc | ||
* \brief A pass for simplifying the Relay expression. | ||
*/ | ||
|
||
#include <tvm/relay/expr.h> | ||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/transform.h> | ||
#include <tvm/relay/dataflow_matcher.h> | ||
#include <tvm/support/logging.h> | ||
#include "../op/tensor/transform.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
static Op reshape_op = Op::Get("reshape"); | ||
static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape"); | ||
|
||
/*! | ||
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops, | ||
* and merges into one reshape op. | ||
*/ | ||
class SimplifyReshape { | ||
public: | ||
SimplifyReshape() { | ||
x_ = WildcardPattern(make_object<WildcardPatternNode>()); | ||
auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); | ||
auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); | ||
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {}); | ||
} | ||
|
||
Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) { | ||
auto x = node_map[x_][0]; | ||
bool const_shape = true; | ||
Array<Integer> newshape; | ||
for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) { | ||
if (dim.as<IntImmNode>() == nullptr) { | ||
const_shape = false; | ||
break; | ||
} | ||
newshape.push_back(Downcast<Integer>(dim)); | ||
} | ||
if (const_shape) { | ||
return MakeReshape(x, newshape); | ||
} | ||
return post; | ||
} | ||
|
||
DFPattern pattern() const { return pattern_; } | ||
|
||
private: | ||
/*! \brief Pattern input */ | ||
DFPattern x_; | ||
/*! \brief Pattern for consecutive reshape or reverse_reshape ops */ | ||
DFPattern pattern_; | ||
|
||
}; | ||
|
||
/*! | ||
* \brief ExprSimplifier simplifies the Relay expression. | ||
*/ | ||
class ExprSimplifier { | ||
public: | ||
ExprSimplifier() { | ||
auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) { | ||
Expr pre = args[0]; | ||
Expr post = args[1]; | ||
Map<DFPattern, Array<Expr>> node_map = args[2]; | ||
*rv = simplify_reshape_.callback(pre, post, node_map); | ||
}; | ||
callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe have SimplifyReshape directly inherit DFPatternCallback? You could fold this directly into that and keep it out of the main Simplifier. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason that I didn't inherit directly from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :/ I think I focused too much on the Python API and left an Ugly C++ API. I'll see if I can clean that up in a follow up PR. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. :) |
||
true)); | ||
} | ||
|
||
Expr Simplify(const Expr& expr) { | ||
return RewritePatterns(callbacks_, expr); | ||
} | ||
|
||
private: | ||
/*! \brief Simplify reshape pattern */ | ||
SimplifyReshape simplify_reshape_; | ||
/*! \brief Callbacks for expr simplification */ | ||
Array<DFPatternCallback> callbacks_; | ||
}; | ||
|
||
Expr SimplifyExpr(const Expr& expr, const IRModule& module) { | ||
return ExprSimplifier().Simplify(expr); | ||
} | ||
|
||
namespace transform { | ||
|
||
Pass SimplifyExpr() { | ||
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = | ||
[=](Function f, IRModule m, PassContext pc) { | ||
return Downcast<Function>(SimplifyExpr(f, m)); | ||
}; | ||
return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr); | ||
|
||
} // namespace transform | ||
|
||
} // namespace relay | ||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://tvm.apache.org/docs/contribute/code_guide.html
https://google.github.io/styleguide/cppguide.html#Variable_Names
Why the move away from the Google Style Guide convention? You seem to use the var_name_ convention in simplify_expr.cc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because these variables are public, it's probably better and more consistent to name it without "_" at the end imo.