From bb48a45bcfc7d8a40dadca0ab7f589f59fdec374 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Mon, 3 Jun 2019 10:40:38 -0700 Subject: [PATCH] [RELAY][TRANSFORM] Migrate buildmodule to transform (#3251) --- include/tvm/relay/module.h | 26 +- include/tvm/relay/pass.h | 20 ++ include/tvm/relay/transform.h | 90 ++++- python/tvm/relay/build_module.py | 94 +----- python/tvm/relay/transform.py | 199 +++++++++++ src/relay/backend/build_module.cc | 370 +++++++-------------- src/relay/pass/alter_op_layout.cc | 27 +- src/relay/pass/canonicalize_ops.cc | 17 + src/relay/pass/combine_parallel_conv2d.cc | 17 + src/relay/pass/dead_code.cc | 5 +- src/relay/pass/device_annotation.cc | 8 +- src/relay/pass/eliminate_common_subexpr.cc | 17 + src/relay/pass/fold_constant.cc | 8 +- src/relay/pass/fold_scale_axis.cc | 42 ++- src/relay/pass/forward_rewrite.cc | 4 +- src/relay/pass/fuse_ops.cc | 7 +- src/relay/pass/partial_eval.cc | 9 +- src/relay/pass/pass_manager.cc | 166 +++++---- src/relay/pass/simplify_inference.cc | 17 + src/relay/pass/to_a_normal_form.cc | 5 +- src/relay/pass/to_graph_normal_form.cc | 5 +- src/relay/pass/type_infer.cc | 19 ++ tests/cpp/relay_transform_sequential.cc | 111 +++++++ tests/python/relay/test_pass_manager.py | 51 ++- 24 files changed, 879 insertions(+), 455 deletions(-) create mode 100644 tests/cpp/relay_transform_sequential.cc diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 6441fb3f5b9c..3966a6258a20 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -87,14 +87,14 @@ class ModuleNode : public RelayNode { * \param update Controls whether you can replace a definition in the * environment. */ - void Add(const GlobalVar& var, const Function& func, bool update = false); + TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); /*! * \brief Add a type-level definition to the global environment. * \param var The var of the global type definition. * \param type The type definition. */ - void AddDef(const GlobalTypeVar& var, const TypeData& type); + TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type); /*! * \brief Add a function to the global environment. @@ -103,69 +103,69 @@ class ModuleNode : public RelayNode { * * It does not do type inference as Add does. */ - void AddUnchecked(const GlobalVar& var, const Function& func); + TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); /*! * \brief Update a function in the global environment. * \param var The name of the global function to update. * \param func The new function. */ - void Update(const GlobalVar& var, const Function& func); + TVM_DLL void Update(const GlobalVar& var, const Function& func); /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. */ - void Remove(const GlobalVar& var); + TVM_DLL void Remove(const GlobalVar& var); /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalVar GetGlobalVar(const std::string& str); + TVM_DLL GlobalVar GetGlobalVar(const std::string& str); /*! * \brief Look up a global function by its name. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalTypeVar GetGlobalTypeVar(const std::string& str); + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str); /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ - Function Lookup(const GlobalVar& var); + TVM_DLL Function Lookup(const GlobalVar& var); /*! * \brief Lookup a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ - Function Lookup(const std::string& name); + TVM_DLL Function Lookup(const std::string& name); /*! * \brief Lookup a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ - TypeData LookupDef(const GlobalTypeVar& var); + TVM_DLL TypeData LookupDef(const GlobalTypeVar& var); /*! * \brief Lookup a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ - TypeData LookupDef(const std::string& var); + TVM_DLL TypeData LookupDef(const std::string& var); /*! * \brief Update the functions inside this environment by * functions in another environment. * \param other The other environment. */ - void Update(const Module& other); + TVM_DLL void Update(const Module& other); /*! \brief Construct a module from a standalone expression. * @@ -177,7 +177,7 @@ class ModuleNode : public RelayNode { * * \returns A module with expr set as the entry point. */ - static Module FromExpr( + TVM_DLL static Module FromExpr( const Expr& expr, const tvm::Map& global_funcs = {}); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 67cc5df82407..81587339f2ad 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -358,6 +358,15 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); */ TVM_DLL Map CollectDeviceInfo(const Expr& expr); +/*! + * \brief Collect the device anntation operators. + * + * \param expr The expression. + * + * \return The annotated expression to device type mapping for annotation ops. + */ +TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); + /*! * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). * @@ -403,6 +412,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); */ TVM_DLL Expr PartialEval(const Expr& e); +/*! + * \brief Bind the free variables to a Relay expression. + * + * \param expr The expression. + * \param bind_map The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); + /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1c1b60813b78..793bc981ea61 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -58,9 +58,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -292,9 +294,9 @@ class Sequential : public Pass { * \param passes The passes to apply. * \param pass_info The pass metadata. */ - TVM_DLL Sequential(tvm::Array passes, - PassInfo pass_info); -/*! + TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); + + /*! * \brief The constructor of `Sequential`. * * \param passes The passes to apply. @@ -311,7 +313,6 @@ class Sequential : public Pass { using ContainerType = Sequential; }; - /* * \brief Create a module pass. * @@ -339,7 +340,7 @@ Pass CreateModulePass( * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, Module, PassContext)>& pass_func, + Function(Function, Module, PassContext)>& pass_func, int opt_level, const std::string& name, const tvm::Array& required); @@ -451,6 +452,85 @@ TVM_DLL Pass ToGraphNormalForm(); */ TVM_DLL Pass PartialEval(); +/*! + * \brief Simplify certain operators during inference. For example, batch norm + * will be unpacked into a number of simplified operators. + * + * \return The Pass. + */ +TVM_DLL Pass SimplifyInference(); + +/*! + * \brief Infer the type of an expression. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \return The pass. + */ +TVM_DLL Pass InferType(); + +/*! + * \brief Search and eliminate common subexpression. For example, if there are + * two expressions evaluated to an identical value, a single variable is created + * and these two expressions are replaced by this variable. + * + * \param fskip The callback argument that allows to skip certain expressions. + * + * \return The pass. + */ +TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); + +/*! + * \brief Combine parallel 2d convolutions into a single convolution if the + * number of branches of this conv2d operator is not less than + * `min_num_branch`. + * + * \param min_num_branches The minimun number of branches. + * + * \return The pass. + */ +TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); + +/*! + * \brief Backward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass BackwardFoldScaleAxis(); + +/*! + * \brief Forward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass ForwardFoldScaleAxis(); + +/*! + * \brief A sequential pass that executes ForwardFoldScaleAxis and + * BackwardFoldScaleAxis passes. + * + * \return The pass. + */ +TVM_DLL Pass FoldScaleAxis(); + +/*! + * \brief Canonicalize some operators to the simplified operators. For example, + * bias_add can be canonicalized to expand_dims and broadcast_add. + * + * \return The pass. + */ +TVM_DLL Pass CanonicalizeOps(); + +/*! + * \brief Alternate the layouts of operators or replace primitive operators + * with other expressions. + * + * \return The pass. + */ +TVM_DLL Pass AlterOpLayout(); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 6cee393d5f91..8f9b0481a22c 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -20,7 +20,6 @@ """ import numpy as np -from tvm._ffi.runtime_ctypes import TVMContext from tvm import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt @@ -28,7 +27,6 @@ from . import ir_pass from . import ty as _ty from . import expr as _expr -from . import transform as _transform from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -61,10 +59,6 @@ def __init__(self): self._get_graph_json = self.mod["get_graph_json"] self._get_module = self.mod["get_module"] self._build = self.mod["build"] - self._add_pass = self.mod["add_pass"] - self._disable_pass = self.mod["disable_pass"] - self._set_opt_level = self.mod["set_opt_level"] - self._set_fallback_device = self.mod["set_fallback_device"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] @@ -106,8 +100,9 @@ def build(self, func, target=None, target_host=None, params=None): """ target = _update_target(target) - # Setup the build configurations passed in through `with build_config`. - self._setup_build_config(params) + # Setup the params. + if params: + self._set_params(params) # Build the function self._build(func, target, target_host) # Get artifacts @@ -117,41 +112,6 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params - def _setup_build_config(self, params): - cfg = _transform.PassContext.current() - - # Set opt_level. - self.set_opt_level(cfg.opt_level) - - # Set fallback device if it is available. - if cfg.fallback_device: - self.set_fallback_device(cfg.fallback_device) - - # Add required passes. - if cfg.required_pass: - passes = set() - if isinstance(cfg.required_pass, (list, tuple, set)): - passes = set(cfg.required_pass) - else: - raise TypeError("add_pass must be list, tuple, or set, but " + - "got {}".format(type(cfg.required_pass))) - for pass_name in passes: - self.add_pass(pass_name) - - # Add disabled passes. - if cfg.disabled_pass: - passes = set() - if isinstance(cfg.disabled_pass, (list, tuple, set)): - passes = set(cfg.disabled_pass) - else: - raise TypeError("disable_pass must be list, tuple, or set, " + - "but got {}".format(type(cfg.disabled_pass))) - for pass_name in passes: - self.disable_pass(pass_name) - - if params: - self._set_params(params) - def _set_params(self, params): inputs = {} for name, param in params.items(): @@ -160,28 +120,6 @@ def _set_params(self, params): inputs[name] = _expr.const(param) self._set_params_func(inputs) - def add_pass(self, pass_name): - """Add a pass to the pass list. - - Parameters - ---------- - pass_name : str - The name of the pass that will be added to the list of passes used - for optimizations. - """ - self._add_pass(pass_name) - - def disable_pass(self, pass_name): - """Add a pass to the disabled pass list. - - Parameters - ---------- - pass_name : str - The name of a pass. This pass will be added to the list of passes - that are disabled during optimization. - """ - self._disable_pass(pass_name) - def get_json(self): """Return the json file of the built program.""" return self._get_graph_json() @@ -198,32 +136,6 @@ def get_params(self): ret[key] = value.data return ret - def set_opt_level(self, level): - """Set the optimization level. - - Parameters - ---------- - level : int - The optimization level for build. - """ - self._set_opt_level(level) - - def set_fallback_device(self, fallback_device): - """Set the fallback device for heterogeneous execution. - - Parameters - ---------- - fallback_device : str or tvm.TVMContext - The fallback device used for heterogeneous execution. - """ - if isinstance(fallback_device, (int, str)): - fallback_device = _nd.context(fallback_device) - if not isinstance(fallback_device, TVMContext): - raise TypeError("fallback_device is expected to be str, int, or " + - "TVMContext but received: {}".format(type(fallback_device))) - - self._set_fallback_device(fallback_device.device_type) - def build(func, target=None, target_host=None, params=None): """Helper function that builds a Relay function to run on TVM graph diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index a7887c630c76..38079b010e7d 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck +# pylint: disable=invalid-name """ This file contains the pass manager for Relay which exposes different granularity of interfaces for users to implement and use passes more @@ -394,3 +395,201 @@ def create_function_pass(pass_func): if pass_func: return create_function_pass(pass_func) return create_function_pass + + +def InferType(): + """Infer the type of an expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered type inference pass. + """ + return _transform.InferType() + + +def FoldScaleAxis(): + """Fold the scaling of axis into weights of conv2d/dense. This pass will + invoke both forward and backward scale folding. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to fold expressions. + + Note + ---- + Internally, we will call backward_fold_scale_axis before using + forward_fold_scale_axis. As backward folding targets common conv-bn + pattern. + """ + return _transform.FoldScaleAxis() + + +def SimplifyInference(): + """Simplify the data-flow graph for inference phase. An simplified expression + which is semantically equal to the input expression will be returned. + + Returns + ------- + ret: tvm.relay.Pass + The registered to perform operator simplification. + """ + return _transform.SimplifyInference() + + +def CanonicalizeOps(): + """ Canonicalize special operators to basic operators. + This can simplify followed analysis. (e.g. expanding bias_add to + expand_dims and broadcast_add.) + + Returns + ------- + ret: tvm.relay.Pass + The registered pass performing the canonicalization. + """ + return _transform.CanonicalizeOps() + + +def DeadCodeElimination(): + """ Remove expressions which does not effect the program result (dead code). + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that eliminates the dead code in a Relay program. + """ + return _transform.DeadCodeElimination() + + +def FoldConstant(): + """Fold the constant expression in expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for constant folding. + """ + return _transform.FoldConstant() + + +def FuseOps(fuse_opt_level=-1): + """Fuse operators in an expr to a larger operator according to some rules. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for operator fusion. + """ + return _transform.FuseOps(fuse_opt_level) + + +def CombineParallelConv2D(min_num_branches=3): + """Combine multiple conv2d operators into one. + + Parameters + ---------- + min_num_branches : int + The minimum number of required parallel branches for performing this + optimization. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that combines parallel conv2d operators. + """ + return _transform.CombineParallelConv2D(min_num_branches) + + +def AlterOpLayout(): + """Alternate the layouts of operators or replace primitive operators with + other expressions. + This pass can be used for computing convolution in custom layouts or + other general weight pre-transformation. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that alters the layout of operators. + """ + return _transform.AlterOpLayout() + + +def RewriteAnnotatedOps(fallback_device): + """Rewrite the annotated program where annotation operators, e.g. + `on_deivce`, mark which device an expression should be scheduled to. + This pass helps heterogeneous execution where different operators may need + to be allocated on various devices. + + Parameters + ---------- + fallback_device : int + The fallback device type. It is also used as the default device for + operators with no annotated device. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that rewrites an expression with annotated + `on_device` operators. + """ + return _transform.RewriteDeviceAnnotation(fallback_device) + + +def ToANormalForm(): + """Turn Graph Normal Form expression into A Normal Form Expression. + The scope of the root expression is the global scope. + The scope of any non root expression is the least common ancestor of all it's scope. + Values are ordered by post-DFS order in each scope. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that transforms an expression into A Normal Form. + """ + return _transform.ToANormalForm() + + +def ToGraphNormalForm(): + """Turn A Normal Form expression into Graph Normal Form expression + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that transforms an expression into Graph Normal Form. + """ + return _transform.ToGraphNormalForm() + + +def EliminateCommonSubexpr(fskip=None): + """Eliminate common subexpressions. + + Parameters + ---------- + fskip: Callable + The callback function that decides whether an expression should be + skipped. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that eliminates common subexpressions. + """ + return _transform.EliminateCommonSubexpr(fskip) + + +def PartialEvaluate(): + """Evaluate the static fragment of the code. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that performs partial evaluation on an expression. + """ + return _transform.PartialEvaluate() diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 57dc256ef6b7..e0014e919089 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -23,12 +23,8 @@ */ #include #include -#include #include -#include -#include -#include -#include +#include #include #include "utils.h" @@ -38,39 +34,7 @@ namespace relay { namespace backend { using TargetsMap = Map; - -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - * - */ -struct OptPassLevel { - static const std::unordered_map _data; - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - auto it = _data.find(key); - if (it == _data.end()) { - return -1; - } - return it->second; - } -}; - -const std::unordered_map OptPassLevel::_data = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 4}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} -}; +using namespace tvm::relay::transform; /*! * \brief Output of building module @@ -82,27 +46,6 @@ struct BuildOutput { std::unordered_map params; }; -/*! - * \brief Relay building config - * - */ -struct RelayBuildConfig { - int opt_level{2}; - int fallback_device{static_cast(kDLCPU)}; - std::unordered_set enabled_pass; - std::unordered_set disabled_pass; - OptPassLevel OPT_PASS_LEVEL; - inline bool pass_enabled(const std::string& pass_name) const { - if (disabled_pass.count(pass_name)) { - return false; - } - if (enabled_pass.count(pass_name)) { - return true; - } - return opt_level >= OPT_PASS_LEVEL[pass_name]; - } -}; - /*! * \brief GraphCodegen module wrapper * @@ -156,18 +99,6 @@ struct GraphCodegen { } }; -template -R CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - -template -Function CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - /*! * \brief Relay build module * @@ -203,28 +134,6 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); - } else if (name == "set_opt_level") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - int level = args[0]; - this->SetOptLevel(level); - }); - } else if (name == "set_fallback_device") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - int dev = args[0]; - this->SetFallBackDev(dev); - }); - } else if (name == "add_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->AddPass(pass_name); - }); - } else if (name == "disable_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->DisablePass(pass_name); - }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -246,30 +155,7 @@ class RelayBuildModule : public runtime::ModuleNode { const std::string& GetGraphJSON() { return ret_.graph_json; } - /*! - * \brief Add extra pass into build cfg - * - * \param pass_name name of pass - */ - void AddPass(const std::string& pass_name) { - cfg_.enabled_pass.insert(pass_name); - } - /*! - * \brief Disable a specific pass in cfg - * - * \param pass_name name of pass - */ - void DisablePass(const std::string& pass_name) { - cfg_.disabled_pass.insert(pass_name); - } - /*! - * \brief Set the Fallback device - * - * \param device name - */ - void SetFallBackDev(int dev) { - cfg_.fallback_device = dev; - } + /*! * \brief Get the Module object * @@ -315,15 +201,6 @@ class RelayBuildModule : public runtime::ModuleNode { params_[name] = data_in; } - /*! - * \brief Set the optimization level - * - * \param level - */ - void SetOptLevel(char level) { - cfg_.opt_level = level; - } - /*! * \brief type key * @@ -345,7 +222,7 @@ class RelayBuildModule : public runtime::ModuleNode { const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; - BuildRelay(func, cfg_, params_); + BuildRelay(func, params_); } protected: @@ -378,85 +255,81 @@ class RelayBuildModule : public runtime::ModuleNode { if (repeat_var.count(arg)) { LOG(FATAL) << "Multiple args in the function have name " << kv.first; } - auto e = CallPackedFunc("relay._make.Constant", kv.second); - bind_dict[arg] = e; + bind_dict[arg] = ConstantNode::make(kv.second); } - return CallPackedFunc("relay._expr.Bind", func, tvm::Map(bind_dict)); + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) + << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; } /*! - * \brief Optimize Relay function + * \brief Optimize a Relay module. * - * \param func Input function - * \param target target device - * \param cfg Relay build config - * \param params params dict - * \return relay::Function + * \param relay_module The input Relay module where optmization will be + * applied on. + * \param targets The device type to `Target` mapping. + * \param params The param name to value mapping. + * + * \return relay::Module The updated Relay module after optimization. */ - relay::Function Optimize(relay::Function func, - const TargetsMap& targets, - const RelayBuildConfig& cfg, - const std::unordered_map& params) { - if (params.size()) { - func = BindParamsByName(func, params); - } - if (cfg.pass_enabled("SimplifyInference")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.simplify_inference", func); - } - if (cfg.pass_enabled("EliminateCommonSubexpr")) { - auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == HalideIR::Int(32)) { - *rv = true; - } + relay::Module Optimize( + relay::Module relay_module, + const TargetsMap& targets, + const std::unordered_map& params) { + Array pass_seqs; + pass_seqs.push_back(transform::SimplifyInference()); + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; } } - *rv = false; - }); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip); - } - if (cfg.pass_enabled("CombineParallelConv2D")) { - const int min_num_branches = 3; - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches); - } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("FoldScaleAxis")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("CanonicalizeOps")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func); + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + pass_seqs.push_back(transform::AlterOpLayout()); } - if (cfg.pass_enabled("AlterOpLayout")) { - if (targets.size() == 1) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - for (const auto& kv : targets) { - With tctx(kv.second); - func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); - } - } else { - LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" - << " execution yet."; + pass_seqs.push_back(transform::FoldConstant()); + + // Create a sequential pass and perform optimizations. + transform::Pass seq = transform::Sequential(pass_seqs); + if (targets.size() == 1) { + for (const auto& kv : targets) { + With tctx(kv.second); + relay_module = seq(relay_module); } + } else { + relay_module = seq(relay_module); } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); + + // Handle heterogeneous compilation. + transform::PassContext pass_ctx = PassContext::Current(); + if (targets_.size() > 1) { + relay_module = + RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); } - return func; + + // Fuse the operations if it is needed. + relay_module = transform::FuseOps()(relay_module); + relay_module = transform::InferType()(relay_module); + + return relay_module; } /*! @@ -470,54 +343,58 @@ class RelayBuildModule : public runtime::ModuleNode { if (name == "gpu") return Target::Create("cuda"); return Target::Create(name); } + /*! * \brief Update the target and fallback device required for heterogeneous * compilation. CPU is used as the fallback device if it wasn't provided. * Meanwhile, a CPU device type and "llvm" pair will be added to the target * dictionary in this case. * - * \param targets dictionary - * \param cfg - * \return Map + * \param fallback_device The fallback device for heterogeneous execution. */ - TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets, - const RelayBuildConfig& cfg) { - TargetsMap device_target = targets; + void UpdateHeterogeneousInputs(int fallback_device) { std::unordered_map tmp_map; - for (const auto& kv : targets) { + for (const auto& kv : targets_) { tmp_map[kv.first->value] = kv.second; } - if (tmp_map.count(cfg.fallback_device) == 0) { - device_target.Set( - cfg.fallback_device, - CreateDefaultTarget(cfg.fallback_device)); + if (tmp_map.count(fallback_device) == 0) { + targets_.Set(fallback_device, CreateDefaultTarget(fallback_device)); } - return device_target; } + /*! * \brief Execute the device annotation passes to update the input program and * target information. * - * \param func - * \param cfg - * \param targets_map_ptr - * \return Function + * \param relay_module The input Relay module. + * \param fallback_device The fallback device for heterogeneous execution. + * + * \return updated_module The updated module after device annotation. */ - Function RunDeviceAnnotationPass(Function func, - const RelayBuildConfig& cfg, - TargetsMap* targets_map_ptr) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, - cfg.fallback_device); - auto device_map = CallPackedFunc >( - "relay._ir_pass.CollectDeviceInfo", func, nullptr); - if (device_map.size() == 0) { - auto annotation_map = CallPackedFunc >( - "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr); - if (annotation_map.size() == 0) { - targets_map_ptr->Set( - 0, CreateDefaultTarget(cfg.fallback_device)); + relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, + int fallback_device) { + UpdateHeterogeneousInputs(fallback_device); + auto rewrite = transform::RewriteAnnotatedOps(fallback_device); + auto updated_module = rewrite(relay_module); + CHECK(updated_module.defined()); + + tvm::Map device_map; + for (const auto& it : updated_module->functions) { + device_map = relay::CollectDeviceInfo(it.second); + if (!device_map.empty()) break; + } + + if (device_map.empty()) { + tvm::Map annotation_map; + for (const auto& it : relay_module->functions) { + annotation_map = relay::CollectDeviceAnnotationOps(it.second); + if (!annotation_map.empty()) break; + } + // None op is annotated but they are fallen back to the default device. + if (annotation_map.empty()) { + targets_.Set(0, CreateDefaultTarget(fallback_device)); } else { + // All ops are annotated to the same device type. int64_t dev_type = -1; for (auto kv : annotation_map) { dev_type = kv.second->value; @@ -531,47 +408,42 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set(0, CreateDefaultTarget(dev_type)); + targets_.Set(0, CreateDefaultTarget(dev_type)); } } - return func; + return updated_module; } /*! * \brief Build relay function to runtime module * * \param func Relay Function - * \param cfg Relay build config * \param params parameters */ - void BuildRelay(Function func, - const RelayBuildConfig& cfg, - const std::unordered_map ¶ms) { - // convert - tvm_cfg_ = BuildConfig::Create(); - TargetsMap device_target; - if (targets_.size() > 1) { - device_target = UpdateHeterogeneousInputs(targets_, cfg); - } else { - device_target = targets_; - } - func = Optimize(func, targets_, cfg, params); - if (device_target.size() > 1) { - func = RunDeviceAnnotationPass(func, cfg, &device_target); + void BuildRelay( + Function func, + const std::unordered_map& params) { + if (params.size()) { + func = BindParamsByName(func, params); } - // TODO(@jroesch): use the passes directly. - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level, nullptr); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + // Perform Module->Module optimizations. + relay::Module relay_module = relay::ModuleNode::FromExpr(func); + relay_module = Optimize(relay_module, targets_, params); + CHECK(relay_module.defined()); + // Get the updated function. + func = relay_module->Lookup(relay_module->entry_func->name_hint); + + // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, device_target); + graph_codegen_->Init(nullptr, targets_); graph_codegen_->Codegen(func); ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_); + ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, + BuildConfig::Current()); } protected: @@ -580,14 +452,10 @@ class RelayBuildModule : public runtime::ModuleNode { TargetsMap targets_; /*! \brief target host device */ tvm::Target target_host_; - /*! \brief frontend optimization configure */ - RelayBuildConfig cfg_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ BuildOutput ret_; - /*! \brief tvm building cfg */ - BuildConfig tvm_cfg_; }; runtime::Module RelayBuildCreate() { diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index f51c201d0b2a..d623393049a6 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -338,17 +339,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // Limiations: // 1. the altered op should have the same number of arguments as the previous one // 2. do not support nested tuple arguments -TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") -.set_body([](TVMArgs args, TVMRetValue *ret) { +Expr AlterOpLayout(const Expr& expr) { TransformMemorizer transformMemorizer(make_node()); auto fcontext = [&](const Call& call) -> NodeRef{ return transformMemorizer; }; - *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext); -}); + return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); +} + +TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") +.set_body_typed(AlterOpLayout); } // namespace alter_op_layout +namespace transform { + +Pass AlterOpLayout() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + }; + return CreateFunctionPass(pass_func, 3, "AlterOpLayout", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.AlterOpLayout") +.set_body_typed(AlterOpLayout); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 9a4602750195..ff9e2304a3bc 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "pattern_util.h" namespace tvm { @@ -63,5 +64,21 @@ Expr CanonicalizeOps(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") .set_body_typed(CanonicalizeOps); +namespace transform { + +Pass CanonicalizeOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CanonicalizeOps(f)); + }; + return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CanonicalizeOps") +.set_body_typed(CanonicalizeOps); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 7e76322d5a2a..c95c1ddf8e16 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include "./expr_subst.h" @@ -357,5 +358,21 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body_typed(CombineParallelConv2D); +namespace transform { + +Pass CombineParallelConv2D(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelConv2D(f, min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CombineParallelConv2D") +.set_body_typed(CombineParallelConv2D); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index dd1ed6240cab..be6774564806 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -158,9 +158,12 @@ Pass DeadCodeElimination() { [=](Function f, Module m, PassContext pc) { return Downcast(DeadCodeElimination(f)); }; - return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); + return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } +TVM_REGISTER_API("relay._transform.DeadCodeElimination") +.set_body_typed(DeadCodeElimination); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index e2d07619cb0f..02d6d9e1fefb 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -564,11 +565,14 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, Module m, PassContext pc) { return Downcast(RewriteAnnotatedOps(f, fallback_device)); }; - return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {}); + return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", + {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation") +.set_body_typed(RewriteAnnotatedOps); + } // namespace transform } // namespace relay } // namespace tvm - diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index f8432f671855..883681adcaf4 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -29,6 +29,7 @@ */ #include #include +#include #include #include "./pattern_util.h" @@ -87,5 +88,21 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr") .set_body_typed(EliminateCommonSubexpr); +namespace transform { + +Pass EliminateCommonSubexpr(PackedFunc fskip) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(f, fskip)); + }; + return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr") +.set_body_typed(EliminateCommonSubexpr); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 286392ab5d3f..815407038b08 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -220,11 +221,14 @@ namespace transform { Pass FoldConstant() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(FoldConstant(f)); + return Downcast(FoldConstant(f)); }; - return CreateFunctionPass(pass_func, 1, "fold_constant", {}); + return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } +TVM_REGISTER_API("relay._transform.FoldConstant") +.set_body_typed(FoldConstant); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index c738e3e3b731..53089807ace5 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "pattern_util.h" #include "pass_util.h" @@ -530,7 +531,7 @@ RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); -Expr ForwardFoldScaleAxis(Expr data) { +Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); auto fcontext = [&](const Call& call) -> NodeRef{ auto it = message.find(call.get()); @@ -942,7 +943,7 @@ RELAY_REGISTER_OP("nn.conv2d") RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); -Expr BackwardFoldScaleAxis(Expr data) { +Expr BackwardFoldScaleAxis(const Expr& data) { return make_node()->Fold(data); } @@ -950,5 +951,42 @@ TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") .set_body_typed(BackwardFoldScaleAxis); } // namespace fold_scale_axis + +namespace transform { + +Pass ForwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::ForwardFoldScaleAxis(f)); + }; + return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); +} + +Pass BackwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::BackwardFoldScaleAxis(f)); + }; + return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); +} + +Pass FoldScaleAxis() { + // FoldScaleAxis pass contains the following three passes. Therefore, we can + // register it as a sequential pass. + Pass pass = Sequential( + {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, + "FoldScaleAxis"); + return pass; +} + +TVM_REGISTER_API("relay._transform.FoldScaleAxis") +.set_body_typed(FoldScaleAxis); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 2a3aa1612418..8ad61270e33a 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -220,7 +220,7 @@ Pass ForwardRewrite(const std::string& rewrite_map_attr_name, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {}); } Pass ForwardRewrite(const FForwardRewrite& rewrite_func, @@ -233,7 +233,7 @@ Pass ForwardRewrite(const FForwardRewrite& rewrite_func, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {}); } } // namespace transform diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9277689075c2..9f940e54953b 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "./pattern_util.h" #include "../../common/arena.h" @@ -973,9 +974,13 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - return CreateFunctionPass(pass_func, 1, "fuse_ops", {}); + return CreateFunctionPass(pass_func, 1, "FuseOps", + {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.FuseOps") +.set_body_typed(FuseOps); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 3f42c6fce4b2..71ba7cd11cd5 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -797,9 +797,7 @@ Expr PartialEval(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.partial_evaluate") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = PartialEval(args[0]); - }); +.set_body_typed(PartialEval); namespace transform { @@ -808,9 +806,12 @@ Pass PartialEval() { [=](Function f, Module m, PassContext pc) { return Downcast(PartialEval(f)); }; - return CreateFunctionPass(pass_func, 1, "partial_eval", {}); + return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {}); } +TVM_REGISTER_API("relay._transform.PartialEvaluate") +.set_body_typed(PartialEval); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a9c671aa163a..13e908d28f7a 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -37,42 +37,46 @@ namespace transform { using tvm::IRPrinter; -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - */ -class OptPassLevel { - public: - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - const auto data = CreateMap(); - auto it = data.find(key); - if (it == data.end()) { - return -1; - } - return it->second; +namespace { + +// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be +// handled because we need to register the pass for Python invocation anyway. +Pass GetPass(const std::string& pass_name) { + if (pass_name == "InferType") { + return InferType(); + } else if (pass_name == "AlterOpLayout") { + return AlterOpLayout(); + } else if (pass_name == "CanonicalizeOps") { + return CanonicalizeOps(); + } else if (pass_name == "CombineParallelConv2d") { + return CombineParallelConv2D(); + } else if (pass_name == "DeadCodeElimination") { + return DeadCodeElimination(); + } else if (pass_name == "EliminateCommonSubexpr") { + return DeadCodeElimination(); + } else if (pass_name == "FoldConstant") { + return FoldConstant(); + } else if (pass_name == "BackwardFoldScaleAxis") { + return FoldScaleAxis(); + } else if (pass_name == "ForwardFoldScaleAxis") { + return FoldScaleAxis(); + } else if (pass_name == "FoldScaleAxis") { + return FoldScaleAxis(); + } else if (pass_name == "PartialEvaluate") { + return SimplifyInference(); + } else if (pass_name == "SimplifyInference") { + return SimplifyInference(); + } else if (pass_name == "ToANormalForm") { + return ToANormalForm(); + } else if (pass_name == "ToGraphNormalForm") { + return ToGraphNormalForm(); + } else { + LOG(FATAL) << pass_name << " has not been registered yet." << "\n"; + return Pass(nullptr); } +} - private: - static const std::unordered_map CreateMap() { - const std::unordered_map m = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 3}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} - }; - return m; - } -}; +} // namespace struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -246,12 +250,6 @@ class SequentialNode : public PassNode { /* \brief The pass meta data.*/ PassInfo pass_info; - /*! - * \brief A helper struct to get the optimization pass name to opt level - * mapping. - */ - OptPassLevel opt_pass_level; - /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -300,7 +298,7 @@ class SequentialNode : public PassNode { const Array& disabled) const; std::unordered_set RequiredPasses( - const Array& disabled) const; + const Array& required) const; /*! * \brief Perform optimizations on a series of passes. The aforementioned @@ -338,14 +336,25 @@ ModulePass ModulePassNode::make( } // Module -> Module optimizations. -// TODO(zhiics) Check and handle the required passes. Module ModulePassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level << "\n"; + CHECK(mod.defined()); - auto updated_mod = pass_func(mod, pass_ctx); + Module updated_mod = mod; + // Execute the required passes in a DFS way. + // TODO(zhiics) We may need to pass validation to detect the cyclic + // dependency. + for (const auto& it : pass_info->required) { + const auto* name = it.as(); + CHECK(name); + auto pass = GetPass(name->value); + updated_mod = pass(updated_mod, pass_ctx); + } + + updated_mod = pass_func(updated_mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } @@ -365,12 +374,26 @@ Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); CHECK(mod.defined()); - Module new_mod = ModuleNode::make({}, mod->type_definitions); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level << "\n"; + + Module updated_mod = mod; + // Execute the required passes in a DFS way. + // TODO(zhiics) We may need to pass validation to detect the cyclic + // dependency. + for (const auto& it : pass_info->required) { + const auto* name = it.as(); + CHECK(name); + auto pass = GetPass(name->value); + updated_mod = pass(updated_mod, pass_ctx); + } + + Module new_mod = ModuleNode::make({}, mod->type_definitions); // Execute the pass function and return a new module. for (const auto& it : mod->functions) { - auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx); + auto updated_func = SkipFunction(it.second) + ? it.second + : pass_func(it.second, updated_mod, pass_ctx); new_mod->Add(it.first, updated_func); } @@ -418,7 +441,7 @@ std::unordered_set SequentialNode::DisabledPasses( std::unordered_set ret; for (const auto& it : disabled) { const auto* str = it.as(); - CHECK(str) << "disabled passes must be string."; + CHECK(str) << "Disabled pass name must be string."; ret.emplace(str->value); } return ret; @@ -429,7 +452,7 @@ std::unordered_set SequentialNode::RequiredPasses( std::unordered_set ret; for (const auto& it : required) { const auto* str = it.as(); - CHECK(str) << "disabled passes must be string."; + CHECK(str) << "Required pass name must be string."; ret.emplace(str->value); } return ret; @@ -439,7 +462,7 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { PassContext ctx = PassContext::Current(); auto required = RequiredPasses(ctx->required_pass); - auto disabled = DisabledPasses(ctx->required_pass); + auto disabled = DisabledPasses(ctx->disabled_pass); if (disabled.count(pass_name)) { return false; @@ -448,29 +471,27 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { if (required.count(pass_name)) { return true; } - return ctx->opt_level >= opt_pass_level[pass_name]; + + const Pass pass = GetPass(pass_name); + PassInfo info = pass->Info(); + return ctx->opt_level >= info->opt_level; } // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase -// ordering problem needed to be handled in the future. +// ordering problem needs to be handled in the future. Module SequentialNode::operator()(const Module& module, const PassContext& pass_ctx) const { - int opt_level = pass_ctx->opt_level; - auto disabled = DisabledPasses(pass_ctx->disabled_pass); Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; + PassInfo info = pass->Info(); const auto& pass_name = info->name; - const auto& pass_opt_level = info->opt_level; - // Skip the pass if its optimization level is higher that the one of in the - // pass context or if this pass is disabled. - if (pass_opt_level > opt_level || disabled.count(pass_name)) { - continue; + // Execute the pass if it is enabled. + if (PassEnabled(pass_name)) { + mod = pass(mod, pass_ctx); } - const auto* pn = pass.operator->(); - mod = (*pn)(mod, pass_ctx); } return mod; } @@ -525,15 +546,17 @@ TVM_REGISTER_API("relay._transform.CreateModulePass") TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Pass()(args[1]); + Pass pass = args[0]; + Module mod = args[1]; + *ret = pass(mod); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ModulePassNode* node, tvm::IRPrinter* p) { - const PassInfoNode* pn = node->Info().operator->(); - p->stream << "Run Module pass: " << pn->name - << " at the optimization level " << pn->opt_level; + const PassInfo info = node->Info(); + p->stream << "Run Module pass: " << info->name + << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(FunctionPassNode); @@ -544,9 +567,9 @@ TVM_REGISTER_API("relay._transform.CreateFunctionPass") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionPassNode* node, tvm::IRPrinter* p) { - const PassInfoNode* pn = node->Info().operator->(); - p->stream << "Run Function pass: " << pn->name - << " at the optimization level " << pn->opt_level; + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name + << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(SequentialNode); @@ -564,14 +587,13 @@ TVM_REGISTER_API("relay._transform.Sequential") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SequentialNode* node, tvm::IRPrinter* p) { - const PassInfoNode* seq_pn = node->Info().operator->(); - p->stream << "Run Sequential pass: " << seq_pn->name - << " at the optimization level " << seq_pn->opt_level << ". "; + const PassInfo info = node->Info(); + p->stream << "Run Sequential pass: " << info->name + << " at the optimization level " << info->opt_level << ". "; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { - const PassNode* pn = it.operator->(); - const PassInfoNode* pass_info_node = pn->Info().operator->(); - p->stream << pass_info_node->name << " "; + const PassInfo pass_info = it->Info(); + p->stream << pass_info->name << " "; } p->stream << "]"; }); diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 8dab0c370853..6d6b24abec20 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "./pattern_util.h" namespace tvm { @@ -105,5 +106,21 @@ Expr SimplifyInference(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.simplify_inference") .set_body_typed(SimplifyInference); +namespace transform { + +Pass SimplifyInference() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(SimplifyInference(f)); + }; + return CreateFunctionPass(pass_func, 0, "SimplifyInference", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.SimplifyInference") +.set_body_typed(SimplifyInference); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index f9d47f78a6d2..324eddd21c5c 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -340,9 +340,12 @@ Pass ToANormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToANormalForm(f, m)); }; - return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); + return CreateFunctionPass(pass_func, 1, "ToANormalForm", {}); } +TVM_REGISTER_API("relay._transform.ToANormalForm") +.set_body_typed(ToANormalForm); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 50ebb702e4b2..9c166f98c1a5 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -86,9 +86,12 @@ Pass ToGraphNormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToGraphNormalForm(f)); }; - return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); + return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); } +TVM_REGISTER_API("relay._transform.ToGraphNormalForm") +.set_body_typed(ToGraphNormalForm); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 482cef3b2c2d..3fde3c7e7b36 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -43,6 +43,7 @@ #include #include #include +#include #include "./pass_util.h" #include "type_solver.h" #include "../ir/type_functor.h" @@ -807,5 +808,23 @@ TVM_REGISTER_API("relay._ir_pass.infer_type") .set_body_typed([](const Expr& expr, const Module& mod_ref) { return InferType(expr, mod_ref); }); + +namespace transform { + +Pass InferType() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(InferType(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "InferType", {}); +} + +TVM_REGISTER_API("relay._transform.InferType") +.set_body_typed([]() { + return InferType(); +}); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc new file mode 100644 index 000000000000..b61a5cc0daad --- /dev/null +++ b/tests/cpp/relay_transform_sequential.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TVM_REGISTER_GLOBAL("schedule") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + *rv = topi::generic::schedule_injective(args[0], args[1]); + }); + +TEST(Relay, Sequential) { + using namespace tvm; + auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, ::tvm::Float(32)); + auto c_data = + tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + // Create a function for optimization. + auto c = relay::ConstantNode::make(c_data); + auto a = relay::VarNode::make("a", tensor_type); + auto x = relay::VarNode::make("x", tensor_type); + auto add_op = relay::Op::Get("add"); + auto y = relay::CallNode::make(add_op, {c, c}); + y = relay::CallNode::make(add_op, {x, y}); + auto z = relay::CallNode::make(add_op, {y, c}); + auto z1 = relay::CallNode::make(add_op, {y, c}); + auto z2 = relay::CallNode::make(add_op, {z, z1}); + // Let expression and varaible a should be dead-code eliminated. + auto z3 = relay::LetNode::make(a, c, z2); + relay::Function func = + relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {}); + + // Get schedule + auto reg = tvm::runtime::Registry::Get("relay.op._Register"); + auto sch = tvm::runtime::Registry::Get("schedule"); + if (!reg || !sch) { + LOG(FATAL) << "Register/schedule is not defined."; + } + + (*reg)("add", "FTVMSchedule", *sch, 10); + + // Run sequential passes. + tvm::Array pass_seqs{ + relay::transform::InferType(), + relay::transform::DeadCodeElimination(), + relay::transform::EliminateCommonSubexpr(), + relay::transform::AlterOpLayout() + }; + relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); + auto mod = relay::ModuleNode::FromExpr(func); + auto pass_ctx = relay::transform::PassContext::Create(); + pass_ctx->opt_level = 3; + pass_ctx->fallback_device = 1; + { + tvm::With ctx_scope(pass_ctx); + tvm::With tctx(tvm::Target::Create("llvm")); + mod = seq(mod); + } + + CHECK(mod.defined()); + auto entry_func = mod->entry_func; + CHECK(entry_func.defined()); + relay::Function f = mod->Lookup(entry_func->name_hint); + CHECK(f.defined()); + + // Expected function + auto c1 = relay::ConstantNode::make(c_data); + auto x1 = relay::VarNode::make("x", tensor_type); + auto y1 = relay::CallNode::make(add_op, {c1, c1}); + y1 = relay::CallNode::make(add_op, {x1, y1}); + auto zz = relay::CallNode::make(add_op, {y1, c1}); + zz = relay::CallNode::make(add_op, {zz, zz}); + relay::Function expected_func = + relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); + + // Infer type for the expected function. + auto expected = relay::InferType(expected_func, relay::Module(nullptr)); + CHECK(relay::AlphaEqual(f, expected)); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 2703e5ce1679..7fdef3fa8b9c 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -327,7 +327,8 @@ def test_no_pass(): def test_only_module_pass(): passes = [module_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + with relay.build_config(required_pass=["mod_transform"]): + ret_mod = sequential(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, sub) @@ -341,7 +342,8 @@ def test_only_function_pass(): # Check the subtract function. passes = [function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + with relay.build_config(required_pass=["func_transform"]): + ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) @@ -355,7 +357,9 @@ def test_multiple_passes(): mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + required = ["mod_transform", "func_transform"] + with relay.build_config(required_pass=required): + ret_mod = sequential(mod) # Check the abs function is added. abs_var, abs_func = get_var_func() @@ -400,7 +404,48 @@ def test_multiple_passes(): test_multiple_passes() +def test_sequential_with_scoping(): + shape = (1, 2, 3) + c_data = np.array(shape).astype("float32") + tp = relay.TensorType(shape, "float32") + def before(): + c = relay.const(c_data) + x = relay.var("x", tp) + y = relay.add(c, c) + y = relay.multiply(y, relay.const(2, "float32")) + y = relay.add(x, y) + z = relay.add(y, c) + z1 = relay.add(y, c) + z2 = relay.add(z, z1) + return relay.Function([x], z2) + + def expected(): + x = relay.var("x", tp) + c_folded = (c_data + c_data) * 2 + y = relay.add(x, relay.const(c_folded)) + z = relay.add(y, relay.const(c_data)) + z1 = relay.add(z, z) + return relay.Function([x], z1) + + seq = _transform.Sequential([ + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.EliminateCommonSubexpr(), + relay.transform.AlterOpLayout() + ]) + + mod = relay.Module({"main": before()}) + with relay.build_config(opt_level=3): + with tvm.target.create("llvm"): + mod = seq(mod) + + zz = mod["main"] + zexpected = ir_pass.infer_type(expected()) + assert relay.ir_pass.alpha_equal(zz, zexpected) + + if __name__ == "__main__": test_module_pass() test_function_pass() test_sequential_pass() + test_sequential_with_scoping()