Skip to content

Commit

Permalink
[RELAY][TRANSFORM] Migrate buildmodule to transform (#3251)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and tqchen committed Jun 3, 2019
1 parent 0faf731 commit bb48a45
Show file tree
Hide file tree
Showing 24 changed files with 879 additions and 455 deletions.
26 changes: 13 additions & 13 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);

/*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
*/
void AddDef(const GlobalTypeVar& var, const TypeData& type);
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Add a function to the global environment.
Expand All @@ -103,69 +103,69 @@ class ModuleNode : public RelayNode {
*
* It does not do type inference as Add does.
*/
void AddUnchecked(const GlobalVar& var, const Function& func);
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);

/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);
TVM_DLL void Update(const GlobalVar& var, const Function& func);

/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);
TVM_DLL void Remove(const GlobalVar& var);

/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);
TVM_DLL GlobalVar GetGlobalVar(const std::string& str);

/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalTypeVar GetGlobalTypeVar(const std::string& str);
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);

/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);
TVM_DLL Function Lookup(const GlobalVar& var);

/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);
TVM_DLL Function Lookup(const std::string& name);

/*!
* \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const GlobalTypeVar& var);
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);

/*!
* \brief Lookup a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const std::string& var);
TVM_DLL TypeData LookupDef(const std::string& var);

/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
void Update(const Module& other);
TVM_DLL void Update(const Module& other);

/*! \brief Construct a module from a standalone expression.
*
Expand All @@ -177,7 +177,7 @@ class ModuleNode : public RelayNode {
*
* \returns A module with expr set as the entry point.
*/
static Module FromExpr(
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});

Expand Down
20 changes: 20 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);

/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
Expand Down Expand Up @@ -403,6 +412,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
*/
TVM_DLL Expr PartialEval(const Expr& e);

/*!
* \brief Bind the free variables to a Relay expression.
*
* \param expr The expression.
* \param bind_map The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand Down
90 changes: 85 additions & 5 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@

#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -292,9 +294,9 @@ class Sequential : public Pass {
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info);
/*!
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
Expand All @@ -311,7 +313,6 @@ class Sequential : public Pass {
using ContainerType = Sequential;
};


/*
* \brief Create a module pass.
*
Expand Down Expand Up @@ -339,7 +340,7 @@ Pass CreateModulePass(
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, Module, PassContext)>& pass_func,
Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
Expand Down Expand Up @@ -451,6 +452,85 @@ TVM_DLL Pass ToGraphNormalForm();
*/
TVM_DLL Pass PartialEval();

/*!
* \brief Simplify certain operators during inference. For example, batch norm
* will be unpacked into a number of simplified operators.
*
* \return The Pass.
*/
TVM_DLL Pass SimplifyInference();

/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
*/
TVM_DLL Pass InferType();

/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is created
* and these two expressions are replaced by this variable.
*
* \param fskip The callback argument that allows to skip certain expressions.
*
* \return The pass.
*/
TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);

/*!
* \brief Combine parallel 2d convolutions into a single convolution if the
* number of branches of this conv2d operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);

/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass BackwardFoldScaleAxis();

/*!
* \brief Forward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass ForwardFoldScaleAxis();

/*!
* \brief A sequential pass that executes ForwardFoldScaleAxis and
* BackwardFoldScaleAxis passes.
*
* \return The pass.
*/
TVM_DLL Pass FoldScaleAxis();

/*!
* \brief Canonicalize some operators to the simplified operators. For example,
* bias_add can be canonicalized to expand_dims and broadcast_add.
*
* \return The pass.
*/
TVM_DLL Pass CanonicalizeOps();

/*!
* \brief Alternate the layouts of operators or replace primitive operators
* with other expressions.
*
* \return The pass.
*/
TVM_DLL Pass AlterOpLayout();

} // namespace transform
} // namespace relay
} // namespace tvm
Expand Down
Loading

0 comments on commit bb48a45

Please sign in to comment.