diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 44cbb09e6889..3c53eb323c79 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -248,6 +248,30 @@ def __init__(self, passes, opt_level, name, required) +def infer_type(expr, mod=None): + """Infer the type of an expr. + Adding Function into a Module will change it's binding, + and some passes need type inference to work without binding modification. + However, InferType() work by putting stuff into a Module, thus changing all the binding. + + This is an escape patch that allow type inference without binding changing. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + mod : Optional[tvm.relay.Module] + The input module + + Returns + ------- + ret : tvm.relay.Expr + The output expression. + """ + return _transform.infer_type(expr, mod) + + def InferType(): """Infer the type of an expr. diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 02f6cc3d4857..038437c2291a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -824,6 +824,9 @@ Function InferType(const Function& func, return Downcast(func_ret); } +TVM_REGISTER_API("relay._transform.infer_type") +.set_body_typed([](Expr l, Module r) { return InferType(l, r); }); + namespace transform { Pass InferType() {