From be3596aad80cc58dcc996e2d920d32895b8e07d2 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 2 Jun 2020 23:30:54 +0530 Subject: [PATCH] [REFACTOR][IR] Migrate Pass/PassContext ObjectRef to not-null --- include/tvm/ir/transform.h | 10 +++------- src/ir/transform.cc | 6 ++---- src/relay/backend/compile_engine.cc | 4 ++-- src/relay/ir/transform.cc | 2 +- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/partial_eval.cc | 2 +- src/tir/ir/transform.cc | 2 +- 7 files changed, 11 insertions(+), 17 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 4c36c7c940419..a1f5216db0ed5 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -161,7 +161,7 @@ class PassContextNode : public Object { */ class PassContext : public ObjectRef { public: - PassContext() {} + PassContext() : PassContext(make_object()) {} explicit PassContext(ObjectPtr n) : ObjectRef(n) {} /*! * \brief const accessor. @@ -180,11 +180,6 @@ class PassContext : public ObjectRef { return static_cast(get_mutable()); } - /*! - * \brief Construct a PassContext containing the default configurations. - * \return The new PassContext. - */ - TVM_DLL static PassContext Create(); /*! * \brief Get the default pass context in the current scope. * \return The pass context. @@ -216,6 +211,7 @@ class PassContext : public ObjectRef { // accessor. using ContainerType = PassContextNode; + static constexpr bool _type_is_nullable = false; class Internal; private: @@ -388,11 +384,11 @@ class Sequential : public Pass { */ TVM_DLL Sequential(Array passes, String name = "sequential"); - Sequential() = default; explicit Sequential(ObjectPtr n) : Pass(n) {} const SequentialNode* operator->() const; using ContainerType = Sequential; + static constexpr bool _type_is_nullable = false; }; /* diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 9eb327dafe353..99955fe7cbd1c 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -138,8 +138,6 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassConfigManager::Global()->Register(key, value_type_index); } -PassContext PassContext::Create() { return PassContext(make_object()); } - void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { auto pass_ctx_node = this->operator->(); if (pass_ctx_node->trace_func != nullptr) { @@ -195,7 +193,7 @@ class ModulePass : public Pass { ModulePass(runtime::TypedPackedFunc pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; /*! @@ -456,7 +454,7 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, TraceFunc trace_func, Optional> config) { - auto pctx = PassContext::Create(); + auto pctx = PassContext(); pctx->opt_level = opt_level; pctx->required_pass = std::move(required); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 182207ad7bb59..17833e5c3d498 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -689,7 +689,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); } else { using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + With fresh_pass_ctx_scope(PassContext()); std::unordered_map binds; cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); @@ -726,7 +726,7 @@ class CompileEngineImpl : public CompileEngineNode { } using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); + With fresh_pass_ctx_scope(PassContext()); std::unordered_map binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 184ee58009d7b..8623d959295d5 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -99,7 +99,7 @@ class FunctionPass : public Pass { runtime::TypedPackedFunc pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); }; FunctionPass::FunctionPass( diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index b41284b12b386..b058601e26a2a 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -275,7 +275,7 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) { Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - With fresh_build_ctx(PassContext::Create()); + With fresh_build_ctx(PassContext()); return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); } diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 9a002ef1319cb..48f331ddcc3a2 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -538,7 +538,7 @@ FInterpreter CPUInterpreter() { Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - With fresh_build_ctx(PassContext::Create()); + With fresh_build_ctx(PassContext()); return CreateInterpreter(IRModule(nullptr), CPUContext(), target); } diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 30d5f0f507741..500136c5ecc19 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -73,7 +73,7 @@ class PrimFuncPass : public Pass { runtime::TypedPackedFunc pass_func, PassInfo pass_info); - TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode); }; PrimFuncPass::PrimFuncPass(