Skip to content

Commit

Permalink
Enable control deps in API (#55)
Browse files Browse the repository at this point in the history
* [SYMBOL] support control deps in API

* enable more generic tuple list

* fix

* fix
  • Loading branch information
tqchen committed May 29, 2018
1 parent 5575338 commit 3a469d5
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 8 deletions.
7 changes: 7 additions & 0 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out);
NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
/*!
* \brief Add src_dep to the handle as control dep.
* \param handle The symbol to add dependency edges on.
* \param src_dep the source handles.
*/
NNVM_DLL int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
Expand Down
6 changes: 4 additions & 2 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,10 @@ inline Op& Op::set_attr( // NOLINT(*)
std::vector<std::pair<ValueType, int> >& vec =
nnvm::get<OpMap<ValueType> >(*pmap).data_;
// resize the value type.
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
if (vec.size() <= index_) {
vec.resize(index_ + 1,
std::make_pair(ValueType(), 0));
}
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0)
<< "Attribute " << attr_name
Expand Down
8 changes: 4 additions & 4 deletions nnvm/include/nnvm/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class Tuple {
return is;
}
is.get();
if (ch == '(') break;
if (ch == '(' || ch == '[') break;
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
Expand All @@ -250,13 +250,13 @@ class Tuple {
if (isspace(ch)) {
is.get(); continue;
}
if (ch == ')') {
if (ch == ')' || ch == ']') {
is.get(); break;
}
break;
}
if (ch == ')') break;
} else if (ch == ')') {
if (ch == ')' || ch == ']') break;
} else if (ch == ')' || ch == ']') {
break;
} else {
is.setstate(std::ios::failbit);
Expand Down
16 changes: 16 additions & 0 deletions nnvm/python/nnvm/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,22 @@ def debug_str(self):
self.handle, _ctypes.byref(debug_str)))
return _base.py_str(debug_str.value)

def _add_control_deps(self, deps):
"""Add control flow dependencies.
This makes current op depend on the deps.
Only use when necessary,
this function mutate the current symbol node.
Returns
-------
deps : Symbol for list of symbol
The dependencies
"""
if isinstance(deps, list):
deps = Group(deps)
_check_call(_LIB.NNAddControlDeps(
self.handle, deps.handle))


def Variable(name, **kwargs):
"""Create a symbolic variable with specified name.
Expand Down
10 changes: 9 additions & 1 deletion nnvm/src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ int NNGetOpHandle(const char* op_name,
}

int NNListUniqueOps(nn_uint *out_size,
OpHandle **out_array) {
OpHandle **out_array) {
API_BEGIN();
auto &vec = dmlc::Registry<Op>::List();
*out_size = static_cast<nn_uint>(vec.size());
*out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}

int NNAddControlDeps(SymbolHandle handle,
SymbolHandle src_dep) {
API_BEGIN();
static_cast<Symbol*>(handle)->AddControlDeps(
*static_cast<Symbol*>(src_dep));
API_END();
}

int NNGetOpInfo(OpHandle handle,
const char **name,
const char **description,
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def test_infer_shape_known_partial():


def test_infer_type():
x = sym.Variable('x')
x = sym.Variable('x', dtype=0)
y = sym.add(x, x, name='add1')
y = sym.cast(y, dtype=1, name="cast1")
g = graph.create(y)
g._set_json_attr("dtype_attr_key", "dtype")
g = g.apply('InferType')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
Expand Down
8 changes: 8 additions & 0 deletions nnvm/tests/python/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,16 @@ def test_copy():
name='exp', gpu=1, attr={"kk": "1"})
assert y.__copy__().debug_str() == y.debug_str()

def test_control_dep():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
z = sym.assign(x, y)
t = sym.add(x, x)
t._add_control_deps([z, y])

if __name__ == "__main__":
test_copy()
test_default_input()
test_compose()
test_mutate_input()
test_control_dep()

0 comments on commit 3a469d5

Please sign in to comment.