Skip to content

Commit

Permalink
[OP] Experimental assign op (apache#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 74a3f74 commit 6a7d1c4
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 5 deletions.
6 changes: 6 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def _compute(attrs, x, _):
_fschedule_broadcast = _fschedule_injective
_fschedule_elemwise = _fschedule_injective

# Assign requires special treatment in the compiler
# The compute and schedule are designed as
# copy from rhs to output
reg.register_pattern("_assign", OpPattern.OPAQUE)
reg.register_schedule("_assign", _fschedule_broadcast)

# copy
reg.register_pattern("copy", OpPattern.ELEMWISE)
reg.register_schedule("copy", _fschedule_broadcast)
Expand Down
92 changes: 89 additions & 3 deletions nnvm/src/compiler/graph_fuse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,48 @@ NNVM_REGISTER_PASS(GraphFusePartition)
.depend_graph_attr("shape")
.depend_graph_attr("dtype");


// Decorate the result of PlanMemory
// This function does two things:
// - Give separate memory to each variable
// - Tie the memory of output/lhs in assign node properly
// so the execution of assign can have side effect.
nnvm::Graph DecorateMemoryPlan(
nnvm::Graph g,
const std::vector<int>& assign_flag) {
// setup ref counter
const IndexedGraph& idx = g.indexed_graph();
StorageVector storage_vec = g.MoveCopyAttr<StorageVector>("storage_id");
g.attrs.erase("storage_allocated_bytes");
g.attrs.erase("storage_inplace_index");
size_t num_not_allocated = g.MoveCopyAttr<size_t>(
"storage_num_not_allocated");
CHECK_EQ(num_not_allocated, 0U)
<< "Can only build inference graph with all statically allocated memory";

// reassign variable id so that they are different.
int max_id = 0;
for (size_t i = 0; i < storage_vec.size(); ++i) {
max_id = std::max(storage_vec[i] + 1, max_id);
}
for (uint32_t nid : idx.input_nodes()) {
storage_vec[idx.entry_id(nid, 0)] = max_id++;
}
// tie up the assign node storage properly
for (uint32_t nid = 0 ; nid < idx.num_nodes(); ++nid) {
if (assign_flag[nid] == 0) continue;
const auto& inode = idx[nid];
int var_storage_id = storage_vec[idx.entry_id(inode.inputs[0])];
storage_vec[idx.entry_id(nid, 0)] = var_storage_id;

if (assign_flag[nid] == 2) {
storage_vec[idx.entry_id(inode.inputs[1])] = var_storage_id;
}
}
g.attrs["storage_id"] = std::make_shared<any>(std::move(storage_vec));
return g;
}

struct INodeEntryHash {
size_t operator()(const IndexedGraph::NodeEntry& e) const {
return e.node_id;
Expand Down Expand Up @@ -218,8 +260,12 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
g.GetAttr<std::vector<TOpPattern> >("pattern");
std::string target = g.GetAttr<std::string>("target");
std::string target_host;
if (g.HasAttr("target_host"))

if (g.HasAttr("target_host")) {
target_host = g.GetAttr<std::string>("target_host");
}
// specially handle assign
const nnvm::Op* assign_op = nnvm::Op::Get("_assign");

std::vector<FuseEntry> fuse_vec(idx.num_nodes());
// setup inputs and placeholder.
Expand All @@ -229,7 +275,8 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
CHECK_GE(group_vec[nid], 0);
int root_id = group_vec[nid];
FuseEntry& fe = fuse_vec[root_id];
fe.flatten_data = (pattern_vec[root_id] == kElemWise);
fe.flatten_data = (pattern_vec[root_id] == kElemWise ||
inode.source->op() == assign_op);
for (const auto& e : inode.inputs) {
if (group_vec[e.node_id] != root_id && fe.imap.count(e) == 0) {
Array<Expr> shape;
Expand Down Expand Up @@ -331,8 +378,9 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
}
}
}
// Rebuild the fused graph

const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");

std::unordered_map<uint32_t, nnvm::NodePtr> old_new;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
Expand All @@ -345,6 +393,8 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
}
int root_id = group_vec[nid];
if (static_cast<int>(nid) != root_id) continue;

// Handle normal op
FuseEntry& fe = fuse_vec[root_id];
const IndexedGraph& subidx = fe.subgraph.indexed_graph();
nnvm::NodePtr np = nnvm::Node::Create();
Expand Down Expand Up @@ -385,13 +435,48 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
nnvm::NodeEntry{it->second, e.index, e.version});
}

// Reference counter of each op node
// For now, always store result when an op is referred more than once.
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
for (const auto& e : idx.outputs()) {
// this line will realize all the outputs
ref_count[e.node_id] += 1;
}

const IndexedGraph& new_idx = ret.indexed_graph();

// Handling assign:
//
// assign is a special operator that mutates the variable.
// Currently assign is implemented as output = copy(input[1])
// Then we run DecorageMemoryPlan to force
// output.storage = input[0].storage
//
std::vector<int> assign_flag(new_idx.num_nodes(), 0);
ShapeVector new_shape_vec = ShapeVector(new_idx.num_node_entries(), TShape());
DTypeVector new_dtype_vec = DTypeVector(new_idx.num_node_entries());
std::vector<std::string> new_dltype_vec(new_idx.num_node_entries());

