Skip to content

Commit

Permalink
Clean up pass.h (#3312)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and tqchen committed Jul 2, 2019
1 parent 0af5c21 commit e3d6074
Show file tree
Hide file tree
Showing 130 changed files with 1,273 additions and 1,369 deletions.
3 changes: 2 additions & 1 deletion docs/api/python/relay/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ compiler stack.
expr
frontend
image
ir_pass
analysis
transform
module
nn
op
Expand Down
117 changes: 9 additions & 108 deletions include/tvm/relay/pass.h → include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,21 @@
*/

/*!
* \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++.
*/
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
* \file tvm/relay/analysis.h
* \brief The set of Relay analysis passes written in C++.
*/
#ifndef TVM_RELAY_ANALYSIS_H_
#define TVM_RELAY_ANALYSIS_H_

#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <string>
#include <vector>

namespace tvm {
namespace relay {

/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \param expr The expression to type check.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return A type checked expression with its checked_type field populated.
*/
TVM_DLL Expr InferType(const Expr& expr, const Module& mod);

/*!
* \brief Infer the type of a function as if it is mapped to var in the mod.
*
Expand All @@ -64,7 +43,8 @@ TVM_DLL Expr InferType(const Expr& expr, const Module& mod);
* \return A type checked Function with its checked_type field populated.
* \note this function mutates mod and is not thread-safe.
*/
TVM_DLL Function InferType(const Function& f, const Module& mod,
TVM_DLL Function InferType(const Function& f,
const Module& mod,
const GlobalVar& var);

/*!
Expand Down Expand Up @@ -271,58 +251,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);

/*!
* \brief Fold constant expressions.
*
* \param expr the expression to be optimized.
*
* \return The optimized expression.
*/
TVM_DLL Expr FoldConstant(const Expr& expr);

/*!
* \brief Fuse operations into expr into seperate functions.
*
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \param mod the module.
*
* \return The optimized expression.
*/
TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Rewrite the annotated program.
*
Expand Down Expand Up @@ -364,19 +292,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);

/*!
* \brief Bind the free variables to a Relay expression.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
*
* \return The expression with all free vars bound.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand All @@ -388,7 +303,6 @@ struct StructuralHash {
* \return the hash value.
*/
size_t operator()(const Type& type) const;

/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
Expand All @@ -400,20 +314,7 @@ struct StructuralHash {
size_t operator()(const Expr& expr) const;
};

namespace vm {

/*!
* \brief Compile a module, and construct the virtual machine.
*
* \param mod The module to compile.
*
* \return The constructed virtual machine.
*/
runtime::vm::VirtualMachine CompileModule(const Module& mod);

} // namespace vm

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_PASS_H_
#endif // TVM_RELAY_ANALYSIS_H_
97 changes: 57 additions & 40 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,36 +378,6 @@ TVM_DLL Pass FoldConstant();
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)>
fmulti_ref_trigger = nullptr);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Rewrite the annotated program.
*
Expand Down Expand Up @@ -554,21 +524,68 @@ TVM_DLL Pass CanonicalizeCast();
*/
TVM_DLL Pass EtaExpand();

} // namespace transform

/*!
* \brief This is a helper function that runs a some optimization passes on
* a certain expression and returns the optimized version. With the help of this
* function, users don't need to manually construct a module, then perform
* passes, and finally and extract the target function/expression from the
* returned module frequently.
* \brief Bind the free variables to a Relay expression. This is a helper
* function usually called by other pass functions to help optimizations.
*
* \param expr The expression to be optimized.
* \param passes The passses that will be applied on the given expression.
* \param expr The input expression.
* \param binds The variable to expression map that will be used to help the
* binding.
*
* \return The optimized expression.
* \return The updated expression.
*/
TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes);
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Infer the type of a function as if it is mapped to var in the mod.
*
* \param f the function.
* \param mod The module used for referencing global functions.
* \param var The global variable corresponding to the function.
*
* \return A type checked Function with its checked_type field populated.
* \note this function mutates mod and is not thread-safe.
*/
TVM_DLL Function InferType(const Function& f,
const Module& mod,
const GlobalVar& var);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
* function is used as a helper function to rewrtie an expression in a pass.
*
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const std::string& rewrite_map_attr_name,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
* function is used as a helper function to rewrtie an expression in a pass.
*
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr,
const FForwardRewrite& rewrite_func,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

} // namespace transform
} // namespace relay
} // namespace tvm

Expand Down
9 changes: 5 additions & 4 deletions nnvm/tests/python/compiler/test_to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nnvm import testing
from nnvm import to_relay
import tvm
from tvm.relay import ir_pass
from tvm.relay import transform
from tvm.relay import create_executor
from tvm.contrib import graph_runtime
import numpy as np
Expand All @@ -41,10 +41,11 @@ def check_model(sym, shapes, dtypes, params):
nnvm_rts.run(**inputs)
nnvm_out = nnvm_rts.get_output(0)
relay_model, params = to_relay.to_relay(net, shapes, dtypes, params)
relay_model = ir_pass.infer_type(relay_model)
relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm')
mod = tvm.relay.Module.from_expr(relay_model)
mod = transform.InferType()(mod)
relay_rts = create_executor(kind='graph', mod=mod, ctx=tvm.cpu(0), target='llvm')
inputs.update(params)
relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values()))
relay_out = relay_rts.evaluate()(*list(inputs.values()))
np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy())

# def test_mlp():
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import topi

from tvm import relay, autotvm
from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv
Expand Down Expand Up @@ -80,6 +81,14 @@ def expr2graph(expr, target_ops, node_dict, node_list):
task_pos += 1


def _infer_type(node):
"""A method to infer the type of a relay expression."""
mod = relay.Module.from_expr(node)
mod = transform.InferType()(mod)
entry = mod[mod.entry_func]
return entry if isinstance(node, relay.Function) else entry.body


def _expr2graph_impl(expr, target_ops, node_dict, node_list):
"""Implementation to convert relay expr to graph data structure
"""
Expand All @@ -99,7 +108,7 @@ def _traverse_expr(node):
node_entry["inputs"] += node_list[in_node_idx]["inputs"]
else:
node_entry["inputs"].append([in_node_idx, 0, 0])
infer_out = relay.ir_pass.infer_type(node)
infer_out = _infer_type(node)
out_type = infer_out._checked_type_
if isinstance(out_type, TensorType):
node_entry["types"].append(out_type)
Expand Down Expand Up @@ -168,7 +177,7 @@ def _traverse_expr(node):
node_dict[node] = node_index
node_list.append(node_entry)

relay.ir_pass.post_order_visit(expr, _traverse_expr)
relay.analysis.post_order_visit(expr, _traverse_expr)


def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names):
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/autotvm/graph_tuner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=eval-used,invalid-name,too-many-arguments
"""Utility functions"""
from tvm import relay
from tvm.relay import transform


def has_multiple_inputs(node_list, node_idx, input_names):
Expand Down Expand Up @@ -107,4 +108,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
rebind_dict[var] = updated_input_dict[var.name_hint]
updated_expr = relay.expr.bind(expr, rebind_dict)

return relay.ir_pass.infer_type(updated_expr)
mod = relay.Module.from_expr(updated_expr)
mod = transform.InferType()(mod)
entry = mod[mod.entry_func]
return entry if isinstance(updated_expr, relay.Function) else entry.body
Loading

0 comments on commit e3d6074

Please sign in to comment.