diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index ba9a62aa0596..9317f1ed3250 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -285,7 +285,7 @@ class IRModule : public ObjectRef { Map type_definitions = {}, std::unordered_set import_set = {}); /*! \brief default constructor */ - IRModule() {} + IRModule() : IRModule(Map()) {} /*! * \brief constructor * \param n The object pointer. @@ -298,12 +298,6 @@ class IRModule : public ObjectRef { return static_cast(ptr); } - /*! - * \brief Construct an empty module. - * - * \returns The constructed module - */ - static IRModule Empty() { return IRModule(Map()); } /*! * \brief Construct a module from a standalone expression. * @@ -330,6 +324,10 @@ class IRModule : public ObjectRef { /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; + + /*! \brief Declare whether Ref is nullable. */ + static constexpr bool _type_is_nullable = false; + // allow copy on write. TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode); }; diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c58c67937cf4..1d17c421b5a6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -656,7 +656,8 @@ def to_cps(func, mod=None): result: tvm.relay.Function The output function. """ - return _ffi_api.to_cps(func, mod) + use_mod = mod if mod is not None else tvm.ir.IRModule() + return _ffi_api.to_cps(func, use_mod) def un_cps(func): diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 9e94459d7018..6be956c70d34 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -96,8 +96,11 @@ FeatureSet DetectFeature(const IRModule& mod) { return fs; } -Array PyDetectFeature(const Expr& expr, const IRModule& mod) { - FeatureSet fs = DetectFeature(expr) + DetectFeature(mod); +Array PyDetectFeature(const Expr& expr, const Optional& mod) { + FeatureSet fs = DetectFeature(expr); + if (mod.defined()) { + fs = fs + DetectFeature(mod.value()); + } return static_cast>(fs); } diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index 96dab6b60495..e852c40dfeba 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -305,11 +305,8 @@ Array UnmatchedCases(const Match& match, const IRModule& mod) { // expose for testing only TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases") - .set_body_typed([](const Match& match, const IRModule& mod_ref) { - IRModule call_mod = mod_ref; - if (!call_mod.defined()) { - call_mod = IRModule({}, {}); - } + .set_body_typed([](const Match& match, const Optional& mod_ref) { + IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {}); return UnmatchedCases(match, call_mod); }); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 9abe80c363c5..a5f3f6359f89 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -82,7 +82,7 @@ struct CachedFuncNode : public Object { /*! \brief The schedule to the function */ te::Schedule schedule; /*! \brief The lowered functions to support the function. */ - IRModule funcs = IRModule::Empty(); + IRModule funcs = IRModule(); /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index c8ec1bf1e767..85d439b85b3c 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -207,7 +207,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorUpdate(kv.second); @@ -395,7 +395,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorstr())) { - lowered_funcs_[target->str()] = IRModule::Empty(); + lowered_funcs_[target->str()] = IRModule(); } lowered_funcs_[target->str()]->Update(lowered_func->funcs); return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name); diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index afe556818371..7801c03910d0 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -68,7 +68,7 @@ Type WithGradientType(const Type&); /*! return an expression that represent differentiation of e (according to WithGradientType). * This version only work on first order code without control flow. */ -Expr FirstOrderGradient(const Expr& e, const IRModule& mod); +Expr FirstOrderGradient(const Expr& e, const Optional& mod); Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking @@ -78,9 +78,11 @@ Type WithGradientType(const Type& t) { } //! \brief if the expression is a GlobalVar, transform to it's expression. -Expr DeGlobal(const IRModule& mod, const Expr& e) { - if (const auto* x = e.as()) { - BaseFunc base_func = mod->Lookup(GetRef(x)); +Expr DeGlobal(const Optional& mod, const Expr& e) { + const auto* x = e.as(); + + if (mod.defined() && (x)) { + BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { return n->body; } else { @@ -214,7 +216,7 @@ Type GradRetType(const Function& f) { return TupleType({f->ret_type, TupleType(vt)}); } -Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { +Expr FirstOrderGradient(const Expr& re, const Optional& mod) { // Currently we first remove any global functions for the first // order case. auto e = DeGlobal(mod, re); @@ -482,7 +484,7 @@ bool MissingGrad(const Expr& e) { return false; } -Expr Gradient(const Expr& re, const IRModule& mod) { +Expr Gradient(const Expr& re, const Optional& mod) { auto e = DeGlobal(mod, re); auto f = e.as(); CHECK(f) << "input need to be a function"; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 9bdb0e235478..98577a79a37f 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -275,7 +275,7 @@ Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); - IRModule device_mod = IRModule::Empty(); + IRModule device_mod = IRModule(); for (auto& kv : func_dict->data) { if (kv.second->IsInstance()) {