Skip to content

Commit

Permalink
Done
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Nov 5, 2019
1 parent dd0f1d5 commit b901e86
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 30 deletions.
63 changes: 44 additions & 19 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}

class BindParams : public ExprMutator {
class BindParams {
public:
explicit BindParams(const std::unordered_map<std::string, runtime::NDArray>& params)
: params_(params) {}
Expand Down Expand Up @@ -891,29 +891,41 @@ class BindParams : public ExprMutator {
}


Expr VisitExpr_(const FunctionNode* op) {
auto f = GetRef<Function>(op);
auto it = visited.find(f);
if (it != visited.end()) {
return it->second;
}
/*
auto func = ExprMutator::VisitExpr(f);
auto res = BindParamsByName(Downcast<Function>(func), params_);
visited[f] = res;
return res;
}
Function FunctionNode::make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Attrs attrs) {
NodePtr<FunctionNode> n = make_node<FunctionNode>();
*/
// Expr VisitExpr_(const FunctionNode* op) {
// auto f = GetRef<Function>(op);
// auto it = visited.find(f);
// if (it != visited.end()) {
// return it->second;
// }

// std::cout << "Bind Visit " << AsText(f, false) << "\n";
// auto func_body = ExprMutator::VisitExpr(f->body);
// auto new_f = FunctionNode::make(FreeVars(func_body), func_body, f->ret_type, f->type_params, f->attrs);
// auto res = BindParamsByName(new_f, params_);
// visited[f] = res;
// return res;
// }

private:
std::unordered_map<std::string, runtime::NDArray> params_;
std::unordered_map<Expr, Expr, NodeHash> visited;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> visited;
};

relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::cout << "BindParams " << func << "\n";
return Downcast<Function>(BindParams(params).Mutate(func));
std::cout << "BindParams " << AsText(func, false) << "\n";
return Downcast<Function>(BindParams(params).BindParamsByName(func, params));
}

void VMCompiler::Compile(Module mod,
Expand All @@ -923,11 +935,24 @@ void VMCompiler::Compile(Module mod,
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto functions = mod->functions;
for (auto p : functions) {
std::cout << "Binding " << p.first << "\n";
auto f = BindParamsByName(p.second, params_);
mod->Add(p.first, f);

// for (auto p : functions) {
{
// std::cout << "Binding " << p.first << "\n";
auto gvar = mod->GetGlobalVar("foreach");
auto f = BindParamsByName(mod->Lookup("foreach"), params_);
std::cout << "After bind params " << AsText(f, false) << "\n";
mod->Add(gvar, f);
}

{
// std::cout << "Binding " << p.first << "\n";
auto gvar = mod->GetGlobalVar("main");
auto f = BindParamsByName(mod->Lookup("main"), params_);
std::cout << "After bind params " << AsText(f, false) << "\n";
mod->Add(gvar, f);
}

} else {
std::cout << "Params empty\n";
}
Expand Down
83 changes: 74 additions & 9 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,30 +354,93 @@ TVM_REGISTER_API("relay._analysis.post_order_visit")
});
});

Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map);

// Implement bind.
class ExprBinder : public ExprMutator, PatternMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) {
for (auto it : args_map) {
args_map_[it.first->name_hint()] = it.second;
args_map2_[it.first] = it.second;
}
}