for (const auto& kv : old_new) {
uint32_t nid = kv.first;
const auto& inode = idx[nid];
uint32_t new_nid = new_idx.node_id(kv.second.get());
if (inode.source->op() == assign_op) {
// Check if rhs of assign can be comute inplace
// If yes, we can simply set that memory to be assign target
// and change assign to nop
const IndexedGraph::NodeEntry& rhs = inode.inputs[1];
if (ref_count[rhs.node_id] <= 1 &&
!(idx[rhs.node_id].source->is_variable()) &&
pattern_vec[group_vec[rhs.node_id]] <= kBroadcast) {
assign_flag[new_nid] = 2;
TVMOpParam& param = dmlc::get<TVMOpParam>(kv.second->attrs.parsed);
param.func_name = "__nop";
param.UpdateDict(&(kv.second->attrs.dict));
} else {
assign_flag[new_nid] = 1;
}
}
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
uint32_t new_eid = new_idx.entry_id(new_idx.node_id(kv.second.get()), i);
uint32_t old_eid = idx.entry_id(nid, i);
Expand All @@ -409,6 +494,7 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
tvm::runtime::Module module = fbuild(func_list, target, target_host);
ret.attrs["module"] = std::make_shared<any>(std::move(module));
ret = nnvm::ApplyPass(ret, "PlanMemory");
ret = DecorateMemoryPlan(ret, assign_flag);
return ret;
}

Expand Down
6 changes: 4 additions & 2 deletions nnvm/src/pass/plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,11 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) {
uint32_t eid = rit->second;
auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
storage_ref_count[sid] = entry_ref_count[eid];
if (sid >= 0) {
storage_ref_count[sid] = entry_ref_count[eid];
}
storage[eid] = sid;
}

// check if certain inputs is ignored.
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
std::vector<uint32_t> ignore_inputs;
Expand Down Expand Up @@ -330,6 +331,7 @@ Graph PlanMemory(Graph ret) {
AllocMemory(ret, idx, node_range, &storage_vec, &storage_inplace_index,
ref_count, &allocator);
size_t storage_allocated_bytes = allocator.TotalAllocBytes();

// Choose the plan which leads to minimal memory usage
if (min_allocated_bytes > storage_allocated_bytes) {
ret.attrs["storage_id"] = std::make_shared<any>(std::move(storage_vec));
Expand Down
2 changes: 2 additions & 0 deletions nnvm/src/top/tensor/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

namespace nnvm {
namespace top {

using namespace tvm;
using namespace nnvm::compiler;

// undefined op
NNVM_REGISTER_ELEMWISE_UNARY_OP(__undef__)
.describe(R"code(undefined op.
Expand Down
54 changes: 54 additions & 0 deletions nnvm/src/top/tensor/state_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*!
* Copyright (c) 2018 by Contributors
* \file state_op.cc
* \brief Experimental operators
* Currently we only support assign
*/
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/tensor.h>
#include <topi/elemwise.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"

namespace nnvm {
namespace top {

using namespace tvm;
using namespace nnvm::compiler;

NNVM_REGISTER_OP(_assign)
.describe(R"doc(Assign rhs to the lhs.
lhs must be a Variable.
This is an experimental operator.
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FMutateInputs>(
"FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
})
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
// This implementation is needed for the special
// logic handling assign in the compiler
// It simply copies the result of rhs the output
// The later decoration in compiler will change
// the memory assignment of assign to tie
// the lhs to the output.
return Array<Tensor>{ topi::identity(inputs[1]) };
})
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{1, 0}};
});

} // namespace top
} // namespace nnvm
41 changes: 41 additions & 0 deletions nnvm/tests/python/compiler/test_top_assign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np

import tvm
from tvm.contrib import graph_runtime

import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list


def test_update():
w = sym.Variable("w")
w2 = sym.Variable("w2")
w = sym._assign(w, w + 1)
w2 = sym._assign(w2, w + 1)

dshape = (5, 3, 18, 18)
shape_dict = {"w": dshape, "w2":dshape}
dtype = "float32"

def check(target, ctx):
graph, lib, _ = nnvm.compiler.build(w2, target, shape_dict)

m = graph_runtime.create(graph, lib, ctx)

data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.set_input("w", data)
m.run()
out = m.get_input("w2", tvm.nd.empty(dshape, dtype))
np.testing.assert_allclose(out.asnumpy(), data.asnumpy() + 2, rtol=1e-5)

m.run()
out = m.get_input("w2", tvm.nd.empty(dshape, dtype))
np.testing.assert_allclose(out.asnumpy(), data.asnumpy() + 3, rtol=1e-5)

for target, ctx in ctx_list():
check(target, ctx)


if __name__ == "__main__":
test_update()

0 comments on commit 6a7d1c4

Please sign in to comment.