Skip to content

Commit

Permalink
fix more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 24, 2019
1 parent f148e30 commit ef9f5c6
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 12 deletions.
2 changes: 0 additions & 2 deletions docs/api/python/relay/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ tvm.relay.transform

.. autofunction:: tvm.relay.transform.function_pass

.. autofunction:: tvm.relay.transform.current_pass_context

.. autoclass:: tvm.relay.transform.Pass
:members:

Expand Down
37 changes: 35 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,19 @@ class PassContext : public NodeRef {
PassContext() {}
explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}

TVM_DLL PassContext(int opt_level, int fallback_device,
/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
*/
TVM_DLL PassContext(int opt_level,
int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass);

Expand Down Expand Up @@ -191,7 +203,8 @@ class PassNode : public RelayNode {
virtual PassInfo Info() const = 0;

/*!
* \brief Execute the optimization pass using a functor.
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
*
* \param mod The module that an optimization pass runs on.
*
Expand All @@ -201,6 +214,15 @@ class PassNode : public RelayNode {
return this->operator()(mod, PassContext::Current());
}

/*!
* \brief Execute the optimization pass using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;

Expand Down Expand Up @@ -228,11 +250,22 @@ class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");

Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def build(self, func, target=None, target_host=None, params=None):
return graph_json, mod, params

def _setup_build_config(self, params):
cfg = _transform.current_pass_context()
cfg = _transform.PassContext.current()

# Set opt_level.
self.set_opt_level(cfg.opt_level)
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ class PassContext(RelayNode):
opt_level : Optional[int]
The optimization level of this pass.
fallback_device : Optional[int]
fallback_device : Optional[Union[int, str, TVMContext]]
The fallback device type. It is also used as the default device for
operators that are not annotated during heterogeneous execution.
required_pass : Optional[List[str]]
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
disabled_pass : Optional[List[str]]
disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are disabled.
"""
def __init__(self,
Expand Down Expand Up @@ -107,10 +107,10 @@ def __enter__(self):
def __exit__(self, ptype, value, trace):
_transform.ExitPassContext(self)


def current_pass_context():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()


def build_config(opt_level=2,
Expand Down
8 changes: 8 additions & 0 deletions src/relay/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,14 @@ Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
node_ = std::move(n);
}

Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
auto n = make_node<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfoNode::make(2, std::move(name), {});
n->pass_info = std::move(pass_info);
node_ = std::move(n);
}

const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(this->node_.get());
}
Expand Down

0 comments on commit ef9f5c6

Please sign in to comment.