Expr BindTwice(Expr op) {
auto op1 = VisitExpr(op);
// auto op2 = VisitExpr(op1);
return op1;
}

Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var))
CHECK(!args_map_.count(op->var->name_hint()))
<< "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const FunctionNode* op) final {
std::cout << "Visit Func " << AsText(GetRef<Function>(op), false) << "\n";
bool rebind = false;
Array<Var> new_params;
for (Var param : op->params) {
CHECK(!args_map_.count(param))
<< "Cannnot bind an internal function parameter";
if (args_map_.count(param->name_hint()) > 0) {
rebind = true;
} else {
new_params.push_back(param);
}
}
if (rebind) {
auto new_body = Bind(op->body, args_map2_);
std::cout << "Visit Func End " << AsText(GetRef<Function>(op), false) << "\n";
return FunctionNode::make(new_params, new_body, op->ret_type, op->type_params, op->attrs);
}
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const CallNode* op) final {
std::cout << "Visit Call " << AsText(GetRef<Call>(op), false) << "\n";
auto func = op->op.as<FunctionNode>();
auto gv_node = op->op.as<GlobalVarNode>();
Array<Expr> new_args;
for (size_t i = 0; i < op->args.size(); ++i) {
auto new_arg = ExprMutator::VisitExpr(op->args[i]);
auto n = op->args[i].as<VarNode>();
if (n) {
auto var = GetRef<Var>(n);
std::cout << "call arg " << i << ": " << var << "\n";
auto it = args_map_.find(var->name_hint());
if (it != args_map_.end()) {
std::cout << "Line 404\n";
if (func) {
args_map_.emplace(func->params[i]->name_hint(), it->second);
args_map2_.emplace(func->params[i], it->second);
std::cout << "Emplace " << func->params[i] << "\n";
} else if (gv_node) {
// new_args.push_back(new_arg);
} else {
std::cout << "Line 409 " << op->op << "\n";
new_args.push_back(new_arg);
}
} else {
new_args.push_back(new_arg);
}
} else {
new_args.push_back(new_arg);
}
}

if (func) {
auto new_f = ExprMutator::VisitExpr(op->op);
std::cout << "Visit Call End " << AsText(GetRef<Call>(op), false) << "\n";
return CallNode::make(new_f, new_args, op->attrs, op->type_args);
} else {
return CallNode::make(op->op, new_args, op->attrs, op->type_args);
}
}

Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
auto it = args_map_.find(id);
auto it = args_map_.find(id->name_hint());
if (it != args_map_.end()) {
return (*it).second;
} else {
Expand All @@ -395,13 +458,15 @@ class ExprBinder : public ExprMutator, PatternMutator {
}

Var VisitVar(const Var& v) final {
CHECK(!args_map_.count(v))
CHECK(!args_map_.count(v->name_hint()))
<< "Cannnot bind an internal pattern variable";
return v;
}

private:
const tvm::Map<Var, Expr>& args_map_;
std::unordered_map<std::string, Expr> args_map_;
std::unordered_map<Var, Expr, NodeHash, NodeEqual> args_map2_;
//const tvm::Map<Var, Expr>& args_map_;
};

Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
Expand Down Expand Up @@ -439,7 +504,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return std::move(ret);
} else {
return ExprBinder(args_map).VisitExpr(expr);
return ExprBinder(args_map).BindTwice(expr);
}
}

Expand Down
7 changes: 5 additions & 2 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm", params=None):
exe = relay.vm.compile(mod, target, params=params)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
vm.set_inputs('main', *args, **params)
vm.set_inputs('main', *args)
print("Running")
ret = vm.invoke("main", *args)
return ret
Expand Down Expand Up @@ -671,6 +671,9 @@ def helper_rnn(cell_type):
for name, _ in params:
inputs.append(relay.var(name))
mod['main'] = relay.Function(inputs, relay.Call(relay_net, inputs))
import os
print("pid {}".format(os.getpid()))
input("hi")

l = 5
data_v = np.random.rand(l, batch, 128).astype('float32')
Expand All @@ -680,8 +683,8 @@ def helper_rnn(cell_type):
# pdb.set_trace()
aas = [data_v] + states_v
res = veval(mod, *aas, params=params_new)
res = vmobj_to_list(res)
print("Relay result is {}".format(res))
result = _eval_vm(mod, tvm.cpu(), data_v, *states_v)

mx_inputs = [mx.nd.array(x) for x in [data_v, *states_v]]
mx_outputs = mx_net(*mx_inputs)
Expand Down

0 comments on commit b901e86

Please sign in to comment.