Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][TRANSFORM] Migrate buildmodule to transform #3251

Merged
merged 12 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);

/*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
*/
void AddDef(const GlobalTypeVar& var, const TypeData& type);
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Add a function to the global environment.
Expand All @@ -103,69 +103,69 @@ class ModuleNode : public RelayNode {
*
* It does not do type inference as Add does.
*/
void AddUnchecked(const GlobalVar& var, const Function& func);
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);

/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);
TVM_DLL void Update(const GlobalVar& var, const Function& func);

/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);
TVM_DLL void Remove(const GlobalVar& var);

/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);
TVM_DLL GlobalVar GetGlobalVar(const std::string& str);

/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalTypeVar GetGlobalTypeVar(const std::string& str);
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);

/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);
TVM_DLL Function Lookup(const GlobalVar& var);

/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);
TVM_DLL Function Lookup(const std::string& name);

/*!
* \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const GlobalTypeVar& var);
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);

/*!
* \brief Lookup a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const std::string& var);
TVM_DLL TypeData LookupDef(const std::string& var);

/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
void Update(const Module& other);
TVM_DLL void Update(const Module& other);

/*! \brief Construct a module from a standalone expression.
*
Expand All @@ -177,7 +177,7 @@ class ModuleNode : public RelayNode {
*
* \returns A module with expr set as the entry point.
*/
static Module FromExpr(
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});

Expand Down
20 changes: 20 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);

/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
Expand Down Expand Up @@ -403,6 +412,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
*/
TVM_DLL Expr PartialEval(const Expr& e);

/*!
* \brief Bind the free variables to a Relay expression.
*
* \param expr The expression.
* \param bind_map The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand Down
90 changes: 85 additions & 5 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@

#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -292,9 +294,9 @@ class Sequential : public Pass {
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info);
/*!
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
Expand All @@ -311,7 +313,6 @@ class Sequential : public Pass {
using ContainerType = Sequential;
};


/*
* \brief Create a module pass.
*
Expand Down Expand Up @@ -339,7 +340,7 @@ Pass CreateModulePass(
* \return The created function pass.
*/
TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, Module, PassContext)>& pass_func,
Function(Function, Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
Expand Down Expand Up @@ -451,6 +452,85 @@ TVM_DLL Pass ToGraphNormalForm();
*/
TVM_DLL Pass PartialEval();

/*!
* \brief Simplify certain operators during inference. For example, batch norm
* will be unpacked into a number of simplified operators.
*
* \return The Pass.
*/
TVM_DLL Pass SimplifyInference();

/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
*/
TVM_DLL Pass InferType();

/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is created
* and these two expressions are replaced by this variable.
*
* \param fskip The callback argument that allows to skip certain expressions.
*
* \return The pass.
*/
TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);

/*!
* \brief Combine parallel 2d convolutions into a single convolution if the
* number of branches of this conv2d operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);

/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass BackwardFoldScaleAxis();

/*!
* \brief Forward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL Pass ForwardFoldScaleAxis();

/*!
* \brief A sequential pass that executes ForwardFoldScaleAxis and
* BackwardFoldScaleAxis passes.
*
* \return The pass.
*/
TVM_DLL Pass FoldScaleAxis();

/*!
* \brief Canonicalize some operators to the simplified operators. For example,
* bias_add can be canonicalized to expand_dims and broadcast_add.
*
* \return The pass.
*/
TVM_DLL Pass CanonicalizeOps();

/*!
* \brief Alternate the layouts of operators or replace primitive operators
* with other expressions.
*
* \return The pass.
*/
TVM_DLL Pass AlterOpLayout();

} // namespace transform
} // namespace relay
} // namespace tvm
Expand Down
Loading