diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index ffe37074ecae..f9061bcf4efa 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -161,16 +161,8 @@ class PassContextNode : public Object { */ class PassContext : public ObjectRef { public: - PassContext() {} - explicit PassContext(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief const accessor. - * \return const access pointer. - */ - const PassContextNode* operator->() const { - CHECK(get() != nullptr); - return static_cast(get()); - } + PassContext() : PassContext(make_object()) {} + /*! * \brief mutable accessor. * \return mutable access pointer. @@ -180,11 +172,6 @@ class PassContext : public ObjectRef { return static_cast(get_mutable()); } - /*! - * \brief Construct a PassContext containing the default configurations. - * \return The new PassContext. - */ - TVM_DLL static PassContext Create(); /*! * \brief Get the default pass context in the current scope. * \return The pass context. @@ -214,8 +201,7 @@ class PassContext : public ObjectRef { return tindex; } - // accessor. - using ContainerType = PassContextNode; + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PassContext, ObjectRef, PassContextNode); class Internal; private: @@ -388,11 +374,11 @@ class Sequential : public Pass { */ TVM_DLL Sequential(Array passes, String name = "sequential"); - Sequential() = default; explicit Sequential(ObjectPtr n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = Sequential; + static constexpr bool _type_is_nullable = false; }; /* diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 322c1ef59ac6..58ada5b3ad97 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -138,8 +138,6 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassConfigManager::Global()->Register(key, value_type_index); } -PassContext PassContext::Create() { return PassContext(make_object()); } - void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->trace_func != nullptr) { @@ -195,7 +193,7 @@ class ModulePass : public Pass { ModulePass(runtime::TypedPackedFunc pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; /*! @@ -456,7 +454,7 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, TraceFunc trace_func, Optional> config) { - auto pctx = PassContext::Create(); + auto pctx = PassContext(); pctx->opt_level = opt_level; pctx->required_pass = std::move(required); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index be749fdd3a97..3da49390d4d2 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -689,7 +689,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); } else { using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + With fresh_pass_ctx_scope(static_cast(PassContext())); std::unordered_map binds; cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); @@ -726,7 +726,7 @@ class CompileEngineImpl : public CompileEngineNode { } using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + With fresh_pass_ctx_scope(static_cast(PassContext())); std::unordered_map binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 184ee58009d7..8623d959295d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -99,7 +99,7 @@ class FunctionPass : public Pass { runtime::TypedPackedFunc pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); }; FunctionPass::FunctionPass( diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index b2eab8f96987..0300482fa343 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -275,7 +275,7 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) { Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - With fresh_build_ctx(PassContext::Create()); + With fresh_build_ctx(static_cast(PassContext())); return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); } diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 371142ad76a2..acd161a917bd 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -538,7 +538,7 @@ FInterpreter CPUInterpreter() { Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - With fresh_build_ctx(PassContext::Create()); + With fresh_build_ctx(static_cast(PassContext())); return CreateInterpreter(IRModule(nullptr), CPUContext(), target); } diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 30d5f0f50774..500136c5ecc1 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -73,7 +73,7 @@ class PrimFuncPass : public Pass { runtime::TypedPackedFunc pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode); }; PrimFuncPass::PrimFuncPass( diff --git a/tests/cpp/relay_transform_sequential_test.cc b/tests/cpp/relay_transform_sequential_test.cc index bb4bf928b018..79f040f18575 100644 --- a/tests/cpp/relay_transform_sequential_test.cc +++ b/tests/cpp/relay_transform_sequential_test.cc @@ -90,7 +90,7 @@ TEST(Relay, Sequential) { relay::transform::EliminateCommonSubexpr(), relay::transform::AlterOpLayout()}; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); auto mod = IRModule::FromExpr(func); - auto pass_ctx = relay::transform::PassContext::Create(); + auto pass_ctx = relay::transform::PassContext(); pass_ctx->opt_level = 3; pass_ctx->config.Set("relay.fallback_device_type", Integer(1)); {