Skip to content

Commit

Permalink
Expose relay BindParamsByName to Python (apache#4751)
Browse files Browse the repository at this point in the history
* expose BindParamByName to python

* fixed alpha equal test
  • Loading branch information
masahi authored and zhiics committed Mar 2, 2020
1 parent 61b8461 commit 82f5986
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 45 deletions.
39 changes: 33 additions & 6 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def _update_target(target):
return tgts


def _convert_param_map(params):
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
return inputs


class BuildModule(object):
"""Build a Relay function to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
Expand Down Expand Up @@ -151,12 +160,7 @@ def optimize(self, func, target=None, params=None):


def _set_params(self, params):
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
self._set_params_func(_convert_param_map(params))

def get_json(self):
"""Return the json file of the built program."""
Expand Down Expand Up @@ -296,6 +300,29 @@ def optimize(mod, target=None, params=None):
return mod, params


def bind_params_by_name(func, params):
"""Bind params to function by name.
This could be useful when assembling custom Relay optimization
passes that involve constant folding.
Parameters
----------
func : relay.Function
The function to bind parameters to.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
func : relay.Function
The function with parameters bound
"""
inputs = _convert_param_map(params)
return _build_module.BindParamsByName(func, inputs)


class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
Expand Down
86 changes: 47 additions & 39 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,43 @@ using tir::LoweredFunc;
using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;

/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

/*!
* \brief Output of building module
*
Expand Down Expand Up @@ -248,45 +285,6 @@ class RelayBuildModule : public runtime::ModuleNode {
}

protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

/*!
* \brief Optimize a Relay Function.
*
Expand Down Expand Up @@ -522,6 +520,16 @@ TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
*rv = RelayBuildCreate();
});

TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[1];
std::unordered_map<std::string, runtime::NDArray> params_;
for (const auto& kv : params) {
params_[kv.first] = kv.second->data;
}
*rv = BindParamsByName(args[0], params_);
});

} // namespace backend
} // namespace relay
} // namespace tvm
44 changes: 44 additions & 0 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing import run_infer_type, create_workload


def run_opt_pass(expr, opt_pass):
Expand Down Expand Up @@ -161,10 +163,52 @@ def expected():
assert relay.analysis.graph_equal(zz, zexpected)


def test_fold_batch_norm():
def expected():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.const(np.zeros((16, 3, 3, 3)))
bias = relay.const(np.zeros((16, 1, 1)))
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=16, padding=(1, 1))
add = relay.add(conv, bias)
return relay.Function(relay.analysis.free_vars(add), add)

remove_bn_pass = transform.Sequential([
relay.transform.InferType(),
relay.transform.SimplifyInference(),
relay.transform.FoldConstant(),
relay.transform.FoldScaleAxis(),
])

data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight")
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")

conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=16, padding=(1, 1))
bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta,
bn_mmean, bn_mvar)
def initializer(_, param):
param = np.zeros(param.shape)

mod, params = create_workload(bn_output[0], initializer)
mod["main"] = bind_params_by_name(mod["main"], params)

with relay.build_config(opt_level=3):
mod = remove_bn_pass(mod)

expect = run_infer_type(expected())
assert relay.analysis.graph_equal(mod["main"], expect)


if __name__ == "__main__":
test_fold_const()
test_fold_let()
test_fold_tuple()
test_fold_concat()
test_fold_shape_of()
test_fold_full()
test_fold_batch_norm()

0 comments on commit 82f5986

Please sign in to comment.