Skip to content

Commit

Permalink
[REFACTOR][IR] Migrate Pass/PassContext ObjectRef to not-null
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Jun 3, 2020
1 parent 4347b41 commit 2891365
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 17 deletions.
10 changes: 3 additions & 7 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class PassContextNode : public Object {
*/
class PassContext : public ObjectRef {
public:
PassContext() {}
PassContext() : PassContext(make_object<PassContextNode>()) {}
explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
Expand All @@ -180,11 +180,6 @@ class PassContext : public ObjectRef {
return static_cast<PassContextNode*>(get_mutable());
}

/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL static PassContext Create();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
Expand Down Expand Up @@ -216,6 +211,7 @@ class PassContext : public ObjectRef {

// accessor.
using ContainerType = PassContextNode;
static constexpr bool _type_is_nullable = false;
class Internal;

private:
Expand Down Expand Up @@ -388,11 +384,11 @@ class Sequential : public Pass {
*/
TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");

Sequential() = default;
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}

const SequentialNode* operator->() const;
using ContainerType = Sequential;
static constexpr bool _type_is_nullable = false;
};

/*
Expand Down
6 changes: 2 additions & 4 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde
PassConfigManager::Global()->Register(key, value_type_index);
}

PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }

void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->trace_func != nullptr) {
Expand Down Expand Up @@ -195,7 +193,7 @@ class ModulePass : public Pass {
ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info);

TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
};

/*!
Expand Down Expand Up @@ -456,7 +454,7 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
TraceFunc trace_func, Optional<Map<String, ObjectRef>> config) {
auto pctx = PassContext::Create();
auto pctx = PassContext();
pctx->opt_level = opt_level;

pctx->required_pass = std::move(required);
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
} else {
using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
With<PassContext> fresh_pass_ctx_scope(PassContext());

std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
Expand Down Expand Up @@ -726,7 +726,7 @@ class CompileEngineImpl : public CompileEngineNode {
}

using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
With<PassContext> fresh_pass_ctx_scope(PassContext());

std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class FunctionPass : public Pass {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info);

TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
};

FunctionPass::FunctionPass(
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) {
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());
With<PassContext> fresh_build_ctx(PassContext());

return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ FInterpreter CPUInterpreter() {
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());
With<PassContext> fresh_build_ctx(PassContext());

return CreateInterpreter(IRModule(nullptr), CPUContext(), target);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class PrimFuncPass : public Pass {
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info);

TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
};

PrimFuncPass::PrimFuncPass(
Expand Down

0 comments on commit 2891365

Please sign in to comment.