Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] Migrate Pass/PassContext ObjectRef to not-null #5717

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,8 @@ class PassContextNode : public Object {
*/
class PassContext : public ObjectRef {
public:
PassContext() {}
explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
*/
const PassContextNode* operator->() const {
CHECK(get() != nullptr);
return static_cast<const PassContextNode*>(get());
}
PassContext() : PassContext(make_object<PassContextNode>()) {}
ANSHUMAN87 marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief mutable accessor.
* \return mutable access pointer.
Expand All @@ -180,11 +172,6 @@ class PassContext : public ObjectRef {
return static_cast<PassContextNode*>(get_mutable());
}

/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL static PassContext Create();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
Expand Down Expand Up @@ -214,8 +201,7 @@ class PassContext : public ObjectRef {
return tindex;
}

// accessor.
using ContainerType = PassContextNode;
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PassContext, ObjectRef, PassContextNode);
class Internal;

private:
Expand Down Expand Up @@ -388,11 +374,11 @@ class Sequential : public Pass {
*/
TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");

Sequential() = default;
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}

const SequentialNode* operator->() const;
using ContainerType = Sequential;
static constexpr bool _type_is_nullable = false;
};

/*
Expand Down
6 changes: 2 additions & 4 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde
PassConfigManager::Global()->Register(key, value_type_index);
}

PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); }

void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->trace_func != nullptr) {
Expand Down Expand Up @@ -195,7 +193,7 @@ class ModulePass : public Pass {
ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info);

TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
};

/*!
Expand Down Expand Up @@ -456,7 +454,7 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_GLOBAL("transform.PassContext")
.set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
TraceFunc trace_func, Optional<Map<String, ObjectRef>> config) {
auto pctx = PassContext::Create();
auto pctx = PassContext();
pctx->opt_level = opt_level;

pctx->required_pass = std::move(required);
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
} else {
using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
With<PassContext> fresh_pass_ctx_scope(static_cast<PassContext>(PassContext()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think static cast would be necessary here, With fresh_pass_ctx_scope(PassContext()); should suffice once we defined the copy constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen : Thanks! I tried with copy constructor using, but it did not work out.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be interesting to see what error message you got. As copy constructor is clearly defined in there.
We can also use With<PassContext> fresh_pass_ctx_scope();, which invokes the default constructor(in this case the new pass context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen : Sorry for late reply! If i dont use static_cast the testcase fails with additional assert count.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen : I have changed it to default constructor now, please check. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen : I have changed it to default constructor now, please check. Thanks!

@tqchen : Sorry! This change does not work. I made a mistake in code build, cause of that it got passed last time! So i have to use static_cast for the test case to pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen : Any thoughts on this comment?


std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
Expand Down Expand Up @@ -726,7 +726,7 @@ class CompileEngineImpl : public CompileEngineNode {
}

using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
With<PassContext> fresh_pass_ctx_scope(static_cast<PassContext>(PassContext()));

std::unordered_map<te::Tensor, tir::Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class FunctionPass : public Pass {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info);

TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
};

FunctionPass::FunctionPass(
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) {
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());
With<PassContext> fresh_build_ctx(static_cast<PassContext>(PassContext()));

return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ FInterpreter CPUInterpreter() {
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());
With<PassContext> fresh_build_ctx(static_cast<PassContext>(PassContext()));

return CreateInterpreter(IRModule(nullptr), CPUContext(), target);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class PrimFuncPass : public Pass {
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info);

TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
};

PrimFuncPass::PrimFuncPass(
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/relay_transform_sequential_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST(Relay, Sequential) {
relay::transform::EliminateCommonSubexpr(), relay::transform::AlterOpLayout()};
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create();
auto pass_ctx = relay::transform::PassContext();
pass_ctx->opt_level = 3;
pass_ctx->config.Set("relay.fallback_device_type", Integer(1));
{
Expand Down