Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[C++][API] Consistent RAII scoping API.
Browse files Browse the repository at this point in the history
tqchen committed May 23, 2019
1 parent 29cfac6 commit 9eba481
Showing 22 changed files with 216 additions and 192 deletions.
25 changes: 15 additions & 10 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
@@ -278,14 +278,14 @@ class CanonicalSimplifier {
};

/*!
* \brief A RAII constraint context.
* \brief Constraint context.
*
* \code
*
* Var("x");
* arith::Analyzer analyzer;
* {
* arith::ConstraintContext cctx(&analyzer, x % 3 == 0);
* With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
@@ -294,19 +294,24 @@ class CanonicalSimplifier {
* \endcode
*/
class ConstraintContext {
public:
private:
// declare friend to enable with.
friend class With<ConstraintContext>;
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION;
/*! \brief destructor */
~ConstraintContext() DMLC_THROW_EXCEPTION {
exit_();
}

private:
ConstraintContext(Analyzer* analyzer, Expr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
// enter the scope.
void EnterWithScope();
// exit the scope.
void ExitWithScope();
/*! \brief The analyzer */
Analyzer* analyzer_;
/*! \brief The constraint */
Expr constraint_;
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};
44 changes: 44 additions & 0 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
@@ -90,6 +90,50 @@ using ::tvm::AttrVisitor;
};


/*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
*
* \code
* // context class
* class MyContext {
* private:
* friend class With<MyContext>;
MyContext(arguments);
* void EnterWithScope();
* void ExitWithScope();
* };
*
* {
* With<MyContext> scope(arguments);
* // effect take place.
* }
* \endcode
*
* \tparam ContextType Type of the context object.
*/
template<typename ContextType>
class With {
public:
/*!
* \brief constructor.
* Enter the scope of the context.
*/
template<typename ...Args>
explicit With(Args&& ...args)
: ctx_(std::forward<Args>(args)...) {
ctx_.EnterWithScope();
}
/*! \brief destructor, leaves the scope of the context. */
~With() DMLC_THROW_EXCEPTION {
ctx_.ExitWithScope();
}

private:
/*! \brief internal context type. */
ContextType ctx_;
};

/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
131 changes: 45 additions & 86 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ namespace tvm {

/*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
class TargetNode : public Node {
public:
@@ -89,65 +89,47 @@ class TargetNode : public Node {
mutable std::string str_repr_;
};

/*! \brief reference cpass to the target. */
class Target : public NodeRef {
public:
Target() {}
explicit Target(NodePtr<Node> n) : NodeRef(n) {}

/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL static Target create(const std::string& target_str);

/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
TVM_DLL static void EnterTargetScope(const tvm::Target& target);

/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
TVM_DLL static void ExitTargetScope();

TVM_DLL static Target Create(const std::string& target_str);
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL static tvm::Target current_target(bool allow_not_defined = true);
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* undefined Target will be returned. Otherwise, an empty context stack will cause a
* runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);

inline const TargetNode* operator->() const {
const TargetNode* operator->() const {
return static_cast<const TargetNode*>(node_.get());
}

using ContainerType = TargetNode;
};

/*!
* \brief RAII container to provide a scoped target context. Pushes a target onto the
* context stack when constructed, and pops it when destructed.
*/
struct TargetContext {
class Internal;
private:
// enable with syntax.
friend class Internal;
friend class With<Target>;
/*!
* \brief Enter a new target context. The given target becomes the new current context.
* When the TargetContext is destructed, the previous context is restored.
* \param target The target to set as the new current context.
* \brief Push a new target context onto the thread local stack.
* The Target on top of the stack is used to determine which
* specialization to use when invoking a GenericFunc.
*/
explicit TargetContext(const tvm::Target& target) {
Target::EnterTargetScope(target);
}

/*! \brief Destructor. Pops the context off the thread local stack. */
~TargetContext() {
Target::ExitTargetScope();
}
TVM_DLL void EnterWithScope();
/*!
* \brief Pop a target off the thread local context stack,
* restoring the previous target as the current context.
*/
TVM_DLL void ExitWithScope();
};

/*! \brief This namespace provides functions to construct Target instances */
@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =

} // namespace target

class BuildConfig;

/*!
* \brief Container for build configuration options
*/
* \brief Container for build configuration options
*/
class BuildConfigNode : public Node {
public:
/*!
@@ -271,69 +251,48 @@ class BuildConfigNode : public Node {
};

/*!
* \brief Container for build configuration options
*/
* \brief Build configuration for compilations.
*/
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {}

const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
}

BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get());
}

/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
* \brief Construct a BuildConfig containing a empty build config node.
* \return The new BuildConfig
*/
TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);

/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
TVM_DLL static void ExitBuildConfigScope();

TVM_DLL static BuildConfig Create();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
TVM_DLL static tvm::BuildConfig Current();
TVM_DLL static BuildConfig Current();

using ContainerType = BuildConfigNode;
};
class Internal;

/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct BuildConfigContext {
private:
// Enable with syntax.
friend class With<BuildConfig>;
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
* \brief Push a new BuildConfig context onto the thread local stack.
*/
explicit BuildConfigContext(const tvm::BuildConfig& build_config) {
BuildConfig::EnterBuildConfigScope(build_config);
}
TVM_DLL void EnterWithScope();

/*! \brief Destructor. Pops the context off the thread local stack. */
~BuildConfigContext() {
BuildConfig::ExitBuildConfigScope();
}
/*!
* \brief Pop a build config off the thread local context stack,
* restoring the previous configuration as the current context.
*/
TVM_DLL void ExitWithScope();
};

/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL BuildConfig build_config();

/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
2 changes: 1 addition & 1 deletion python/tvm/build_module.py
Original file line number Diff line number Diff line change
@@ -187,7 +187,7 @@ def __enter__(self):
def __exit__(self, ptype, value, trace):
if self.dump_pass_ir:
BuildConfig._dump_ir.exit()
_api_internal._ExitBuildConfigScope()
_api_internal._ExitBuildConfigScope(self)

def __setattr__(self, name, value):
if name in BuildConfig._node_defaults:
2 changes: 1 addition & 1 deletion python/tvm/target.py
Original file line number Diff line number Diff line change
@@ -133,7 +133,7 @@ def __enter__(self):
return self

def __exit__(self, ptype, value, trace):
_api_internal._ExitTargetScope()
_api_internal._ExitTargetScope(self)


@register_node
8 changes: 4 additions & 4 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -116,8 +116,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto ctx =
std::shared_ptr<ConstraintContext>(new ConstraintContext(self.get(), args[0]));
auto ctx = std::shared_ptr<With<ConstraintContext> >(
new With<ConstraintContext>(self.get(), args[0]));
auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable {
ctx.reset();
};
17 changes: 12 additions & 5 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -54,17 +54,24 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) {
// skip rewrite simplify
}

ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {

void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr);
// entering the scope.
auto f0 = analyzer->const_int_bound.EnterConstraint(constraint);
auto f1 = analyzer->modular_set.EnterConstraint(constraint);
auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_);
auto f1 = analyzer_->modular_set.EnterConstraint(constraint_);
// recovery function.
exit_ = [f0, f1]() {
if (f1 != nullptr) f1();
if (f0 != nullptr) f0();
};
}

void ConstraintContext::ExitWithScope() {
CHECK(exit_ != nullptr);
exit_();
}

bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
Loading

0 comments on commit 9eba481

Please sign in to comment.