Skip to content

Commit

Permalink
Add EtaExpand to transform API (apache#3406)
Browse files Browse the repository at this point in the history
* Add EtaExpand to transform API

* Add test case
  • Loading branch information
wweic committed Jun 27, 2019
1 parent 268921d commit c9661b2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,15 @@ def ToANormalForm():
"""
return _transform.ToANormalForm()

def EtaExpand():
"""Add abstraction over a function
Returns
-------
ret: tvm.relay.Pass
The registered pass that eta expands an expression.
"""
return _transform.EtaExpand()

def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression
Expand Down
15 changes: 15 additions & 0 deletions src/relay/pass/eta_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,20 @@ Expr EtaExpand(const Expr& e, const Module& mod) {

TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand);

namespace transform {

Pass EtaExpand() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(EtaExpand(f, m));
};
return CreateFunctionPass(pass_func, 1, "EtaExpand", {});
}

TVM_REGISTER_API("relay._transform.EtaExpand")
.set_body_typed(EtaExpand);

} // namespace transform

} // namespace relay
} // namespace tvm
13 changes: 10 additions & 3 deletions tests/python/relay/test_pass_eta_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
# specific language governing permissions and limitations
# under the License.
from tvm import relay
import tvm.relay.module as _module
import tvm.relay.transform as _transform

def test_eta_expand_basic():
mod = relay.Module()
x = relay.var('x', 'int32')
y = relay.var('y', 'int32')
orig = relay.Function([x], x)
got = relay.ir_pass.eta_expand(orig, mod)
mod = _module.Module.from_expr(orig)
seq = _transform.Sequential([_transform.EtaExpand()])
with _transform.PassContext(opt_level=3):
mod = seq(mod)

got = mod[mod.entry_func.name_hint]

y = relay.var('y', 'int32')
expected = relay.Function([y], orig(y))

got = relay.ir_pass.infer_type(got, mod)
Expand Down

0 comments on commit c9661b2

Please sign in to comment.