From 289e8f8cd5cb39542d2ec2acef998d037f226d67 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 24 May 2019 23:47:26 +0000 Subject: [PATCH 01/12] [relay][transform] Migrate buildmodule to transform --- include/tvm/relay/pass.h | 100 ++++++ include/tvm/relay/transform.h | 144 ++++++++- python/tvm/relay/build_module.py | 94 +----- python/tvm/relay/transform.py | 232 ++++++++++++++ src/relay/backend/build_module.cc | 335 +++++++-------------- src/relay/pass/alter_op_layout.cc | 28 +- src/relay/pass/canonicalize_ops.cc | 18 ++ src/relay/pass/combine_parallel_conv2d.cc | 18 ++ src/relay/pass/dead_code.cc | 6 +- src/relay/pass/device_annotation.cc | 9 +- src/relay/pass/eliminate_common_subexpr.cc | 19 ++ src/relay/pass/fold_constant.cc | 9 +- src/relay/pass/fold_scale_axis.cc | 51 +++- src/relay/pass/forward_rewrite.cc | 6 +- src/relay/pass/fuse_ops.cc | 8 +- src/relay/pass/partial_eval.cc | 10 +- src/relay/pass/pass_manager.cc | 171 +++++++---- src/relay/pass/simplify_inference.cc | 18 ++ src/relay/pass/to_a_normal_form.cc | 6 +- src/relay/pass/to_graph_normal_form.cc | 6 +- src/relay/pass/type_infer.cc | 29 ++ tests/cpp/relay_transform_sequential.cc | 119 ++++++++ tests/python/relay/test_pass_manager.py | 41 +++ 23 files changed, 1086 insertions(+), 391 deletions(-) create mode 100644 tests/cpp/relay_transform_sequential.cc diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 67cc5df82407..067f21cbffdc 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -358,6 +358,15 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); */ TVM_DLL Map CollectDeviceInfo(const Expr& expr); +/*! + * \brief Collect the device anntation operators. + * + * \param expr The expression. + * + * \return The annotated expression to device type mapping for annotation ops. + */ +TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); + /*! * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). * @@ -403,6 +412,97 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); */ TVM_DLL Expr PartialEval(const Expr& e); +/*! + * \brief Bind the free variables to a Relay expression. + * + * \param expr The expression. + * \param bind_map The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); + +/*! + * \brief Simplify certain operators during inference. For example, batch norm + * will be unpacked into a number of simplified operators. + * + * \param expr The expression. + * + * \return The updated expression. + */ +TVM_DLL Expr SimplifyInference(const Expr& e); + +/*! + * \brief Search and eliminate common subexpression. For example, if there are + * two expressions evaluated to an identical value, a single variable is created + * and these two expressions are replaced by this variable. + * + * \param expr The expression. + * \param fskip The callback argument that allows to skip certain expressions. + * + * \return The updated expression. + */ +TVM_DLL Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc fskip); + +/*! + * \brief Combine parallel 2d convolutions into a single convolution if the + * number of branches of this conv2d operator is not less than + * `min_num_branch`. + * + * \param expr The expression. + * \param min_num_branch The minimun number of branches. + * + * \return The updated expression. + */ +TVM_DLL Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches); + +namespace fold_scale_axis { + +/*! + * \brief Backward fold axis scaling into weights of conv/dense operators. + * + * \param expr The input expression. + * + * \return The updated expression. + */ +TVM_DLL Expr BackwardFoldScaleAxis(const Expr& expr); + +/*! + * \brief Forward fold axis scaling into weights of conv/dense operators. + * + * \param expr The input expression. + * + * \return The updated expression. + */ +TVM_DLL Expr ForwardFoldScaleAxis(const Expr& expr); + +} // namespace fold_scale_axis + +/*! + * \brief Canonicalize some operators to the simplified operators. For example, + * bias_add can be canonicalized to expand_dims and broadcast_add. + * + * \param expr The input expression. + * + * \return The updated expression. + */ +TVM_DLL Expr CanonicalizeOps(const Expr& expr); + +namespace alter_op_layout { + +/*! + * \brief Alternate the layouts of operators or replace primitive operators with + * other expressions. + * + * \param expr The input expression. + * + * \return The updated expression. + */ +TVM_DLL Expr AlterOpLayout(const Expr& expr); + +} // namespace alter_op_layout + /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1c1b60813b78..09f6a684e913 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -292,9 +292,9 @@ class Sequential : public Pass { * \param passes The passes to apply. * \param pass_info The pass metadata. */ - TVM_DLL Sequential(tvm::Array passes, - PassInfo pass_info); -/*! + TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); + + /*! * \brief The constructor of `Sequential`. * * \param passes The passes to apply. @@ -311,6 +311,52 @@ class Sequential : public Pass { using ContainerType = Sequential; }; +/*! + * \brief A class to save registered passes. + */ +class PassRegistry { + public: + /* + * \brief Get the global pass registry. + */ + static PassRegistry& Global(); + + /* + * \brief Look up a pass in the registered pass map using the pass name. + * + * \param name The name of the pass to be looked up. + * + * \return The correponding pass if it is found. Otherwise, an empty pass. + */ + const Pass Lookup(const std::string& name) const; + + /* + * \brief Look up a pass in the registered pass map. + * + * \param name The name of the pass to be looked up. + * + * \return The correponding pass if it is found. Otherwise, an empty pass. + */ + const Pass Lookup(const Pass& pass) const; + + /* + * \brief Register a pass. + * + * \param pass The pass to be required. + */ + const Pass RegisterPass(const Pass& pass); + + private: + PassRegistry() = default; + ~PassRegistry() = default; + + PassRegistry(const PassRegistry&) = delete; + PassRegistry(PassRegistry&&) = delete; + PassRegistry& operator=(const PassRegistry&) = delete; + PassRegistry& operator=(PassRegistry&&) = delete; + + std::unordered_map registered_pass_map_; +}; /* * \brief Create a module pass. @@ -451,6 +497,98 @@ TVM_DLL Pass ToGraphNormalForm(); */ TVM_DLL Pass PartialEval(); +/*! + * \brief Simplify certain operators during inference. For example, batch norm + * will be unpacked into a number of simplified operators. + * + * \return The Pass. + */ +TVM_DLL Pass SimplifyInference(); + +/*! + * \brief Infer the type of an expression. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \return The pass. + */ +TVM_DLL Pass InferType(); + +/*! + * \brief Infer the type of an expression. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \param var The global variable corresponding to the function being optimized. + * + * \return The pass. + */ +TVM_DLL Pass InferType(const GlobalVar& var); + +/*! + * \brief Search and eliminate common subexpression. For example, if there are + * two expressions evaluated to an identical value, a single variable is created + * and these two expressions are replaced by this variable. + * + * \param fskip The callback argument that allows to skip certain expressions. + * + * \return The pass. + */ +TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip); + +/*! + * \brief Combine parallel 2d convolutions into a single convolution if the + * number of branches of this conv2d operator is not less than + * `min_num_branch`. + * + * \param min_num_branch The minimun number of branches. + * + * \return The pass. + */ +TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); + +/*! + * \brief Backward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass BackwardFoldScaleAxis(); + +/*! + * \brief Forward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass ForwardFoldScaleAxis(); + +/*! + * \brief A sequential pass that executes ForwardFoldScaleAxis and + * BackwardFoldScaleAxis passes. + * + * \return The pass. + */ +TVM_DLL Pass FoldScaleAxis(); + +/*! + * \brief Canonicalize some operators to the simplified operators. For example, + * bias_add can be canonicalized to expand_dims and broadcast_add. + * + * \return The pass. + */ +TVM_DLL Pass CanonicalizeOps(); + +/*! + * \brief Alternate the layouts of operators or replace primitive operators + * with other expressions. + * + * \return The pass. + */ +TVM_DLL Pass AlterOpLayout(); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 6cee393d5f91..8f9b0481a22c 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -20,7 +20,6 @@ """ import numpy as np -from tvm._ffi.runtime_ctypes import TVMContext from tvm import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt @@ -28,7 +27,6 @@ from . import ir_pass from . import ty as _ty from . import expr as _expr -from . import transform as _transform from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -61,10 +59,6 @@ def __init__(self): self._get_graph_json = self.mod["get_graph_json"] self._get_module = self.mod["get_module"] self._build = self.mod["build"] - self._add_pass = self.mod["add_pass"] - self._disable_pass = self.mod["disable_pass"] - self._set_opt_level = self.mod["set_opt_level"] - self._set_fallback_device = self.mod["set_fallback_device"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] @@ -106,8 +100,9 @@ def build(self, func, target=None, target_host=None, params=None): """ target = _update_target(target) - # Setup the build configurations passed in through `with build_config`. - self._setup_build_config(params) + # Setup the params. + if params: + self._set_params(params) # Build the function self._build(func, target, target_host) # Get artifacts @@ -117,41 +112,6 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params - def _setup_build_config(self, params): - cfg = _transform.PassContext.current() - - # Set opt_level. - self.set_opt_level(cfg.opt_level) - - # Set fallback device if it is available. - if cfg.fallback_device: - self.set_fallback_device(cfg.fallback_device) - - # Add required passes. - if cfg.required_pass: - passes = set() - if isinstance(cfg.required_pass, (list, tuple, set)): - passes = set(cfg.required_pass) - else: - raise TypeError("add_pass must be list, tuple, or set, but " + - "got {}".format(type(cfg.required_pass))) - for pass_name in passes: - self.add_pass(pass_name) - - # Add disabled passes. - if cfg.disabled_pass: - passes = set() - if isinstance(cfg.disabled_pass, (list, tuple, set)): - passes = set(cfg.disabled_pass) - else: - raise TypeError("disable_pass must be list, tuple, or set, " + - "but got {}".format(type(cfg.disabled_pass))) - for pass_name in passes: - self.disable_pass(pass_name) - - if params: - self._set_params(params) - def _set_params(self, params): inputs = {} for name, param in params.items(): @@ -160,28 +120,6 @@ def _set_params(self, params): inputs[name] = _expr.const(param) self._set_params_func(inputs) - def add_pass(self, pass_name): - """Add a pass to the pass list. - - Parameters - ---------- - pass_name : str - The name of the pass that will be added to the list of passes used - for optimizations. - """ - self._add_pass(pass_name) - - def disable_pass(self, pass_name): - """Add a pass to the disabled pass list. - - Parameters - ---------- - pass_name : str - The name of a pass. This pass will be added to the list of passes - that are disabled during optimization. - """ - self._disable_pass(pass_name) - def get_json(self): """Return the json file of the built program.""" return self._get_graph_json() @@ -198,32 +136,6 @@ def get_params(self): ret[key] = value.data return ret - def set_opt_level(self, level): - """Set the optimization level. - - Parameters - ---------- - level : int - The optimization level for build. - """ - self._set_opt_level(level) - - def set_fallback_device(self, fallback_device): - """Set the fallback device for heterogeneous execution. - - Parameters - ---------- - fallback_device : str or tvm.TVMContext - The fallback device used for heterogeneous execution. - """ - if isinstance(fallback_device, (int, str)): - fallback_device = _nd.context(fallback_device) - if not isinstance(fallback_device, TVMContext): - raise TypeError("fallback_device is expected to be str, int, or " + - "TVMContext but received: {}".format(type(fallback_device))) - - self._set_fallback_device(fallback_device.device_type) - def build(func, target=None, target_host=None, params=None): """Helper function that builds a Relay function to run on TVM graph diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index a7887c630c76..a4f277026558 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -394,3 +394,235 @@ def create_function_pass(pass_func): if pass_func: return create_function_pass(pass_func) return create_function_pass + + +def infer_type(): + """Infer the type of an expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered type inference pass. + """ + return _transform.InferType() + + +def backward_fold_scale_axis(): + """Backward fold axis scaling into weights of conv2d/dense. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to backward fold expressions. + + Note + ---- + It is recommended to call backward_fold_scale_axis before using + forward_fold_scale_axis. As backward folding targets common conv-bn + pattern. + """ + return _transform.BackwardFoldScaleAxis() + + +def forward_fold_scale_axis(): + """Fold the scaling of axis into weights of conv2d/dense. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to forward fold expressions. + + Note + ---- + It is recommended to call backward_fold_scale_axis before using + forward_fold_scale_axis. As backward folding targets common conv-bn + pattern. + """ + return _transform.ForwardFoldScaleAxis() + + +def fold_scale_axis(): + """Fold the scaling of axis into weights of conv2d/dense. This pass will + invoke both forward and backward scale folding. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to fold expressions. + + Note + ---- + It is recommended to call backward_fold_scale_axis before using + forward_fold_scale_axis. As backward folding targets common conv-bn + pattern. + """ + return _transform.FoldScaleAxis() + + +def simplify_inference(): + """Simplify the data-flow graph for inference phase. An simplified expression + which is semantically equal to the input expression will be returned. + + Returns + ------- + ret: tvm.relay.Pass + The registered to perform operator simplification. + """ + return _transform.SimplifyInference() + + +def canonicalize_ops(): + """ Canonicalize special operators to basic operators. + This can simplify followed analysis. (e.g. expanding bias_add to + expand_dims and broadcast_add.) + + Returns + ------- + ret: tvm.relay.Pass + The registered pass performing the canonicalization. + """ + return _transform.CanonicalizeOps() + + +def dead_code_elimination(): + """ Remove expressions which does not effect the program result (dead code). + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that eliminates the dead code in a Relay program. + """ + return _transform.DeadCodeElimination() + + +def fold_constant(): + """Fold the constant expression in expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for constant folding. + """ + return _transform.FoldConstant() + + +def fuse_ops(fuse_opt_level=-1): + """Fuse operators in an expr to a larger operator according to some rules. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for operator fusion. + """ + return _transform.FuseOps(fuse_opt_level) + + +def combine_parallel_conv2d(min_num_branches=3): + """Combine multiple conv2d operators into one. + + Parameters + ---------- + min_num_branches : int + The minimum number of required parallel branches for performing this + optimization. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that combines parallel conv2d operators. + """ + return _transform.CombineParallelConv2D(min_num_branches) + + +def alter_op_layout(): + """Alternate the layouts of operators or replace primitive operators with + other expressions. + This pass can be used for computing convolution in custom layouts or + other general weight pre-transformation. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that alters the layout of operators. + """ + return _transform.AlterOpLayout() + + +def rewrite_annotated_ops(fallback_device): + """Rewrite the annotated program where annotation operators, e.g. + `on_deivce`, mark which device an expression should be scheduled to. + This pass helps heterogeneous execution where different operators may need + to be allocated on various devices. + + Parameters + ---------- + fallback_device : int + The fallback device type. It is also used as the default device for + operators with no annotated device. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that rewrites an expression with annotated + `on_device` operators. + """ + return _transform.RewriteDeviceAnnotation(fallback_device) + + +def to_a_normal_form(): + """Turn Graph Normal Form expression into A Normal Form Expression. + The scope of the root expression is the global scope. + The scope of any non root expression is the least common ancestor of all it's scope. + Values are ordered by post-DFS order in each scope. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that transforms an expression into A Normal Form. + """ + return _transform.ToANormalForm() + + +def to_graph_normal_form(): + """Turn A Normal Form expression into Graph Normal Form expression + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that transforms an expression into Graph Normal Form. + """ + return _transform.ToGraphNormalForm() + + +def eliminate_common_subexpr(fskip=None): + """Eliminate common subexpressions. + + Parameters + ---------- + fskip: Callable + The callback function that decides whether an expression should be + skipped. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that eliminates common subexpressions. + """ + return _transform.EliminateCommonSubexpr(fskip) + + +def partial_evaluate(): + """Evaluate the static fragment of the code. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that performs partial evaluation on an expression. + """ + return _transform.PartialEval() diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 57dc256ef6b7..d76f9c786a87 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -25,10 +25,8 @@ #include #include #include -#include #include -#include -#include +#include #include #include "utils.h" @@ -38,39 +36,7 @@ namespace relay { namespace backend { using TargetsMap = Map; - -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - * - */ -struct OptPassLevel { - static const std::unordered_map _data; - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - auto it = _data.find(key); - if (it == _data.end()) { - return -1; - } - return it->second; - } -}; - -const std::unordered_map OptPassLevel::_data = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 4}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} -}; +using namespace tvm::relay::transform; /*! * \brief Output of building module @@ -82,27 +48,6 @@ struct BuildOutput { std::unordered_map params; }; -/*! - * \brief Relay building config - * - */ -struct RelayBuildConfig { - int opt_level{2}; - int fallback_device{static_cast(kDLCPU)}; - std::unordered_set enabled_pass; - std::unordered_set disabled_pass; - OptPassLevel OPT_PASS_LEVEL; - inline bool pass_enabled(const std::string& pass_name) const { - if (disabled_pass.count(pass_name)) { - return false; - } - if (enabled_pass.count(pass_name)) { - return true; - } - return opt_level >= OPT_PASS_LEVEL[pass_name]; - } -}; - /*! * \brief GraphCodegen module wrapper * @@ -156,18 +101,6 @@ struct GraphCodegen { } }; -template -R CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - -template -Function CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - /*! * \brief Relay build module * @@ -203,28 +136,6 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); - } else if (name == "set_opt_level") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - int level = args[0]; - this->SetOptLevel(level); - }); - } else if (name == "set_fallback_device") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - int dev = args[0]; - this->SetFallBackDev(dev); - }); - } else if (name == "add_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->AddPass(pass_name); - }); - } else if (name == "disable_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->DisablePass(pass_name); - }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -246,30 +157,7 @@ class RelayBuildModule : public runtime::ModuleNode { const std::string& GetGraphJSON() { return ret_.graph_json; } - /*! - * \brief Add extra pass into build cfg - * - * \param pass_name name of pass - */ - void AddPass(const std::string& pass_name) { - cfg_.enabled_pass.insert(pass_name); - } - /*! - * \brief Disable a specific pass in cfg - * - * \param pass_name name of pass - */ - void DisablePass(const std::string& pass_name) { - cfg_.disabled_pass.insert(pass_name); - } - /*! - * \brief Set the Fallback device - * - * \param device name - */ - void SetFallBackDev(int dev) { - cfg_.fallback_device = dev; - } + /*! * \brief Get the Module object * @@ -315,15 +203,6 @@ class RelayBuildModule : public runtime::ModuleNode { params_[name] = data_in; } - /*! - * \brief Set the optimization level - * - * \param level - */ - void SetOptLevel(char level) { - cfg_.opt_level = level; - } - /*! * \brief type key * @@ -345,7 +224,7 @@ class RelayBuildModule : public runtime::ModuleNode { const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; - BuildRelay(func, cfg_, params_); + BuildRelay(func, params_); } protected: @@ -378,85 +257,95 @@ class RelayBuildModule : public runtime::ModuleNode { if (repeat_var.count(arg)) { LOG(FATAL) << "Multiple args in the function have name " << kv.first; } - auto e = CallPackedFunc("relay._make.Constant", kv.second); - bind_dict[arg] = e; + bind_dict[arg] = ConstantNode::make(kv.second); } - return CallPackedFunc("relay._expr.Bind", func, tvm::Map(bind_dict)); + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) + << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; } /*! - * \brief Optimize Relay function + * \brief Optimize a Relay function. * - * \param func Input function - * \param target target device - * \param cfg Relay build config - * \param params params dict - * \return relay::Function + * \param func The input Relay function for optimization. + * \param targets The device type to `Target` mapping. + * \param params The param name to value mapping. + * + * \return func The updated Relay function after optimization. */ - relay::Function Optimize(relay::Function func, - const TargetsMap& targets, - const RelayBuildConfig& cfg, - const std::unordered_map& params) { + relay::Function Optimize( + relay::Function func, + const TargetsMap& targets, + const std::unordered_map& params) { if (params.size()) { func = BindParamsByName(func, params); } - if (cfg.pass_enabled("SimplifyInference")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.simplify_inference", func); - } - if (cfg.pass_enabled("EliminateCommonSubexpr")) { - auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == HalideIR::Int(32)) { - *rv = true; - } + + relay::Module relay_module = relay::ModuleNode::FromExpr(func); + relay_module = Optimize(relay_module, targets_, params); + CHECK(relay_module.defined()); + GlobalVar var = relay_module->entry_func; + return relay_module->Lookup(var->name_hint); + } + + /*! + * \brief Optimize a Relay module. + * + * \param relay_module The input Relay module where optmization will be + * applied on. + * \param targets The device type to `Target` mapping. + * \param params The param name to value mapping. + * + * \return relay::Module The updated Relay module after optimization. + */ + relay::Module Optimize( + relay::Module relay_module, + const TargetsMap& targets, + const std::unordered_map& params) { + Array pass_seqs; + pass_seqs.push_back(transform::SimplifyInference()); + + // Can we move to the pass implementation file and make it as default? + auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; } } - *rv = false; - }); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip); - } - if (cfg.pass_enabled("CombineParallelConv2D")) { - const int min_num_branches = 3; - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches); - } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("FoldScaleAxis")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("CanonicalizeOps")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func); - } - if (cfg.pass_enabled("AlterOpLayout")) { - if (targets.size() == 1) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - for (const auto& kv : targets) { - With tctx(kv.second); - func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); - } - } else { - LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" - << " execution yet."; } + *rv = false; + }); + + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + pass_seqs.push_back(transform::AlterOpLayout()); } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); + pass_seqs.push_back(transform::FoldConstant()); + + // Create a sequential pass and perform optimizations. + transform::Pass seq = transform::Sequential(pass_seqs); + if (targets.size() == 1) { + for (const auto& kv : targets) { + With tctx(kv.second); + relay_module = seq->operator()(relay_module); + } + } else { + relay_module = seq->operator()(relay_module); } - return func; + return relay_module; } /*! @@ -477,20 +366,20 @@ class RelayBuildModule : public runtime::ModuleNode { * dictionary in this case. * * \param targets dictionary - * \param cfg + * \param fallback_device The fallback device for heterogeneous execution. * \return Map */ TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets, - const RelayBuildConfig& cfg) { + int fallback_device) { TargetsMap device_target = targets; std::unordered_map tmp_map; for (const auto& kv : targets) { tmp_map[kv.first->value] = kv.second; } - if (tmp_map.count(cfg.fallback_device) == 0) { + if (tmp_map.count(fallback_device) == 0) { device_target.Set( - cfg.fallback_device, - CreateDefaultTarget(cfg.fallback_device)); + fallback_device, + CreateDefaultTarget(fallback_device)); } return device_target; } @@ -498,25 +387,24 @@ class RelayBuildModule : public runtime::ModuleNode { * \brief Execute the device annotation passes to update the input program and * target information. * - * \param func - * \param cfg - * \param targets_map_ptr - * \return Function + * \param func The input Relay function. + * \param fallback_device The fallback device for heterogeneous execution. + * \param targets_map_ptr The device type to `Target` map pointer. + * + * \return func The updated function after device annotation. */ Function RunDeviceAnnotationPass(Function func, - const RelayBuildConfig& cfg, + int fallback_device, TargetsMap* targets_map_ptr) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, - cfg.fallback_device); - auto device_map = CallPackedFunc >( - "relay._ir_pass.CollectDeviceInfo", func, nullptr); + auto new_func = relay::InferType(func, Module(nullptr)); + new_func = relay::RewriteAnnotatedOps(new_func, fallback_device); + func = Downcast(new_func); + CHECK(func.defined()); + auto device_map = relay::CollectDeviceInfo(func); if (device_map.size() == 0) { - auto annotation_map = CallPackedFunc >( - "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr); + auto annotation_map = relay::CollectDeviceAnnotationOps(func); if (annotation_map.size() == 0) { - targets_map_ptr->Set( - 0, CreateDefaultTarget(cfg.fallback_device)); + targets_map_ptr->Set(0, CreateDefaultTarget(fallback_device)); } else { int64_t dev_type = -1; for (auto kv : annotation_map) { @@ -541,28 +429,33 @@ class RelayBuildModule : public runtime::ModuleNode { * \brief Build relay function to runtime module * * \param func Relay Function - * \param cfg Relay build config * \param params parameters */ - void BuildRelay(Function func, - const RelayBuildConfig& cfg, - const std::unordered_map ¶ms) { + void BuildRelay( + Function func, + const std::unordered_map& params) { + transform::PassContext pass_ctx = PassContext::Current(); + // convert tvm_cfg_ = BuildConfig::Create(); TargetsMap device_target; if (targets_.size() > 1) { - device_target = UpdateHeterogeneousInputs(targets_, cfg); + device_target = + UpdateHeterogeneousInputs(targets_, pass_ctx->fallback_device); } else { device_target = targets_; } - func = Optimize(func, targets_, cfg, params); + func = Optimize(func, targets_, params); if (device_target.size() > 1) { - func = RunDeviceAnnotationPass(func, cfg, &device_target); + func = RunDeviceAnnotationPass(func, pass_ctx->fallback_device, + &device_target); } - // TODO(@jroesch): use the passes directly. - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level, nullptr); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + + auto new_func = relay::InferType(func, Module(nullptr)); + new_func = relay::FuseOps(new_func, pass_ctx->opt_level, Module(nullptr)); + new_func = relay::InferType(new_func, Module(nullptr)); + func = Downcast(new_func); + CHECK(func.defined()); graph_codegen_ = std::unique_ptr(new GraphCodegen()); graph_codegen_->Init(nullptr, device_target); @@ -580,8 +473,6 @@ class RelayBuildModule : public runtime::ModuleNode { TargetsMap targets_; /*! \brief target host device */ tvm::Target target_host_; - /*! \brief frontend optimization configure */ - RelayBuildConfig cfg_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index f51c201d0b2a..213bc03388bc 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -338,17 +339,36 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // Limiations: // 1. the altered op should have the same number of arguments as the previous one // 2. do not support nested tuple arguments -TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") -.set_body([](TVMArgs args, TVMRetValue *ret) { +Expr AlterOpLayout(const Expr& expr) { TransformMemorizer transformMemorizer(make_node()); auto fcontext = [&](const Call& call) -> NodeRef{ return transformMemorizer; }; - *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext); -}); + return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); +} + +TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") +.set_body_typed(AlterOpLayout); } // namespace alter_op_layout +namespace transform { + +Pass AlterOpLayout() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + }; + Pass pass = CreateFunctionPass(pass_func, 3, "alter_op_layout", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); +} + +TVM_REGISTER_API("relay._transform.AlterOpLayout") +.set_body_typed(AlterOpLayout); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 9a4602750195..8f94485894d5 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "pattern_util.h" namespace tvm { @@ -63,5 +64,22 @@ Expr CanonicalizeOps(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") .set_body_typed(CanonicalizeOps); +namespace transform { + +Pass CanonicalizeOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CanonicalizeOps(f)); + }; + Pass pass = CreateFunctionPass(pass_func, 3, "canonicalize_ops", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); +} + +TVM_REGISTER_API("relay._transform.CanonicalizeOps") +.set_body_typed(CanonicalizeOps); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 7e76322d5a2a..2c3800d60f6e 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include "./expr_subst.h" @@ -357,5 +358,22 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body_typed(CombineParallelConv2D); +namespace transform { + +Pass CombineParallelConv2D(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelConv2D(f, min_num_branches)); + }; + Pass pass = CreateFunctionPass(pass_func, 4, "combine_parallel_conv2d", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); +} + +TVM_REGISTER_API("relay._transform.CombineParallelConv2D") +.set_body_typed(CombineParallelConv2D); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index dd1ed6240cab..0766743e5c49 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -158,9 +158,13 @@ Pass DeadCodeElimination() { [=](Function f, Module m, PassContext pc) { return Downcast(DeadCodeElimination(f)); }; - return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); + Pass dec = CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); + return PassRegistry::Global().RegisterPass(dec); } +TVM_REGISTER_API("relay._transform.DeadCodeElimination") +.set_body_typed(DeadCodeElimination); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index fa656dbf489e..21da33f51901 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -557,11 +558,15 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, Module m, PassContext pc) { return Downcast(RewriteAnnotatedOps(f, fallback_device)); }; - return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", {}); + Pass pass = CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); } +TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation") +.set_body_typed(RewriteAnnotatedOps); + } // namespace transform } // namespace relay } // namespace tvm - diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index f8432f671855..3deebdf37c00 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -29,6 +29,7 @@ */ #include #include +#include #include #include "./pattern_util.h" @@ -87,5 +88,23 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr") .set_body_typed(EliminateCommonSubexpr); +namespace transform { + +Pass EliminateCommonSubexpr(PackedFunc fskip) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(f, fskip)); + }; + Pass pass = CreateFunctionPass(pass_func, 3, "eliminate_common_subexpr", + {ir::StringImm::make("infer_type")}); + PassRegistry::Global().RegisterPass(pass); + return pass; +} + +TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr") +.set_body_typed(EliminateCommonSubexpr); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 286392ab5d3f..036533265d8a 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -220,11 +221,15 @@ namespace transform { Pass FoldConstant() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(FoldConstant(f)); + return Downcast(FoldConstant(f)); }; - return CreateFunctionPass(pass_func, 1, "fold_constant", {}); + Pass pass = CreateFunctionPass(pass_func, 2, "fold_constant", {}); + return PassRegistry::Global().RegisterPass(pass); } +TVM_REGISTER_API("relay._transform.FoldConstant") +.set_body_typed(FoldConstant); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index c738e3e3b731..db9340c93d3b 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "pattern_util.h" #include "pass_util.h" @@ -530,7 +531,7 @@ RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); -Expr ForwardFoldScaleAxis(Expr data) { +Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); auto fcontext = [&](const Call& call) -> NodeRef{ auto it = message.find(call.get()); @@ -942,7 +943,7 @@ RELAY_REGISTER_OP("nn.conv2d") RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); -Expr BackwardFoldScaleAxis(Expr data) { +Expr BackwardFoldScaleAxis(const Expr& data) { return make_node()->Fold(data); } @@ -950,5 +951,51 @@ TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") .set_body_typed(BackwardFoldScaleAxis); } // namespace fold_scale_axis + +namespace transform { + +Pass ForwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::ForwardFoldScaleAxis(f)); + }; + Pass pass = CreateFunctionPass(pass_func, 3, "forward_fold_scale_axis", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); +} + +Pass BackwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::BackwardFoldScaleAxis(f)); + }; + Pass pass = CreateFunctionPass(pass_func, 3, "backward_fold_scale_axis", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); +} + +Pass FoldScaleAxis() { + // FoldScaleAxis pass contains the following three passes. Therefore, we can + // register it as a sequential pass. + Pass pass = Sequential( + {FoldConstant(), BackwardFoldScaleAxis(), ForwardFoldScaleAxis()}, + "fold_scale_axis"); + + return PassRegistry::Global().RegisterPass(pass); +} + +TVM_REGISTER_API("relay._transform.ForwardFoldScaleAxis") +.set_body_typed(ForwardFoldScaleAxis); + +TVM_REGISTER_API("relay._transform.BackwardFoldScaleAxis") +.set_body_typed(BackwardFoldScaleAxis); + +TVM_REGISTER_API("relay._transform.FoldScaleAxis") +.set_body_typed(FoldScaleAxis); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 2a3aa1612418..c915c604d722 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -220,7 +220,8 @@ Pass ForwardRewrite(const std::string& rewrite_map_attr_name, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + Pass pass = CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + return PassRegistry::Global().RegisterPass(pass); } Pass ForwardRewrite(const FForwardRewrite& rewrite_func, @@ -233,7 +234,8 @@ Pass ForwardRewrite(const FForwardRewrite& rewrite_func, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + Pass pass = CreateFunctionPass(pass_func, 1, "forward_rewrite_fun", {}); + return PassRegistry::Global().RegisterPass(pass); } } // namespace transform diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9277689075c2..1422ab91fa9c 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "./pattern_util.h" #include "../../common/arena.h" @@ -973,9 +974,14 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - return CreateFunctionPass(pass_func, 1, "fuse_ops", {}); + Pass pass = CreateFunctionPass(pass_func, 1, "fuse_ops", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); } +TVM_REGISTER_API("relay._transform.FuseOps") +.set_body_typed(FuseOps); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 3f42c6fce4b2..78b2b217df04 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -797,9 +797,7 @@ Expr PartialEval(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.partial_evaluate") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = PartialEval(args[0]); - }); +.set_body_typed(PartialEval); namespace transform { @@ -808,9 +806,13 @@ Pass PartialEval() { [=](Function f, Module m, PassContext pc) { return Downcast(PartialEval(f)); }; - return CreateFunctionPass(pass_func, 1, "partial_eval", {}); + Pass pass = CreateFunctionPass(pass_func, 1, "partial_eval", {}); + return PassRegistry::Global().RegisterPass(pass); } +TVM_REGISTER_API("relay._transform.PartialEval") +.set_body_typed(PartialEval); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a9c671aa163a..1525a0744abb 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -37,42 +37,69 @@ namespace transform { using tvm::IRPrinter; -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - */ -class OptPassLevel { - public: - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - const auto data = CreateMap(); - auto it = data.find(key); - if (it == data.end()) { - return -1; +namespace { + +// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be +// handled because we need to register the pass for Python invocation anyway. +Pass GetPass(const std::string& pass_name) { + if (pass_name == "infer_type") { + return InferType(); + } else if (pass_name == "simplify_inference") { + return SimplifyInference(); + } else if (pass_name == "alter_op_layout") { + return AlterOpLayout(); + } else if (pass_name == "canonicalize_ops") { + return CanonicalizeOps(); + } else if (pass_name == "combine_parallel_conv2d") { + return CombineParallelConv2D(); + } else if (pass_name == "fold_constant") { + return FoldConstant(); + } else if (pass_name == "fold_scale_axis") { + return FoldScaleAxis(); + } else if (pass_name == "to_a_normal_form") { + return ToANormalForm(); + } else if (pass_name == "to_graph_normal_form") { + return ToGraphNormalForm(); + } else { + LOG(FATAL) << pass_name << " has not been registered yet." << "\n"; + return Pass(nullptr); + } +} + +} // namespace + +PassRegistry& PassRegistry::Global() { + static PassRegistry registry; + return registry; +} + +const Pass PassRegistry::Lookup(const std::string& name) const { + auto it = registered_pass_map_.find(name); + if (it == registered_pass_map_.end()) { + return Pass(nullptr); } return it->second; - } +} - private: - static const std::unordered_map CreateMap() { - const std::unordered_map m = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 3}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} - }; - return m; +const Pass PassRegistry::Lookup(const Pass& pass) const { + PassInfo info = pass->Info(); + return Lookup(info->name); +} + +const Pass PassRegistry::RegisterPass(const Pass& pass) { + CHECK(pass.defined()) << "Undefined passes are not allowed to be registered." + << "\n"; + + PassInfo info = pass->Info(); + std::string name = info->name; + const Pass pa = Lookup(name); + if (pa.defined()) { + return pa; + } else { + registered_pass_map_.insert({name, pass}); + return pass; } -}; +} struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -246,12 +273,6 @@ class SequentialNode : public PassNode { /* \brief The pass meta data.*/ PassInfo pass_info; - /*! - * \brief A helper struct to get the optimization pass name to opt level - * mapping. - */ - OptPassLevel opt_pass_level; - /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; void VisitAttrs(tvm::AttrVisitor* v) final { @@ -300,7 +321,7 @@ class SequentialNode : public PassNode { const Array& disabled) const; std::unordered_set RequiredPasses( - const Array& disabled) const; + const Array& required) const; /*! * \brief Perform optimizations on a series of passes. The aforementioned @@ -338,14 +359,28 @@ ModulePass ModulePassNode::make( } // Module -> Module optimizations. -// TODO(zhiics) Check and handle the required passes. Module ModulePassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level << "\n"; + CHECK(mod.defined()); - auto updated_mod = pass_func(mod, pass_ctx); + Module updated_mod = mod; + // Execute the required passes in a DFS way. + // TODO(zhiics) We may need to pass validation to detect the cyclic + // dependency. + PassRegistry& registry = PassRegistry::Global(); + for (const auto& it : pass_info->required) { + const auto* name = it.as(); + CHECK(name); + Pass pass = registry.Lookup(name->value); + pass = pass.defined() ? pass : GetPass(name->value); + const auto* pass_node = pass.operator->(); + updated_mod = (*pass_node)(updated_mod, pass_ctx); + } + + updated_mod = pass_func(updated_mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } @@ -365,12 +400,29 @@ Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { PassInfo pass_info = Info(); CHECK(mod.defined()); - Module new_mod = ModuleNode::make({}, mod->type_definitions); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level << "\n"; + + Module updated_mod = mod; + // Execute the required passes in a DFS way. + // TODO(zhiics) We may need to pass validation to detect the cyclic + // dependency. + PassRegistry& registry = PassRegistry::Global(); + for (const auto& it : pass_info->required) { + const auto* name = it.as(); + CHECK(name); + Pass pass = registry.Lookup(name->value); + pass = pass.defined() ? pass : GetPass(name->value); + const auto* pass_node = pass.operator->(); + updated_mod = (*pass_node)(updated_mod, pass_ctx); + } + + Module new_mod = ModuleNode::make({}, mod->type_definitions); // Execute the pass function and return a new module. for (const auto& it : mod->functions) { - auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx); + auto updated_func = SkipFunction(it.second) + ? it.second + : pass_func(it.second, updated_mod, pass_ctx); new_mod->Add(it.first, updated_func); } @@ -439,7 +491,7 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { PassContext ctx = PassContext::Current(); auto required = RequiredPasses(ctx->required_pass); - auto disabled = DisabledPasses(ctx->required_pass); + auto disabled = DisabledPasses(ctx->disabled_pass); if (disabled.count(pass_name)) { return false; @@ -448,29 +500,38 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { if (required.count(pass_name)) { return true; } - return ctx->opt_level >= opt_pass_level[pass_name]; + + PassRegistry& registry = PassRegistry::Global(); + const Pass registered_pass = registry.Lookup(pass_name); + + if (!registered_pass.defined()) { + LOG(WARNING) << pass_name + << " is not registered to the pass registry, it will be " + "forced to execute." + << "\n"; + return true; + } + + PassInfo info = registered_pass->Info(); + return ctx->opt_level >= info->opt_level; } // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase -// ordering problem needed to be handled in the future. +// ordering problem needs to be handled in the future. Module SequentialNode::operator()(const Module& module, const PassContext& pass_ctx) const { - int opt_level = pass_ctx->opt_level; - auto disabled = DisabledPasses(pass_ctx->disabled_pass); Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; + PassInfo info = pass->Info(); const auto& pass_name = info->name; - const auto& pass_opt_level = info->opt_level; - // Skip the pass if its optimization level is higher that the one of in the - // pass context or if this pass is disabled. - if (pass_opt_level > opt_level || disabled.count(pass_name)) { - continue; + // Execute the pass if it is enabled. + if (pass_enabled(pass_name)) { + const auto* pn = pass.operator->(); + mod = (*pn)(mod, pass_ctx); } - const auto* pn = pass.operator->(); - mod = (*pn)(mod, pass_ctx); } return mod; } diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index cecebc5c04ed..ff7d9bb1957c 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "./pattern_util.h" namespace tvm { @@ -105,5 +106,22 @@ Expr SimplifyInference(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.simplify_inference") .set_body_typed(SimplifyInference); +namespace transform { + +Pass SimplifyInference() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(SimplifyInference(f)); + }; + Pass pass = CreateFunctionPass(pass_func, 0, "simplify_inference", + {ir::StringImm::make("infer_type")}); + return PassRegistry::Global().RegisterPass(pass); +} + +TVM_REGISTER_API("relay._transform.SimplifyInference") +.set_body_typed(SimplifyInference); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index f9d47f78a6d2..e2637865b2c3 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -340,9 +340,13 @@ Pass ToANormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToANormalForm(f, m)); }; - return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); + Pass pass = CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); + return PassRegistry::Global().RegisterPass(pass); } +TVM_REGISTER_API("relay._transform.ToANormalForm") +.set_body_typed(ToANormalForm); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 50ebb702e4b2..79d9c3d9ba61 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -86,9 +86,13 @@ Pass ToGraphNormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToGraphNormalForm(f)); }; - return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); + Pass pass = CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); + return PassRegistry::Global().RegisterPass(pass); } +TVM_REGISTER_API("relay._transform.ToGraphNormalForm") +.set_body_typed(ToGraphNormalForm); + } // namespace transform } // namespace relay diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 482cef3b2c2d..0dbc04cf259e 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -43,6 +43,7 @@ #include #include #include +#include #include "./pass_util.h" #include "type_solver.h" #include "../ir/type_functor.h" @@ -807,5 +808,33 @@ TVM_REGISTER_API("relay._ir_pass.infer_type") .set_body_typed([](const Expr& expr, const Module& mod_ref) { return InferType(expr, mod_ref); }); + +namespace transform { + +Pass InferType() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(InferType(f, m)); + }; + Pass pass = CreateFunctionPass(pass_func, 0, "infer_type", {}); + return PassRegistry::Global().RegisterPass(pass); +} + +Pass InferType(const GlobalVar& var) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(InferType(f, m, var)); + }; + Pass pass = CreateFunctionPass(pass_func, 0, "infer_type_var", {}); + return PassRegistry::Global().RegisterPass(pass); +} + +TVM_REGISTER_API("relay._transform.InferType") +.set_body_typed([]() { + return InferType(); +}); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc new file mode 100644 index 000000000000..0b3ba8a0ec2a --- /dev/null +++ b/tests/cpp/relay_transform_sequential.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TVM_REGISTER_GLOBAL("schedule") + .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + *rv = topi::generic::schedule_injective(args[0], args[1]); + }); + +TEST(Relay, Sequential) { + using namespace tvm; + auto tensor_type = relay::TensorTypeNode::make({1, 2, 3}, ::tvm::Float(32)); + auto c_data = + tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + // Create a function for optimization. + auto c = relay::ConstantNode::make(c_data); + auto a = relay::VarNode::make("a", tensor_type); + auto x = relay::VarNode::make("x", tensor_type); + auto add_op = relay::Op::Get("add"); + auto y = relay::CallNode::make(add_op, {c, c}); + y = relay::CallNode::make(add_op, {x, y}); + auto z = relay::CallNode::make(add_op, {y, c}); + auto z1 = relay::CallNode::make(add_op, {y, c}); + auto z2 = relay::CallNode::make(add_op, {z, z1}); + // Let expression and varaible a should be dead-code eliminated. + auto z3 = relay::LetNode::make(a, c, z2); + relay::Function func = + relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {}); + + // Get schedule + auto reg = tvm::runtime::Registry::Get("relay.op._Register"); + auto sch = tvm::runtime::Registry::Get("schedule"); + if (!reg || !sch) { + LOG(FATAL) << "Register/schedule is not defined."; + } + + (*reg)("add", "FTVMSchedule", *sch, 10); + + // Run sequential passes. + tvm::Array pass_seqs; + pass_seqs.push_back(relay::transform::InferType()); + pass_seqs.push_back(relay::transform::DeadCodeElimination()); + pass_seqs.push_back( + relay::transform::EliminateCommonSubexpr(tvm::PackedFunc(nullptr))); + pass_seqs.push_back(relay::transform::AlterOpLayout()); + + auto pfb = tvm::runtime::Registry::Get("relay._module.Module_FromExpr"); + relay::Module mod = (*pfb)(func); + relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); + { + tvm::With pass_ctx( + relay::transform::PassContext(3, 1, {}, {})); + tvm::With tctx(tvm::Target::Create("llvm")); + mod = seq->operator()(mod); + } + + CHECK(mod.defined()); + + relay::GlobalVar var = mod->entry_func; + relay::Function f; + for (const auto& kv : mod->functions) { + f = kv.second; + } + CHECK(f.defined()); + + // Expected function + auto c1 = relay::ConstantNode::make(c_data); + auto x1 = relay::VarNode::make("x", tensor_type); + auto y1 = relay::CallNode::make(add_op, {c1, c1}); + y1 = relay::CallNode::make(add_op, {x1, y1}); + auto zz = relay::CallNode::make(add_op, {y1, c1}); + zz = relay::CallNode::make(add_op, {zz, zz}); + relay::Function expected = + relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); + + // Infer type for the expected function. + auto infer = tvm::runtime::Registry::Get("relay._ir_pass.infer_type"); + if (!infer) { + LOG(FATAL) << "infer_type pass is not registered"; + } + expected = (*infer)(expected, nullptr); + + CHECK(relay::AlphaEqual(f, expected)); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 2703e5ce1679..0035b72507df 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -400,7 +400,48 @@ def test_multiple_passes(): test_multiple_passes() +def test_sequential_with_scoping(): + shape = (1, 2, 3) + c_data = np.array(shape).astype("float32") + tp = relay.TensorType(shape, "float32") + def before(): + c = relay.const(c_data) + x = relay.var("x", tp) + y = relay.add(c, c) + y = relay.multiply(y, relay.const(2, "float32")) + y = relay.add(x, y) + z = relay.add(y, c) + z1 = relay.add(y, c) + z2 = relay.add(z, z1) + return relay.Function([x], z2) + + def expected(): + x = relay.var("x", tp) + c_folded = (c_data + c_data) * 2 + y = relay.add(x, relay.const(c_folded)) + z = relay.add(y, relay.const(c_data)) + z1 = relay.add(z, z) + return relay.Function([x], z1) + + seq = _transform.Sequential([ + relay.transform.infer_type(), + relay.transform.fold_constant(), + relay.transform.eliminate_common_subexpr(), + relay.transform.alter_op_layout() + ]) + + mod = relay.Module({"main": before()}) + with relay.build_config(opt_level=3): + with tvm.target.create("llvm"): + mod = seq(mod) + + zz = mod["main"] + zexpected = ir_pass.infer_type(expected()) + assert relay.ir_pass.alpha_equal(zz, zexpected) + + if __name__ == "__main__": test_module_pass() test_function_pass() test_sequential_pass() + test_sequential_with_scoping() From 38e96821f6a82a8f27fb2127a7a910063dc9129f Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 28 May 2019 17:55:19 +0000 Subject: [PATCH 02/12] fix lint --- include/tvm/relay/pass.h | 4 ++-- include/tvm/relay/transform.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 067f21cbffdc..8fdeda947e20 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -431,7 +431,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); * * \return The updated expression. */ -TVM_DLL Expr SimplifyInference(const Expr& e); +TVM_DLL Expr SimplifyInference(const Expr& expr); /*! * \brief Search and eliminate common subexpression. For example, if there are @@ -451,7 +451,7 @@ TVM_DLL Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc fskip); * `min_num_branch`. * * \param expr The expression. - * \param min_num_branch The minimun number of branches. + * \param min_num_branches The minimun number of branches. * * \return The updated expression. */ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 09f6a684e913..94ae7a4b09b0 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -545,7 +545,7 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip); * number of branches of this conv2d operator is not less than * `min_num_branch`. * - * \param min_num_branch The minimun number of branches. + * \param min_num_branches The minimun number of branches. * * \return The pass. */ From 9cc93a230400c7f39d244e6ec0cff2e780560794 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 28 May 2019 21:02:02 +0000 Subject: [PATCH 03/12] fix comments --- include/tvm/relay/pass.h | 80 ---------------------- include/tvm/relay/transform.h | 60 ---------------- python/tvm/relay/transform.py | 65 +++++------------- src/relay/pass/alter_op_layout.cc | 5 +- src/relay/pass/canonicalize_ops.cc | 5 +- src/relay/pass/combine_parallel_conv2d.cc | 5 +- src/relay/pass/dead_code.cc | 3 +- src/relay/pass/device_annotation.cc | 5 +- src/relay/pass/eliminate_common_subexpr.cc | 6 +- src/relay/pass/fold_constant.cc | 3 +- src/relay/pass/fold_scale_axis.cc | 19 ++--- src/relay/pass/forward_rewrite.cc | 6 +- src/relay/pass/fuse_ops.cc | 5 +- src/relay/pass/partial_eval.cc | 3 +- src/relay/pass/pass_manager.cc | 67 +++++------------- src/relay/pass/simplify_inference.cc | 5 +- src/relay/pass/to_a_normal_form.cc | 3 +- src/relay/pass/to_graph_normal_form.cc | 3 +- src/relay/pass/type_infer.cc | 12 +--- tests/cpp/relay_transform_sequential.cc | 16 ++--- tests/python/relay/test_pass_manager.py | 8 +-- 21 files changed, 72 insertions(+), 312 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8fdeda947e20..81587339f2ad 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -423,86 +423,6 @@ TVM_DLL Expr PartialEval(const Expr& e); */ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); -/*! - * \brief Simplify certain operators during inference. For example, batch norm - * will be unpacked into a number of simplified operators. - * - * \param expr The expression. - * - * \return The updated expression. - */ -TVM_DLL Expr SimplifyInference(const Expr& expr); - -/*! - * \brief Search and eliminate common subexpression. For example, if there are - * two expressions evaluated to an identical value, a single variable is created - * and these two expressions are replaced by this variable. - * - * \param expr The expression. - * \param fskip The callback argument that allows to skip certain expressions. - * - * \return The updated expression. - */ -TVM_DLL Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc fskip); - -/*! - * \brief Combine parallel 2d convolutions into a single convolution if the - * number of branches of this conv2d operator is not less than - * `min_num_branch`. - * - * \param expr The expression. - * \param min_num_branches The minimun number of branches. - * - * \return The updated expression. - */ -TVM_DLL Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches); - -namespace fold_scale_axis { - -/*! - * \brief Backward fold axis scaling into weights of conv/dense operators. - * - * \param expr The input expression. - * - * \return The updated expression. - */ -TVM_DLL Expr BackwardFoldScaleAxis(const Expr& expr); - -/*! - * \brief Forward fold axis scaling into weights of conv/dense operators. - * - * \param expr The input expression. - * - * \return The updated expression. - */ -TVM_DLL Expr ForwardFoldScaleAxis(const Expr& expr); - -} // namespace fold_scale_axis - -/*! - * \brief Canonicalize some operators to the simplified operators. For example, - * bias_add can be canonicalized to expand_dims and broadcast_add. - * - * \param expr The input expression. - * - * \return The updated expression. - */ -TVM_DLL Expr CanonicalizeOps(const Expr& expr); - -namespace alter_op_layout { - -/*! - * \brief Alternate the layouts of operators or replace primitive operators with - * other expressions. - * - * \param expr The input expression. - * - * \return The updated expression. - */ -TVM_DLL Expr AlterOpLayout(const Expr& expr); - -} // namespace alter_op_layout - /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 94ae7a4b09b0..2103253f8d95 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -311,53 +311,6 @@ class Sequential : public Pass { using ContainerType = Sequential; }; -/*! - * \brief A class to save registered passes. - */ -class PassRegistry { - public: - /* - * \brief Get the global pass registry. - */ - static PassRegistry& Global(); - - /* - * \brief Look up a pass in the registered pass map using the pass name. - * - * \param name The name of the pass to be looked up. - * - * \return The correponding pass if it is found. Otherwise, an empty pass. - */ - const Pass Lookup(const std::string& name) const; - - /* - * \brief Look up a pass in the registered pass map. - * - * \param name The name of the pass to be looked up. - * - * \return The correponding pass if it is found. Otherwise, an empty pass. - */ - const Pass Lookup(const Pass& pass) const; - - /* - * \brief Register a pass. - * - * \param pass The pass to be required. - */ - const Pass RegisterPass(const Pass& pass); - - private: - PassRegistry() = default; - ~PassRegistry() = default; - - PassRegistry(const PassRegistry&) = delete; - PassRegistry(PassRegistry&&) = delete; - PassRegistry& operator=(const PassRegistry&) = delete; - PassRegistry& operator=(PassRegistry&&) = delete; - - std::unordered_map registered_pass_map_; -}; - /* * \brief Create a module pass. * @@ -516,19 +469,6 @@ TVM_DLL Pass SimplifyInference(); */ TVM_DLL Pass InferType(); -/*! - * \brief Infer the type of an expression. - * - * The result of type checking is a new expression with unambigous - * type information filled in, as well as it's checked type field - * populated with the result type. - * - * \param var The global variable corresponding to the function being optimized. - * - * \return The pass. - */ -TVM_DLL Pass InferType(const GlobalVar& var); - /*! * \brief Search and eliminate common subexpression. For example, if there are * two expressions evaluated to an identical value, a single variable is created diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index a4f277026558..da1f7ff2b2c3 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck +# pylint: disable=invalid-name """ This file contains the pass manager for Relay which exposes different granularity of interfaces for users to implement and use passes more @@ -396,7 +397,7 @@ def create_function_pass(pass_func): return create_function_pass -def infer_type(): +def InferType(): """Infer the type of an expr. Returns @@ -407,41 +408,7 @@ def infer_type(): return _transform.InferType() -def backward_fold_scale_axis(): - """Backward fold axis scaling into weights of conv2d/dense. - - Returns - ------- - ret : tvm.relay.Pass - The registered pass to backward fold expressions. - - Note - ---- - It is recommended to call backward_fold_scale_axis before using - forward_fold_scale_axis. As backward folding targets common conv-bn - pattern. - """ - return _transform.BackwardFoldScaleAxis() - - -def forward_fold_scale_axis(): - """Fold the scaling of axis into weights of conv2d/dense. - - Returns - ------- - ret : tvm.relay.Pass - The registered pass to forward fold expressions. - - Note - ---- - It is recommended to call backward_fold_scale_axis before using - forward_fold_scale_axis. As backward folding targets common conv-bn - pattern. - """ - return _transform.ForwardFoldScaleAxis() - - -def fold_scale_axis(): +def FoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. This pass will invoke both forward and backward scale folding. @@ -452,14 +419,14 @@ def fold_scale_axis(): Note ---- - It is recommended to call backward_fold_scale_axis before using + Internally, we will call backward_fold_scale_axis before using forward_fold_scale_axis. As backward folding targets common conv-bn pattern. """ return _transform.FoldScaleAxis() -def simplify_inference(): +def SimplifyInference(): """Simplify the data-flow graph for inference phase. An simplified expression which is semantically equal to the input expression will be returned. @@ -471,7 +438,7 @@ def simplify_inference(): return _transform.SimplifyInference() -def canonicalize_ops(): +def CanonicalizeOps(): """ Canonicalize special operators to basic operators. This can simplify followed analysis. (e.g. expanding bias_add to expand_dims and broadcast_add.) @@ -484,7 +451,7 @@ def canonicalize_ops(): return _transform.CanonicalizeOps() -def dead_code_elimination(): +def DeadCodeElimination(): """ Remove expressions which does not effect the program result (dead code). Returns @@ -495,7 +462,7 @@ def dead_code_elimination(): return _transform.DeadCodeElimination() -def fold_constant(): +def FoldConstant(): """Fold the constant expression in expr. Returns @@ -506,7 +473,7 @@ def fold_constant(): return _transform.FoldConstant() -def fuse_ops(fuse_opt_level=-1): +def FuseOps(fuse_opt_level=-1): """Fuse operators in an expr to a larger operator according to some rules. Parameters @@ -523,7 +490,7 @@ def fuse_ops(fuse_opt_level=-1): return _transform.FuseOps(fuse_opt_level) -def combine_parallel_conv2d(min_num_branches=3): +def CombineParallelConv2D(min_num_branches=3): """Combine multiple conv2d operators into one. Parameters @@ -540,7 +507,7 @@ def combine_parallel_conv2d(min_num_branches=3): return _transform.CombineParallelConv2D(min_num_branches) -def alter_op_layout(): +def AlterOpLayout(): """Alternate the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in custom layouts or @@ -554,7 +521,7 @@ def alter_op_layout(): return _transform.AlterOpLayout() -def rewrite_annotated_ops(fallback_device): +def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. `on_deivce`, mark which device an expression should be scheduled to. This pass helps heterogeneous execution where different operators may need @@ -575,7 +542,7 @@ def rewrite_annotated_ops(fallback_device): return _transform.RewriteDeviceAnnotation(fallback_device) -def to_a_normal_form(): +def ToANormalForm(): """Turn Graph Normal Form expression into A Normal Form Expression. The scope of the root expression is the global scope. The scope of any non root expression is the least common ancestor of all it's scope. @@ -589,7 +556,7 @@ def to_a_normal_form(): return _transform.ToANormalForm() -def to_graph_normal_form(): +def ToGraphNormalForm(): """Turn A Normal Form expression into Graph Normal Form expression Returns @@ -600,7 +567,7 @@ def to_graph_normal_form(): return _transform.ToGraphNormalForm() -def eliminate_common_subexpr(fskip=None): +def EliminateCommonSubexpr(fskip=None): """Eliminate common subexpressions. Parameters @@ -617,7 +584,7 @@ def eliminate_common_subexpr(fskip=None): return _transform.EliminateCommonSubexpr(fskip) -def partial_evaluate(): +def PartialEval(): """Evaluate the static fragment of the code. Returns diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 213bc03388bc..2077394ff2e5 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -360,9 +360,8 @@ Pass AlterOpLayout() { [=](Function f, Module m, PassContext pc) { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; - Pass pass = CreateFunctionPass(pass_func, 3, "alter_op_layout", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 3, "alter_op_layout", + {ir::StringImm::make("infer_type")}); } TVM_REGISTER_API("relay._transform.AlterOpLayout") diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 8f94485894d5..fab7ebac81b6 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -71,9 +71,8 @@ Pass CanonicalizeOps() { [=](Function f, Module m, PassContext pc) { return Downcast(CanonicalizeOps(f)); }; - Pass pass = CreateFunctionPass(pass_func, 3, "canonicalize_ops", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 3, "canonicalize_ops", + {ir::StringImm::make("infer_type")}); } TVM_REGISTER_API("relay._transform.CanonicalizeOps") diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 2c3800d60f6e..3d8ea29b8c5d 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -365,9 +365,8 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { [=](Function f, Module m, PassContext pc) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; - Pass pass = CreateFunctionPass(pass_func, 4, "combine_parallel_conv2d", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 4, "combine_parallel_conv2d", + {ir::StringImm::make("infer_type")}); } TVM_REGISTER_API("relay._transform.CombineParallelConv2D") diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 0766743e5c49..687608259d64 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -158,8 +158,7 @@ Pass DeadCodeElimination() { [=](Function f, Module m, PassContext pc) { return Downcast(DeadCodeElimination(f)); }; - Pass dec = CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); - return PassRegistry::Global().RegisterPass(dec); + return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); } TVM_REGISTER_API("relay._transform.DeadCodeElimination") diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 21da33f51901..bd0ba970695f 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -558,9 +558,8 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, Module m, PassContext pc) { return Downcast(RewriteAnnotatedOps(f, fallback_device)); }; - Pass pass = CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", + {ir::StringImm::make("infer_type")}); } TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation") diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index 3deebdf37c00..5921201986a8 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -95,10 +95,8 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { [=](Function f, Module m, PassContext pc) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; - Pass pass = CreateFunctionPass(pass_func, 3, "eliminate_common_subexpr", - {ir::StringImm::make("infer_type")}); - PassRegistry::Global().RegisterPass(pass); - return pass; + return CreateFunctionPass(pass_func, 3, "eliminate_common_subexpr", + {ir::StringImm::make("infer_type")}); } TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr") diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 036533265d8a..f782e8e846b5 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -223,8 +223,7 @@ Pass FoldConstant() { [=](Function f, Module m, PassContext pc) { return Downcast(FoldConstant(f)); }; - Pass pass = CreateFunctionPass(pass_func, 2, "fold_constant", {}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 2, "fold_constant", {}); } TVM_REGISTER_API("relay._transform.FoldConstant") diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index db9340c93d3b..23fe391f30fd 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -960,9 +960,8 @@ Pass ForwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; - Pass pass = CreateFunctionPass(pass_func, 3, "forward_fold_scale_axis", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 3, "forward_fold_scale_axis", + {ir::StringImm::make("infer_type")}); } Pass BackwardFoldScaleAxis() { @@ -971,9 +970,8 @@ Pass BackwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; - Pass pass = CreateFunctionPass(pass_func, 3, "backward_fold_scale_axis", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 3, "backward_fold_scale_axis", + {ir::StringImm::make("infer_type")}); } Pass FoldScaleAxis() { @@ -982,16 +980,9 @@ Pass FoldScaleAxis() { Pass pass = Sequential( {FoldConstant(), BackwardFoldScaleAxis(), ForwardFoldScaleAxis()}, "fold_scale_axis"); - - return PassRegistry::Global().RegisterPass(pass); + return pass; } -TVM_REGISTER_API("relay._transform.ForwardFoldScaleAxis") -.set_body_typed(ForwardFoldScaleAxis); - -TVM_REGISTER_API("relay._transform.BackwardFoldScaleAxis") -.set_body_typed(BackwardFoldScaleAxis); - TVM_REGISTER_API("relay._transform.FoldScaleAxis") .set_body_typed(FoldScaleAxis); diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index c915c604d722..34b1d5b3a38c 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -220,8 +220,7 @@ Pass ForwardRewrite(const std::string& rewrite_map_attr_name, fcontext, fmulti_ref_trigger)); }; - Pass pass = CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); } Pass ForwardRewrite(const FForwardRewrite& rewrite_func, @@ -234,8 +233,7 @@ Pass ForwardRewrite(const FForwardRewrite& rewrite_func, fcontext, fmulti_ref_trigger)); }; - Pass pass = CreateFunctionPass(pass_func, 1, "forward_rewrite_fun", {}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 1, "forward_rewrite_fun", {}); } } // namespace transform diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 1422ab91fa9c..adedbd2d3d24 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -974,9 +974,8 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - Pass pass = CreateFunctionPass(pass_func, 1, "fuse_ops", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 1, "fuse_ops", + {ir::StringImm::make("infer_type")}); } TVM_REGISTER_API("relay._transform.FuseOps") diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 78b2b217df04..caf4d1d3bbd1 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -806,8 +806,7 @@ Pass PartialEval() { [=](Function f, Module m, PassContext pc) { return Downcast(PartialEval(f)); }; - Pass pass = CreateFunctionPass(pass_func, 1, "partial_eval", {}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 1, "partial_eval", {}); } TVM_REGISTER_API("relay._transform.PartialEval") diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 1525a0744abb..a9ad24842589 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -44,63 +44,40 @@ namespace { Pass GetPass(const std::string& pass_name) { if (pass_name == "infer_type") { return InferType(); - } else if (pass_name == "simplify_inference") { - return SimplifyInference(); } else if (pass_name == "alter_op_layout") { return AlterOpLayout(); } else if (pass_name == "canonicalize_ops") { return CanonicalizeOps(); } else if (pass_name == "combine_parallel_conv2d") { return CombineParallelConv2D(); + } else if (pass_name == "dead_code_elimination") { + return DeadCodeElimination(); + } else if (pass_name == "eliminate_common_subexpr") { + return DeadCodeElimination(); } else if (pass_name == "fold_constant") { return FoldConstant(); + } else if (pass_name == "backward_fold_scale_axis") { + return FoldScaleAxis(); + } else if (pass_name == "forward_fold_scale_axis") { + return FoldScaleAxis(); } else if (pass_name == "fold_scale_axis") { return FoldScaleAxis(); + } else if (pass_name == "partial_eval") { + return SimplifyInference(); + } else if (pass_name == "simplify_inference") { + return SimplifyInference(); } else if (pass_name == "to_a_normal_form") { return ToANormalForm(); } else if (pass_name == "to_graph_normal_form") { return ToGraphNormalForm(); } else { - LOG(FATAL) << pass_name << " has not been registered yet." << "\n"; + LOG(WARNING) << pass_name << " has not been registered yet." << "\n"; return Pass(nullptr); } } } // namespace -PassRegistry& PassRegistry::Global() { - static PassRegistry registry; - return registry; -} - -const Pass PassRegistry::Lookup(const std::string& name) const { - auto it = registered_pass_map_.find(name); - if (it == registered_pass_map_.end()) { - return Pass(nullptr); - } - return it->second; -} - -const Pass PassRegistry::Lookup(const Pass& pass) const { - PassInfo info = pass->Info(); - return Lookup(info->name); -} - -const Pass PassRegistry::RegisterPass(const Pass& pass) { - CHECK(pass.defined()) << "Undefined passes are not allowed to be registered." - << "\n"; - - PassInfo info = pass->Info(); - std::string name = info->name; - const Pass pa = Lookup(name); - if (pa.defined()) { - return pa; - } else { - registered_pass_map_.insert({name, pass}); - return pass; - } -} - struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; @@ -370,12 +347,10 @@ Module ModulePassNode::operator()(const Module& mod, // Execute the required passes in a DFS way. // TODO(zhiics) We may need to pass validation to detect the cyclic // dependency. - PassRegistry& registry = PassRegistry::Global(); for (const auto& it : pass_info->required) { const auto* name = it.as(); CHECK(name); - Pass pass = registry.Lookup(name->value); - pass = pass.defined() ? pass : GetPass(name->value); + auto pass = GetPass(name->value); const auto* pass_node = pass.operator->(); updated_mod = (*pass_node)(updated_mod, pass_ctx); } @@ -407,12 +382,10 @@ Module FunctionPassNode::operator()(const Module& mod, // Execute the required passes in a DFS way. // TODO(zhiics) We may need to pass validation to detect the cyclic // dependency. - PassRegistry& registry = PassRegistry::Global(); for (const auto& it : pass_info->required) { const auto* name = it.as(); CHECK(name); - Pass pass = registry.Lookup(name->value); - pass = pass.defined() ? pass : GetPass(name->value); + auto pass = GetPass(name->value); const auto* pass_node = pass.operator->(); updated_mod = (*pass_node)(updated_mod, pass_ctx); } @@ -501,18 +474,16 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { return true; } - PassRegistry& registry = PassRegistry::Global(); - const Pass registered_pass = registry.Lookup(pass_name); + const Pass pass = GetPass(pass_name); - if (!registered_pass.defined()) { + if (!pass.defined()) { LOG(WARNING) << pass_name - << " is not registered to the pass registry, it will be " - "forced to execute." + << " is not registered yet, it will be forced to execute." << "\n"; return true; } - PassInfo info = registered_pass->Info(); + PassInfo info = pass->Info(); return ctx->opt_level >= info->opt_level; } diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index ff7d9bb1957c..d0e866d9a6f8 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -113,9 +113,8 @@ Pass SimplifyInference() { [=](Function f, Module m, PassContext pc) { return Downcast(SimplifyInference(f)); }; - Pass pass = CreateFunctionPass(pass_func, 0, "simplify_inference", - {ir::StringImm::make("infer_type")}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 0, "simplify_inference", + {ir::StringImm::make("infer_type")}); } TVM_REGISTER_API("relay._transform.SimplifyInference") diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index e2637865b2c3..313d1905f767 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -340,8 +340,7 @@ Pass ToANormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToANormalForm(f, m)); }; - Pass pass = CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); } TVM_REGISTER_API("relay._transform.ToANormalForm") diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 79d9c3d9ba61..7cd4431c8f53 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -86,8 +86,7 @@ Pass ToGraphNormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToGraphNormalForm(f)); }; - Pass pass = CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); } TVM_REGISTER_API("relay._transform.ToGraphNormalForm") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 0dbc04cf259e..54df9af7d71e 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -816,17 +816,7 @@ Pass InferType() { [=](Function f, Module m, PassContext pc) { return Downcast(InferType(f, m)); }; - Pass pass = CreateFunctionPass(pass_func, 0, "infer_type", {}); - return PassRegistry::Global().RegisterPass(pass); -} - -Pass InferType(const GlobalVar& var) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(InferType(f, m, var)); - }; - Pass pass = CreateFunctionPass(pass_func, 0, "infer_type_var", {}); - return PassRegistry::Global().RegisterPass(pass); + return CreateFunctionPass(pass_func, 0, "infer_type", {}); } TVM_REGISTER_API("relay._transform.InferType") diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 0b3ba8a0ec2a..40011a278009 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -73,8 +73,10 @@ TEST(Relay, Sequential) { relay::transform::EliminateCommonSubexpr(tvm::PackedFunc(nullptr))); pass_seqs.push_back(relay::transform::AlterOpLayout()); - auto pfb = tvm::runtime::Registry::Get("relay._module.Module_FromExpr"); - relay::Module mod = (*pfb)(func); + relay::GlobalVar var = relay::GlobalVarNode::make("main"); + tvm::Map m; + m.Set(var, func); + auto mod = relay::ModuleNode::make(m, {}); relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); { tvm::With pass_ctx( @@ -85,7 +87,6 @@ TEST(Relay, Sequential) { CHECK(mod.defined()); - relay::GlobalVar var = mod->entry_func; relay::Function f; for (const auto& kv : mod->functions) { f = kv.second; @@ -99,16 +100,11 @@ TEST(Relay, Sequential) { y1 = relay::CallNode::make(add_op, {x1, y1}); auto zz = relay::CallNode::make(add_op, {y1, c1}); zz = relay::CallNode::make(add_op, {zz, zz}); - relay::Function expected = + relay::Function expected_func = relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); // Infer type for the expected function. - auto infer = tvm::runtime::Registry::Get("relay._ir_pass.infer_type"); - if (!infer) { - LOG(FATAL) << "infer_type pass is not registered"; - } - expected = (*infer)(expected, nullptr); - + auto expected = relay::InferType(expected_func, relay::Module(nullptr)); CHECK(relay::AlphaEqual(f, expected)); } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 0035b72507df..85e6e051bfc8 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -424,10 +424,10 @@ def expected(): return relay.Function([x], z1) seq = _transform.Sequential([ - relay.transform.infer_type(), - relay.transform.fold_constant(), - relay.transform.eliminate_common_subexpr(), - relay.transform.alter_op_layout() + relay.transform.InferType(), + relay.transform.FoldConstant(), + relay.transform.EliminateCommonSubexpr(), + relay.transform.AlterOpLayout() ]) mod = relay.Module({"main": before()}) From 93ed0044732f39f45c285dcfcc6acfce6455169a Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 28 May 2019 22:48:38 +0000 Subject: [PATCH 04/12] default fskip for cse --- include/tvm/relay/transform.h | 21 +++++++++++++++++++-- src/relay/backend/build_module.cc | 21 +-------------------- src/relay/pass/pass_manager.cc | 2 +- tests/cpp/relay_transform_sequential.cc | 3 +-- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 2103253f8d95..ea38f785eabb 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -58,9 +58,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -338,7 +340,7 @@ Pass CreateModulePass( * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, Module, PassContext)>& pass_func, + Function(Function, Module, PassContext)>& pass_func, int opt_level, const std::string& name, const tvm::Array& required); @@ -478,7 +480,22 @@ TVM_DLL Pass InferType(); * * \return The pass. */ -TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip); +TVM_DLL Pass EliminateCommonSubexpr( + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }) +); /*! * \brief Combine parallel 2d convolutions into a single convolution if the diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index d76f9c786a87..4c39cfa1f91e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -23,9 +23,7 @@ */ #include #include -#include #include -#include #include #include @@ -307,24 +305,7 @@ class RelayBuildModule : public runtime::ModuleNode { const std::unordered_map& params) { Array pass_seqs; pass_seqs.push_back(transform::SimplifyInference()); - - // Can we move to the pass implementation file and make it as default? - auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == HalideIR::Int(32)) { - *rv = true; - } - } - } - *rv = false; - }); - - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::EliminateCommonSubexpr()); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeOps()); diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a9ad24842589..747709583f08 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -499,7 +499,7 @@ Module SequentialNode::operator()(const Module& module, PassInfo info = pass->Info(); const auto& pass_name = info->name; // Execute the pass if it is enabled. - if (pass_enabled(pass_name)) { + if (PassEnabled(pass_name)) { const auto* pn = pass.operator->(); mod = (*pn)(mod, pass_ctx); } diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 40011a278009..a6ed0069c1e0 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -69,8 +69,7 @@ TEST(Relay, Sequential) { tvm::Array pass_seqs; pass_seqs.push_back(relay::transform::InferType()); pass_seqs.push_back(relay::transform::DeadCodeElimination()); - pass_seqs.push_back( - relay::transform::EliminateCommonSubexpr(tvm::PackedFunc(nullptr))); + pass_seqs.push_back(relay::transform::EliminateCommonSubexpr()); pass_seqs.push_back(relay::transform::AlterOpLayout()); relay::GlobalVar var = relay::GlobalVarNode::make("main"); From 30de2069aa339d4cb4e18aa45038cca76e0a5573 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 28 May 2019 23:56:24 +0000 Subject: [PATCH 05/12] fix fromexpr --- include/tvm/relay/module.h | 26 ++++++++++++------------- tests/cpp/relay_transform_sequential.cc | 25 ++++++++++-------------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 6441fb3f5b9c..3966a6258a20 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -87,14 +87,14 @@ class ModuleNode : public RelayNode { * \param update Controls whether you can replace a definition in the * environment. */ - void Add(const GlobalVar& var, const Function& func, bool update = false); + TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); /*! * \brief Add a type-level definition to the global environment. * \param var The var of the global type definition. * \param type The type definition. */ - void AddDef(const GlobalTypeVar& var, const TypeData& type); + TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type); /*! * \brief Add a function to the global environment. @@ -103,69 +103,69 @@ class ModuleNode : public RelayNode { * * It does not do type inference as Add does. */ - void AddUnchecked(const GlobalVar& var, const Function& func); + TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); /*! * \brief Update a function in the global environment. * \param var The name of the global function to update. * \param func The new function. */ - void Update(const GlobalVar& var, const Function& func); + TVM_DLL void Update(const GlobalVar& var, const Function& func); /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. */ - void Remove(const GlobalVar& var); + TVM_DLL void Remove(const GlobalVar& var); /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalVar GetGlobalVar(const std::string& str); + TVM_DLL GlobalVar GetGlobalVar(const std::string& str); /*! * \brief Look up a global function by its name. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalTypeVar GetGlobalTypeVar(const std::string& str); + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str); /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ - Function Lookup(const GlobalVar& var); + TVM_DLL Function Lookup(const GlobalVar& var); /*! * \brief Lookup a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ - Function Lookup(const std::string& name); + TVM_DLL Function Lookup(const std::string& name); /*! * \brief Lookup a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ - TypeData LookupDef(const GlobalTypeVar& var); + TVM_DLL TypeData LookupDef(const GlobalTypeVar& var); /*! * \brief Lookup a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ - TypeData LookupDef(const std::string& var); + TVM_DLL TypeData LookupDef(const std::string& var); /*! * \brief Update the functions inside this environment by * functions in another environment. * \param other The other environment. */ - void Update(const Module& other); + TVM_DLL void Update(const Module& other); /*! \brief Construct a module from a standalone expression. * @@ -177,7 +177,7 @@ class ModuleNode : public RelayNode { * * \returns A module with expr set as the entry point. */ - static Module FromExpr( + TVM_DLL static Module FromExpr( const Expr& expr, const tvm::Map& global_funcs = {}); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index a6ed0069c1e0..eab89388970a 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -66,17 +66,14 @@ TEST(Relay, Sequential) { (*reg)("add", "FTVMSchedule", *sch, 10); // Run sequential passes. - tvm::Array pass_seqs; - pass_seqs.push_back(relay::transform::InferType()); - pass_seqs.push_back(relay::transform::DeadCodeElimination()); - pass_seqs.push_back(relay::transform::EliminateCommonSubexpr()); - pass_seqs.push_back(relay::transform::AlterOpLayout()); - - relay::GlobalVar var = relay::GlobalVarNode::make("main"); - tvm::Map m; - m.Set(var, func); - auto mod = relay::ModuleNode::make(m, {}); + tvm::Array pass_seqs{ + relay::transform::InferType(), + relay::transform::DeadCodeElimination(), + relay::transform::EliminateCommonSubexpr(), + relay::transform::AlterOpLayout() + }; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); + auto mod = relay::ModuleNode::FromExpr(func); { tvm::With pass_ctx( relay::transform::PassContext(3, 1, {}, {})); @@ -85,11 +82,9 @@ TEST(Relay, Sequential) { } CHECK(mod.defined()); - - relay::Function f; - for (const auto& kv : mod->functions) { - f = kv.second; - } + auto entry_func = mod->entry_func; + CHECK(entry_func.defined()); + relay::Function f = mod->Lookup(entry_func->name_hint); CHECK(f.defined()); // Expected function From 031bfe99324ec4273c0b6d54e026f5dd400d3fbd Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 29 May 2019 02:41:14 +0000 Subject: [PATCH 06/12] rebase --- src/relay/backend/build_module.cc | 4 +-- src/relay/pass/pass_manager.cc | 36 ++++++++++++------------- tests/cpp/relay_transform_sequential.cc | 8 +++--- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 4c39cfa1f91e..d0ab6d79abdb 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -321,10 +321,10 @@ class RelayBuildModule : public runtime::ModuleNode { if (targets.size() == 1) { for (const auto& kv : targets) { With tctx(kv.second); - relay_module = seq->operator()(relay_module); + relay_module = seq(relay_module); } } else { - relay_module = seq->operator()(relay_module); + relay_module = seq(relay_module); } return relay_module; } diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 747709583f08..dbd1496baffd 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -351,8 +351,7 @@ Module ModulePassNode::operator()(const Module& mod, const auto* name = it.as(); CHECK(name); auto pass = GetPass(name->value); - const auto* pass_node = pass.operator->(); - updated_mod = (*pass_node)(updated_mod, pass_ctx); + updated_mod = pass(updated_mod, pass_ctx); } updated_mod = pass_func(updated_mod, pass_ctx); @@ -386,8 +385,7 @@ Module FunctionPassNode::operator()(const Module& mod, const auto* name = it.as(); CHECK(name); auto pass = GetPass(name->value); - const auto* pass_node = pass.operator->(); - updated_mod = (*pass_node)(updated_mod, pass_ctx); + updated_mod = pass(updated_mod, pass_ctx); } Module new_mod = ModuleNode::make({}, mod->type_definitions); @@ -500,8 +498,7 @@ Module SequentialNode::operator()(const Module& module, const auto& pass_name = info->name; // Execute the pass if it is enabled. if (PassEnabled(pass_name)) { - const auto* pn = pass.operator->(); - mod = (*pn)(mod, pass_ctx); + mod = pass(mod, pass_ctx); } } return mod; @@ -557,15 +554,17 @@ TVM_REGISTER_API("relay._transform.CreateModulePass") TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Pass()(args[1]); + Pass pass = args[0]; + Module mod = args[1]; + *ret = pass(mod); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ModulePassNode* node, tvm::IRPrinter* p) { - const PassInfoNode* pn = node->Info().operator->(); - p->stream << "Run Module pass: " << pn->name - << " at the optimization level " << pn->opt_level; + const PassInfo info = node->Info(); + p->stream << "Run Module pass: " << info->name + << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(FunctionPassNode); @@ -576,9 +575,9 @@ TVM_REGISTER_API("relay._transform.CreateFunctionPass") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionPassNode* node, tvm::IRPrinter* p) { - const PassInfoNode* pn = node->Info().operator->(); - p->stream << "Run Function pass: " << pn->name - << " at the optimization level " << pn->opt_level; + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name + << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(SequentialNode); @@ -596,14 +595,13 @@ TVM_REGISTER_API("relay._transform.Sequential") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SequentialNode* node, tvm::IRPrinter* p) { - const PassInfoNode* seq_pn = node->Info().operator->(); - p->stream << "Run Sequential pass: " << seq_pn->name - << " at the optimization level " << seq_pn->opt_level << ". "; + const PassInfo info = node->Info(); + p->stream << "Run Sequential pass: " << info->name + << " at the optimization level " << info->opt_level << ". "; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { - const PassNode* pn = it.operator->(); - const PassInfoNode* pass_info_node = pn->Info().operator->(); - p->stream << pass_info_node->name << " "; + const PassInfo pass_info = it->Info(); + p->stream << pass_info->name << " "; } p->stream << "]"; }); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index eab89388970a..b61a5cc0daad 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -74,11 +74,13 @@ TEST(Relay, Sequential) { }; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); auto mod = relay::ModuleNode::FromExpr(func); + auto pass_ctx = relay::transform::PassContext::Create(); + pass_ctx->opt_level = 3; + pass_ctx->fallback_device = 1; { - tvm::With pass_ctx( - relay::transform::PassContext(3, 1, {}, {})); + tvm::With ctx_scope(pass_ctx); tvm::With tctx(tvm::Target::Create("llvm")); - mod = seq->operator()(mod); + mod = seq(mod); } CHECK(mod.defined()); From 83c25274e8336c7d44c61662576c6a3f1af4bff0 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 30 May 2019 17:35:53 +0000 Subject: [PATCH 07/12] move more to transform --- include/tvm/relay/transform.h | 17 +-------------- src/relay/backend/build_module.cc | 35 +++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ea38f785eabb..793bc981ea61 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -480,22 +480,7 @@ TVM_DLL Pass InferType(); * * \return The pass. */ -TVM_DLL Pass EliminateCommonSubexpr( - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == HalideIR::Int(32)) { - *rv = true; - } - } - } - *rv = false; - }) -); +TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); /*! * \brief Combine parallel 2d convolutions into a single convolution if the diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index d0ab6d79abdb..eb323a94d38a 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -305,7 +305,21 @@ class RelayBuildModule : public runtime::ModuleNode { const std::unordered_map& params) { Array pass_seqs; pass_seqs.push_back(transform::SimplifyInference()); - pass_seqs.push_back(transform::EliminateCommonSubexpr()); + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeOps()); @@ -377,9 +391,11 @@ class RelayBuildModule : public runtime::ModuleNode { Function RunDeviceAnnotationPass(Function func, int fallback_device, TargetsMap* targets_map_ptr) { - auto new_func = relay::InferType(func, Module(nullptr)); - new_func = relay::RewriteAnnotatedOps(new_func, fallback_device); - func = Downcast(new_func); + relay::Module relay_module = relay::ModuleNode::FromExpr(func); + auto rewrite = transform::RewriteAnnotatedOps(fallback_device); + relay_module = rewrite(relay_module); + CHECK(relay_module.defined()); + func = relay_module->Lookup(relay_module->entry_func->name_hint); CHECK(func.defined()); auto device_map = relay::CollectDeviceInfo(func); if (device_map.size() == 0) { @@ -432,11 +448,12 @@ class RelayBuildModule : public runtime::ModuleNode { &device_target); } - auto new_func = relay::InferType(func, Module(nullptr)); - new_func = relay::FuseOps(new_func, pass_ctx->opt_level, Module(nullptr)); - new_func = relay::InferType(new_func, Module(nullptr)); - func = Downcast(new_func); - CHECK(func.defined()); + relay::Module relay_module = relay::ModuleNode::FromExpr(func); + relay_module = transform::InferType()(relay_module); + relay_module = transform::FuseOps()(relay_module); + relay_module = transform::InferType()(relay_module); + CHECK(relay_module.defined()); + func = relay_module->Lookup(relay_module->entry_func->name_hint); graph_codegen_ = std::unique_ptr(new GraphCodegen()); graph_codegen_->Init(nullptr, device_target); From a845b0f778d2e30d295fd0e30626617fa586c18a Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 31 May 2019 06:26:44 +0000 Subject: [PATCH 08/12] move all passes to optimize --- src/relay/backend/build_module.cc | 126 ++++++++++++------------------ 1 file changed, 52 insertions(+), 74 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index eb323a94d38a..2b8e168310d6 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -265,30 +265,6 @@ class RelayBuildModule : public runtime::ModuleNode { return ret; } - /*! - * \brief Optimize a Relay function. - * - * \param func The input Relay function for optimization. - * \param targets The device type to `Target` mapping. - * \param params The param name to value mapping. - * - * \return func The updated Relay function after optimization. - */ - relay::Function Optimize( - relay::Function func, - const TargetsMap& targets, - const std::unordered_map& params) { - if (params.size()) { - func = BindParamsByName(func, params); - } - - relay::Module relay_module = relay::ModuleNode::FromExpr(func); - relay_module = Optimize(relay_module, targets_, params); - CHECK(relay_module.defined()); - GlobalVar var = relay_module->entry_func; - return relay_module->Lookup(var->name_hint); - } - /*! * \brief Optimize a Relay module. * @@ -340,6 +316,18 @@ class RelayBuildModule : public runtime::ModuleNode { } else { relay_module = seq(relay_module); } + + // Handle heterogeneous compilation. + transform::PassContext pass_ctx = PassContext::Current(); + if (targets_.size() > 1) { + relay_module = + RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); + } + + // Fuse the operations if it is needed. + relay_module = transform::FuseOps()(relay_module); + relay_module = transform::InferType()(relay_module); + return relay_module; } @@ -354,55 +342,58 @@ class RelayBuildModule : public runtime::ModuleNode { if (name == "gpu") return Target::Create("cuda"); return Target::Create(name); } + /*! * \brief Update the target and fallback device required for heterogeneous * compilation. CPU is used as the fallback device if it wasn't provided. * Meanwhile, a CPU device type and "llvm" pair will be added to the target * dictionary in this case. * - * \param targets dictionary * \param fallback_device The fallback device for heterogeneous execution. - * \return Map */ - TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets, - int fallback_device) { - TargetsMap device_target = targets; + void UpdateHeterogeneousInputs(int fallback_device) { std::unordered_map tmp_map; - for (const auto& kv : targets) { + for (const auto& kv : targets_) { tmp_map[kv.first->value] = kv.second; } if (tmp_map.count(fallback_device) == 0) { - device_target.Set( - fallback_device, - CreateDefaultTarget(fallback_device)); + targets_.Set(fallback_device, CreateDefaultTarget(fallback_device)); } - return device_target; } + /*! * \brief Execute the device annotation passes to update the input program and * target information. * - * \param func The input Relay function. + * \param relay_module The input Relay module. * \param fallback_device The fallback device for heterogeneous execution. - * \param targets_map_ptr The device type to `Target` map pointer. * - * \return func The updated function after device annotation. + * \return updated_module The updated module after device annotation. */ - Function RunDeviceAnnotationPass(Function func, - int fallback_device, - TargetsMap* targets_map_ptr) { - relay::Module relay_module = relay::ModuleNode::FromExpr(func); + relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, + int fallback_device) { + UpdateHeterogeneousInputs(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device); - relay_module = rewrite(relay_module); - CHECK(relay_module.defined()); - func = relay_module->Lookup(relay_module->entry_func->name_hint); - CHECK(func.defined()); - auto device_map = relay::CollectDeviceInfo(func); - if (device_map.size() == 0) { - auto annotation_map = relay::CollectDeviceAnnotationOps(func); - if (annotation_map.size() == 0) { - targets_map_ptr->Set(0, CreateDefaultTarget(fallback_device)); + auto updated_module = rewrite(relay_module); + CHECK(updated_module.defined()); + + tvm::Map device_map; + for (const auto& it : updated_module->functions) { + device_map = relay::CollectDeviceInfo(it.second); + if (!device_map.empty()) break; + } + + if (device_map.empty()) { + tvm::Map annotation_map; + for (const auto& it : relay_module->functions) { + annotation_map = relay::CollectDeviceAnnotationOps(it.second); + if (!annotation_map.empty()) break; + } + // None op is annotated but they are fallen back to the default device. + if (annotation_map.empty()) { + targets_.Set(0, CreateDefaultTarget(fallback_device)); } else { + // All ops are annotated to the same device type. int64_t dev_type = -1; for (auto kv : annotation_map) { dev_type = kv.second->value; @@ -416,10 +407,10 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set(0, CreateDefaultTarget(dev_type)); + targets_.Set(0, CreateDefaultTarget(dev_type)); } } - return func; + return updated_module; } /*! @@ -431,38 +422,27 @@ class RelayBuildModule : public runtime::ModuleNode { void BuildRelay( Function func, const std::unordered_map& params) { - transform::PassContext pass_ctx = PassContext::Current(); - - // convert - tvm_cfg_ = BuildConfig::Create(); - TargetsMap device_target; - if (targets_.size() > 1) { - device_target = - UpdateHeterogeneousInputs(targets_, pass_ctx->fallback_device); - } else { - device_target = targets_; - } - func = Optimize(func, targets_, params); - if (device_target.size() > 1) { - func = RunDeviceAnnotationPass(func, pass_ctx->fallback_device, - &device_target); + if (params.size()) { + func = BindParamsByName(func, params); } + // Perform Module->Module optimizations. relay::Module relay_module = relay::ModuleNode::FromExpr(func); - relay_module = transform::InferType()(relay_module); - relay_module = transform::FuseOps()(relay_module); - relay_module = transform::InferType()(relay_module); + relay_module = Optimize(relay_module, targets_, params); CHECK(relay_module.defined()); + // Get the updated function. func = relay_module->Lookup(relay_module->entry_func->name_hint); + // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, device_target); + graph_codegen_->Init(nullptr, targets_); graph_codegen_->Codegen(func); ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_); + ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, + BuildConfig::Current()); } protected: @@ -475,8 +455,6 @@ class RelayBuildModule : public runtime::ModuleNode { std::unordered_map params_; /*! \brief building output */ BuildOutput ret_; - /*! \brief tvm building cfg */ - BuildConfig tvm_cfg_; }; runtime::Module RelayBuildCreate() { From 4647df4c40d18c23f924e99c3e4b5ea6e827bf7b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 31 May 2019 14:55:27 +0000 Subject: [PATCH 09/12] indentation --- src/relay/backend/build_module.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 2b8e168310d6..1368c2e9ebbf 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -371,7 +371,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \return updated_module The updated module after device annotation. */ relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, - int fallback_device) { + int fallback_device) { UpdateHeterogeneousInputs(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device); auto updated_module = rewrite(relay_module); From 4bc38a90827ff8e8cc101c5dba85d332610e8647 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 31 May 2019 17:41:57 +0000 Subject: [PATCH 10/12] capitalize pass name --- python/tvm/relay/transform.py | 4 ++-- src/relay/pass/alter_op_layout.cc | 4 ++-- src/relay/pass/canonicalize_ops.cc | 4 ++-- src/relay/pass/combine_parallel_conv2d.cc | 4 ++-- src/relay/pass/dead_code.cc | 2 +- src/relay/pass/device_annotation.cc | 4 ++-- src/relay/pass/eliminate_common_subexpr.cc | 4 ++-- src/relay/pass/fold_constant.cc | 2 +- src/relay/pass/fold_scale_axis.cc | 10 ++++---- src/relay/pass/forward_rewrite.cc | 4 ++-- src/relay/pass/fuse_ops.cc | 4 ++-- src/relay/pass/partial_eval.cc | 4 ++-- src/relay/pass/pass_manager.cc | 28 +++++++++++----------- src/relay/pass/simplify_inference.cc | 4 ++-- src/relay/pass/to_a_normal_form.cc | 2 +- src/relay/pass/to_graph_normal_form.cc | 2 +- src/relay/pass/type_infer.cc | 2 +- 17 files changed, 44 insertions(+), 44 deletions(-) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index da1f7ff2b2c3..38079b010e7d 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -584,7 +584,7 @@ def EliminateCommonSubexpr(fskip=None): return _transform.EliminateCommonSubexpr(fskip) -def PartialEval(): +def PartialEvaluate(): """Evaluate the static fragment of the code. Returns @@ -592,4 +592,4 @@ def PartialEval(): ret : tvm.relay.Pass The registered pass that performs partial evaluation on an expression. """ - return _transform.PartialEval() + return _transform.PartialEvaluate() diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 2077394ff2e5..d623393049a6 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -360,8 +360,8 @@ Pass AlterOpLayout() { [=](Function f, Module m, PassContext pc) { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; - return CreateFunctionPass(pass_func, 3, "alter_op_layout", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 3, "AlterOpLayout", + {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.AlterOpLayout") diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index fab7ebac81b6..ff9e2304a3bc 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -71,8 +71,8 @@ Pass CanonicalizeOps() { [=](Function f, Module m, PassContext pc) { return Downcast(CanonicalizeOps(f)); }; - return CreateFunctionPass(pass_func, 3, "canonicalize_ops", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", + {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.CanonicalizeOps") diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 3d8ea29b8c5d..c95c1ddf8e16 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -365,8 +365,8 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { [=](Function f, Module m, PassContext pc) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "combine_parallel_conv2d", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", + {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.CombineParallelConv2D") diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 687608259d64..be6774564806 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -158,7 +158,7 @@ Pass DeadCodeElimination() { [=](Function f, Module m, PassContext pc) { return Downcast(DeadCodeElimination(f)); }; - return CreateFunctionPass(pass_func, 1, "dead_code_elimination", {}); + return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } TVM_REGISTER_API("relay._transform.DeadCodeElimination") diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index bd0ba970695f..5d72e337f586 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -558,8 +558,8 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, Module m, PassContext pc) { return Downcast(RewriteAnnotatedOps(f, fallback_device)); }; - return CreateFunctionPass(pass_func, 1, "rewrite_annotated_ops", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", + {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation") diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index 5921201986a8..883681adcaf4 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -95,8 +95,8 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { [=](Function f, Module m, PassContext pc) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; - return CreateFunctionPass(pass_func, 3, "eliminate_common_subexpr", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", + {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr") diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index f782e8e846b5..815407038b08 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -223,7 +223,7 @@ Pass FoldConstant() { [=](Function f, Module m, PassContext pc) { return Downcast(FoldConstant(f)); }; - return CreateFunctionPass(pass_func, 2, "fold_constant", {}); + return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } TVM_REGISTER_API("relay._transform.FoldConstant") diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 23fe391f30fd..7d47c3ba7eff 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -960,8 +960,8 @@ Pass ForwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "forward_fold_scale_axis", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); } Pass BackwardFoldScaleAxis() { @@ -970,8 +970,8 @@ Pass BackwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "backward_fold_scale_axis", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); } Pass FoldScaleAxis() { @@ -979,7 +979,7 @@ Pass FoldScaleAxis() { // register it as a sequential pass. Pass pass = Sequential( {FoldConstant(), BackwardFoldScaleAxis(), ForwardFoldScaleAxis()}, - "fold_scale_axis"); + "FoldScaleAxis"); return pass; } diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 34b1d5b3a38c..8ad61270e33a 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -220,7 +220,7 @@ Pass ForwardRewrite(const std::string& rewrite_map_attr_name, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite", {}); + return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {}); } Pass ForwardRewrite(const FForwardRewrite& rewrite_func, @@ -233,7 +233,7 @@ Pass ForwardRewrite(const FForwardRewrite& rewrite_func, fcontext, fmulti_ref_trigger)); }; - return CreateFunctionPass(pass_func, 1, "forward_rewrite_fun", {}); + return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {}); } } // namespace transform diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index adedbd2d3d24..9f940e54953b 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -974,8 +974,8 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - return CreateFunctionPass(pass_func, 1, "fuse_ops", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 1, "FuseOps", + {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.FuseOps") diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index caf4d1d3bbd1..71ba7cd11cd5 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -806,10 +806,10 @@ Pass PartialEval() { [=](Function f, Module m, PassContext pc) { return Downcast(PartialEval(f)); }; - return CreateFunctionPass(pass_func, 1, "partial_eval", {}); + return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {}); } -TVM_REGISTER_API("relay._transform.PartialEval") +TVM_REGISTER_API("relay._transform.PartialEvaluate") .set_body_typed(PartialEval); } // namespace transform diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index dbd1496baffd..31f30b618fc7 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -42,33 +42,33 @@ namespace { // TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be // handled because we need to register the pass for Python invocation anyway. Pass GetPass(const std::string& pass_name) { - if (pass_name == "infer_type") { + if (pass_name == "InferType") { return InferType(); - } else if (pass_name == "alter_op_layout") { + } else if (pass_name == "AlterOpLayout") { return AlterOpLayout(); - } else if (pass_name == "canonicalize_ops") { + } else if (pass_name == "CanonicalizeOps") { return CanonicalizeOps(); - } else if (pass_name == "combine_parallel_conv2d") { + } else if (pass_name == "CombineParallelConv2d") { return CombineParallelConv2D(); - } else if (pass_name == "dead_code_elimination") { + } else if (pass_name == "DeadCodeElimination") { return DeadCodeElimination(); - } else if (pass_name == "eliminate_common_subexpr") { + } else if (pass_name == "EliminateCommonSubexpr") { return DeadCodeElimination(); - } else if (pass_name == "fold_constant") { + } else if (pass_name == "FoldConstant") { return FoldConstant(); - } else if (pass_name == "backward_fold_scale_axis") { + } else if (pass_name == "BackwardFoldScaleAxis") { return FoldScaleAxis(); - } else if (pass_name == "forward_fold_scale_axis") { + } else if (pass_name == "ForwardFoldScaleAxis") { return FoldScaleAxis(); - } else if (pass_name == "fold_scale_axis") { + } else if (pass_name == "FoldScaleAxis") { return FoldScaleAxis(); - } else if (pass_name == "partial_eval") { + } else if (pass_name == "PartialEvaluate") { return SimplifyInference(); - } else if (pass_name == "simplify_inference") { + } else if (pass_name == "SimplifyInference") { return SimplifyInference(); - } else if (pass_name == "to_a_normal_form") { + } else if (pass_name == "ToANormalForm") { return ToANormalForm(); - } else if (pass_name == "to_graph_normal_form") { + } else if (pass_name == "ToGraphNormalForm") { return ToGraphNormalForm(); } else { LOG(WARNING) << pass_name << " has not been registered yet." << "\n"; diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index d0e866d9a6f8..f861e9860a59 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -113,8 +113,8 @@ Pass SimplifyInference() { [=](Function f, Module m, PassContext pc) { return Downcast(SimplifyInference(f)); }; - return CreateFunctionPass(pass_func, 0, "simplify_inference", - {ir::StringImm::make("infer_type")}); + return CreateFunctionPass(pass_func, 0, "SimplifyInference", + {ir::StringImm::make("InferType")}); } TVM_REGISTER_API("relay._transform.SimplifyInference") diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 313d1905f767..324eddd21c5c 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -340,7 +340,7 @@ Pass ToANormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToANormalForm(f, m)); }; - return CreateFunctionPass(pass_func, 1, "to_a_normal_form", {}); + return CreateFunctionPass(pass_func, 1, "ToANormalForm", {}); } TVM_REGISTER_API("relay._transform.ToANormalForm") diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 7cd4431c8f53..9c166f98c1a5 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -86,7 +86,7 @@ Pass ToGraphNormalForm() { [=](Function f, Module m, PassContext pc) { return Downcast(ToGraphNormalForm(f)); }; - return CreateFunctionPass(pass_func, 1, "to_graph_normal_form", {}); + return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); } TVM_REGISTER_API("relay._transform.ToGraphNormalForm") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 54df9af7d71e..3fde3c7e7b36 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -816,7 +816,7 @@ Pass InferType() { [=](Function f, Module m, PassContext pc) { return Downcast(InferType(f, m)); }; - return CreateFunctionPass(pass_func, 0, "infer_type", {}); + return CreateFunctionPass(pass_func, 0, "InferType", {}); } TVM_REGISTER_API("relay._transform.InferType") From 6f83f5686f8dd2bf6125ef0786dae57fd83cb3cb Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 31 May 2019 20:35:10 +0000 Subject: [PATCH 11/12] log fatal for unregistered passes --- src/relay/pass/pass_manager.cc | 14 +++----------- tests/python/relay/test_pass_manager.py | 10 +++++++--- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 31f30b618fc7..13e908d28f7a 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -71,7 +71,7 @@ Pass GetPass(const std::string& pass_name) { } else if (pass_name == "ToGraphNormalForm") { return ToGraphNormalForm(); } else { - LOG(WARNING) << pass_name << " has not been registered yet." << "\n"; + LOG(FATAL) << pass_name << " has not been registered yet." << "\n"; return Pass(nullptr); } } @@ -441,7 +441,7 @@ std::unordered_set SequentialNode::DisabledPasses( std::unordered_set ret; for (const auto& it : disabled) { const auto* str = it.as(); - CHECK(str) << "disabled passes must be string."; + CHECK(str) << "Disabled pass name must be string."; ret.emplace(str->value); } return ret; @@ -452,7 +452,7 @@ std::unordered_set SequentialNode::RequiredPasses( std::unordered_set ret; for (const auto& it : required) { const auto* str = it.as(); - CHECK(str) << "disabled passes must be string."; + CHECK(str) << "Required pass name must be string."; ret.emplace(str->value); } return ret; @@ -473,14 +473,6 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const { } const Pass pass = GetPass(pass_name); - - if (!pass.defined()) { - LOG(WARNING) << pass_name - << " is not registered yet, it will be forced to execute." - << "\n"; - return true; - } - PassInfo info = pass->Info(); return ctx->opt_level >= info->opt_level; } diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 85e6e051bfc8..7fdef3fa8b9c 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -327,7 +327,8 @@ def test_no_pass(): def test_only_module_pass(): passes = [module_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + with relay.build_config(required_pass=["mod_transform"]): + ret_mod = sequential(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, sub) @@ -341,7 +342,8 @@ def test_only_function_pass(): # Check the subtract function. passes = [function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + with relay.build_config(required_pass=["func_transform"]): + ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) @@ -355,7 +357,9 @@ def test_multiple_passes(): mod = relay.Module({v_sub: sub, v_log: log}) passes = [module_pass, function_pass] sequential = _transform.Sequential(opt_level=1, passes=passes) - ret_mod = sequential(mod) + required = ["mod_transform", "func_transform"] + with relay.build_config(required_pass=required): + ret_mod = sequential(mod) # Check the abs function is added. abs_var, abs_func = get_var_func() From 940eb9681495c97d41faadfadf98348c0fba137b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 1 Jun 2019 15:38:35 +0000 Subject: [PATCH 12/12] fold constant --- src/relay/backend/build_module.cc | 1 + src/relay/pass/fold_scale_axis.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 1368c2e9ebbf..e0014e919089 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -297,6 +297,7 @@ class RelayBuildModule : public runtime::ModuleNode { }); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeOps()); diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 7d47c3ba7eff..53089807ace5 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -978,7 +978,7 @@ Pass FoldScaleAxis() { // FoldScaleAxis pass contains the following three passes. Therefore, we can // register it as a sequential pass. Pass pass = Sequential( - {FoldConstant(), BackwardFoldScaleAxis(), ForwardFoldScaleAxis()}, + {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, "FoldScaleAxis"); return pass; }