From 82f59861df19887b79655b9643ed8fc0b11250e6 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 21 Jan 2020 07:32:22 +0900 Subject: [PATCH] Expose relay BindParamsByName to Python (#4751) * expose BindParamByName to python * fixed alpha equal test --- python/tvm/relay/build_module.py | 39 +++++++-- src/relay/backend/build_module.cc | 86 ++++++++++--------- tests/python/relay/test_pass_fold_constant.py | 44 ++++++++++ 3 files changed, 124 insertions(+), 45 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 28ce16b9b452..d848d9030c48 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -51,6 +51,15 @@ def _update_target(target): return tgts +def _convert_param_map(params): + inputs = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = _nd.array(param) + inputs[name] = _expr.const(param) + return inputs + + class BuildModule(object): """Build a Relay function to run on TVM graph runtime. This class is used to expose the `RelayBuildModule` APIs implemented in C++. @@ -151,12 +160,7 @@ def optimize(self, func, target=None, params=None): def _set_params(self, params): - inputs = {} - for name, param in params.items(): - if isinstance(param, np.ndarray): - param = _nd.array(param) - inputs[name] = _expr.const(param) - self._set_params_func(inputs) + self._set_params_func(_convert_param_map(params)) def get_json(self): """Return the json file of the built program.""" @@ -296,6 +300,29 @@ def optimize(mod, target=None, params=None): return mod, params +def bind_params_by_name(func, params): + """Bind params to function by name. + This could be useful when assembling custom Relay optimization + passes that involve constant folding. + + Parameters + ---------- + func : relay.Function + The function to bind parameters to. + + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + + Returns + ------- + func : relay.Function + The function with parameters bound + """ + inputs = _convert_param_map(params) + return _build_module.BindParamsByName(func, inputs) + + class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 480fd9e32074..7b9f1e222611 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -41,6 +41,43 @@ using tir::LoweredFunc; using TargetsMap = Map; using namespace tvm::relay::transform; +/*! + * \brief Bind params to function by using name + * \param func Relay function + * \param params params dict + * \return relay::Function + */ +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(arg); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = ConstantNode::make(kv.second); + } + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; +} + /*! * \brief Output of building module * @@ -248,45 +285,6 @@ class RelayBuildModule : public runtime::ModuleNode { } protected: - /*! - * \brief Bind params to function by using name - * \param func Relay function - * \param params params dict - * \return relay::Function - */ - relay::Function BindParamsByName( - relay::Function func, - const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto &name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(arg); - } else { - name_dict[name] = arg; - } - } - - std::unordered_map bind_dict; - for (auto &kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; - } - bind_dict[arg] = ConstantNode::make(kv.second); - } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - CHECK(ret.defined()) - << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; - } - /*! * \brief Optimize a Relay Function. * @@ -522,6 +520,16 @@ TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") *rv = RelayBuildCreate(); }); +TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Map params = args[1]; + std::unordered_map params_; + for (const auto& kv : params) { + params_[kv.first] = kv.second->data; + } + *rv = BindParamsByName(args[0], params_); +}); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 4752597c828e..ca901b16b842 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -18,6 +18,8 @@ import tvm from tvm import relay from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing import run_infer_type, create_workload def run_opt_pass(expr, opt_pass): @@ -161,6 +163,47 @@ def expected(): assert relay.analysis.graph_equal(zz, zexpected) +def test_fold_batch_norm(): + def expected(): + data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.const(np.zeros((16, 3, 3, 3))) + bias = relay.const(np.zeros((16, 1, 1))) + conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), + channels=16, padding=(1, 1)) + add = relay.add(conv, bias) + return relay.Function(relay.analysis.free_vars(add), add) + + remove_bn_pass = transform.Sequential([ + relay.transform.InferType(), + relay.transform.SimplifyInference(), + relay.transform.FoldConstant(), + relay.transform.FoldScaleAxis(), + ]) + + data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("weight") + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + + conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3), + channels=16, padding=(1, 1)) + bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, + bn_mmean, bn_mvar) + def initializer(_, param): + param = np.zeros(param.shape) + + mod, params = create_workload(bn_output[0], initializer) + mod["main"] = bind_params_by_name(mod["main"], params) + + with relay.build_config(opt_level=3): + mod = remove_bn_pass(mod) + + expect = run_infer_type(expected()) + assert relay.analysis.graph_equal(mod["main"], expect) + + if __name__ == "__main__": test_fold_const() test_fold_let() @@ -168,3 +211,4 @@ def expected(): test_fold_concat() test_fold_shape_of() test_fold_full() + test_fold_batch_norm()