-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RELAY][PASS] Common subexpression elimination #2639
Conversation
@kazum @ZihengJiang please help to take a look |
@jroesch can you manage this PR as per https://docs.tvm.ai/contribute/committer_guide.html, a good chance to test out your committer rights. |
* \return Whether two expressions are equal scalars. | ||
*/ | ||
inline bool IsEqualScalar(const Expr& a, const Expr& b) { | ||
const auto* constant_a = a.as<ConstantNode>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constant_a is nullptr with relay.var("x", shape=(1, 16))
in your test script. Is this what you expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this function is intended to enable combining different Constant instance with the same value
With the below code, the result is as expected. from tvm import relay
x = relay.var("x", shape=(1, 16))
y1 = relay.add(x, relay.const(1.0, "float32"))
y2 = relay.add(x, relay.const(1.0, "float32"))
y = relay.add(y1, y2)
f = relay.Function([x], y)
f = relay.ir_pass.eliminate_common_subexpr(f)
print(f)
However, when I changed the code a bit as follows, elimination did not work. Is this not a scope of this PR? from tvm import relay
x = relay.var("x", shape=(1, 16))
y1 = relay.add(relay.const(1.0, "float32"), x)
y2 = relay.add(relay.const(1.0, "float32"), x)
y = relay.add(y1, y2)
f = relay.Function([x], y)
f = relay.ir_pass.eliminate_common_subexpr(f)
print(f)
|
@kazum yes, it is a limitation of current implementation |
return GetRef<Call>(candidate); | ||
} | ||
} | ||
expr_map_[new_call->args[0]].push_back(new_call); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me ask one more question. expr_map_
is a map from new_call->args[0]
to new_call
. Can we change it to a map from new_call->op
to new_call
? Then, this PR also handles the case of
#2639 (comment), doesn't it?
What I mean is like as follows:
auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
for (const CallNode* candidate : it->second) {
bool is_equivalent = true;
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
}
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!new_call->args[i].same_as(candidate->args[i]) &&
!IsEqualScalar(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
}
expr_map_[new_call->op].push_back(new_call);
return new_expr;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I chose to map from new_call->args[0]
is to avoid searching a long list of candidates. But yes you are right, on a second thought I think it is okay to map from op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks!
This is an optimization pass that eliminates common subexpressions. During the pass, it tries to replace an expression with a previously appeared expression with the same input and attributes. The fskip callback argument allows us to skip specific expressions.
cc @tqchen