Skip to content

Commit

Permalink
[TOP][COMPILER] Add expand_dims, change graph_compare to not compare …
Browse files Browse the repository at this point in the history
…input optionally (apache#25)
  • Loading branch information
tqchen committed May 29, 2018
1 parent 40bc10f commit 2e9b6b9
Show file tree
Hide file tree
Showing 17 changed files with 155 additions and 48 deletions.
11 changes: 11 additions & 0 deletions nnvm/include/nnvm/top/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ struct ConcatenateParam : public dmlc::Parameter<ConcatenateParam> {
}
};

struct ExpandDimsParam : public dmlc::Parameter<ExpandDimsParam> {
int axis;
int num_newaxis;
DMLC_DECLARE_PARAMETER(ExpandDimsParam) {
DMLC_DECLARE_FIELD(axis)
.describe("the axis to be expanded.");
DMLC_DECLARE_FIELD(num_newaxis).set_lower_bound(1).set_default(1)
.describe("Number of new axis to be inserted.");
}
};

struct SplitParam : public dmlc::Parameter<SplitParam> {
// numpy convention, only support indices, not support list.
Tuple<int> indices_or_sections;
Expand Down
2 changes: 2 additions & 0 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Namespace for building operators."""
from __future__ import absolute_import as _abs

import logging
import tvm
from . import graph_attr, graph_util
from .. import graph as _graph
Expand Down Expand Up @@ -74,6 +75,7 @@ def build_config(**kwargs):
@tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]

Expand Down
9 changes: 7 additions & 2 deletions nnvm/python/nnvm/compiler/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def infer_dtype(graph, **dtype):

_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")

def check_graph_equal(grapha, graphb):
def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
"""Check if two graphs have equal structure.
Parameters
Expand All @@ -70,11 +70,16 @@ def check_graph_equal(grapha, graphb):
graphb : Graph
The second graph
compare_variable_attrs : bool, optional
Whether we want to compare attributes(names) on variables.
Usually it is safe to skip it unless we want input name
to exactly match
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
err = _deep_compare(grapha, graphb, compare_variable_attrs)
if err:
raise ValueError("Graph compare error: " + err)
1 change: 1 addition & 0 deletions nnvm/python/nnvm/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utilities for testcase"""
12 changes: 12 additions & 0 deletions nnvm/python/nnvm/testing/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Configuration about tests"""
import os
import tvm

def test_ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("NNVM_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist and x[0] in device_list]
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def compute_conv2d(attrs, inputs, _):
raise ValueError("not support arbitrary group number for now")
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.broadcast_to(bias, (1, bias.shape[0], 1, 1))
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.broadcast_add(out, bias)
return out

Expand Down
12 changes: 12 additions & 0 deletions nnvm/python/nnvm/top/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def _compute(attrs, x, _):
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)

# pow_scalar
reg.register_compute("__pow_scalar__",
_compute_binary_scalar(tvm.power))
reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__pow_scalar__", _fschedule_broadcast)

# rpow_scalar
reg.register_compute("__rpow_scalar__",
_compute_binary_scalar(lambda x, y: tvm.power(y, x)))
reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__rpow_scalar__", _fschedule_broadcast)

# elemwise_add
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
Expand Down
10 changes: 10 additions & 0 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,21 @@
from __future__ import absolute_import

import tvm
import topi
from .tensor import _fschedule_broadcast
from ..compiler import registry as reg
from ..compiler import OpPattern

# Need add reshape, transpose
@reg.register_compute("expand_dims")
def compute_expand_dims(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
return topi.expand_dims(
inputs[0], attrs.get_int("axis"),
num_newaxis=attrs.get_int("num_newaxis"))
reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast)


def _flatten_index(indices, shape):
"""flatten the index to 1D"""
Expand Down
10 changes: 8 additions & 2 deletions nnvm/src/compiler/graph_deep_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ namespace compiler {
// not considering the graph attributes
// return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node.
std::string DeepCompare(Graph a, Graph b) {
// compare_var_attr
std::string DeepCompare(Graph a, Graph b,
bool compare_variable_attr) {
const IndexedGraph& idxa = a.indexed_graph();
const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err;
Expand Down Expand Up @@ -51,6 +53,10 @@ std::string DeepCompare(Graph a, Graph b) {
err << "Node mismatch ";
return err.str();
}
if (anode.source->is_variable()) {
CHECK(bnode.source->is_variable());
if (!compare_variable_attr) continue;
}
AttrDict adict = GetAttrDict(anode.source->attrs);
AttrDict bdict = GetAttrDict(bnode.source->attrs);

Expand Down Expand Up @@ -107,7 +113,7 @@ std::string DeepCompare(Graph a, Graph b) {

TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = DeepCompare(args[0], args[1]);
*rv = DeepCompare(args[0], args[1], args[2]);
});
} // namespace compiler
} // namespace nnvm
17 changes: 8 additions & 9 deletions nnvm/src/compiler/simplify_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,15 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
shift = MakeNode(
"elemwise_add", bn_name + "_add_beta", {shift, beta});
}
// use broaodcast to reshape
std::ostringstream oshape;
for (dim_t i = 0; i < dshape.ndim(); ++i) {
dshape[i] = (i != param.axis) ? 1 : -1;
// use expand dims to pad lower dims to 1
int num_pad_axis = static_cast<int>(dshape.ndim() - param.axis) - 1;
if (num_pad_axis != 0) {
std::unordered_map<std::string, std::string> kwargs{
{"axis", std::to_string(param.axis)},
{"num_newaxis", std::to_string(num_pad_axis)}};
scale = MakeNode("expand_dims", bn_name + "_sc_expand", {scale}, kwargs);
shift = MakeNode("expand_dims", bn_name + "_sh_expand", {shift}, kwargs);
}
oshape << dshape;
scale = MakeNode("reshape", bn_name + "_sc_reshape",
{scale}, {{"shape", oshape.str()}});
shift = MakeNode("reshape", bn_name + "_sh_reshape",
{shift}, {{"shape", oshape.str()}});
NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
{data, scale});
out = MakeNode("broadcast_add", bn_name + "_out",
Expand Down
45 changes: 44 additions & 1 deletion nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,51 @@ Example::
.set_num_inputs(kVarg)
.set_support_level(1);

// expand_dims
DMLC_REGISTER_PARAMETER(ExpandDimsParam);

inline bool ExpandDimsInferShape(const NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
const TShape& dshape = in_shape->at(0);
int ndim = static_cast<int>(dshape.ndim());
CHECK(param.axis >= -ndim - 1 && param.axis <= ndim);
int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis;
std::vector<dim_t> oshape;
for (int i = 0; i < axis; ++i) {
oshape.push_back(dshape[i]);
}
for (int i = 0; i < param.num_newaxis; ++i) {
oshape.push_back(1);
}
for (int i = axis; i < ndim; ++i) {
oshape.push_back(dshape[i]);
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0,
TShape(oshape.begin(), oshape.end()));
return true;
}

// concatenate
NNVM_REGISTER_OP(expand_dims)
.describe(R"code(Inserts a new axis of size 1 into the array shape
For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1)``
will return a new array with shape ``(2,1,3,4)``.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input tensor")
.add_arguments(ExpandDimsParam::__FIELDS__())
.set_attr_parser(ParamParser<ExpandDimsParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ExpandDimsParam>)
.set_attr<FInferShape>("FInferShape", ExpandDimsInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_support_level(1);

// split
DMLC_REGISTER_PARAMETER(SplitParam);

inline void SplitParamParser(nnvm::NodeAttrs* attrs) {
Expand Down
3 changes: 0 additions & 3 deletions nnvm/tests/python/compiler/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def verify(graph, lib):
assert graph.index.num_nodes == 4
verify(graph, lib)




def test_run():
x = sym.Variable("x")
y = sym.Variable("y")
Expand Down
9 changes: 6 additions & 3 deletions nnvm/tests/python/compiler/test_simplify_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ def simple_bn(x, gamma, beta, moving_mean, moving_var,
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
shape = [-1 if i == axis else 1 for i in range(len(shape))]
# for 2D
scale = sym.reshape(scale, shape=shape)
shift = sym.reshape(shift, shape=shape)
num_newaxis=len(shape) - axis - 1
if num_newaxis:
scale = sym.expand_dims(scale, axis=axis, num_newaxis=num_newaxis)
shift = sym.expand_dims(shift, axis=axis, num_newaxis=num_newaxis)
return x * scale + shift


Expand All @@ -25,7 +27,7 @@ def check(dim, axis, nstep):
gamma = sym.Variable("gamma")
moving_var = sym.Variable("moving_var")
moving_mean = sym.Variable("moving_mean")
y1, y2 = x, x
y1, y2 = x, sym.Variable("xx") + 1
ishape = {"x": tuple(10 for i in range(dim))}
for i in range(nstep):
y1 = sym.batch_norm(
Expand All @@ -44,6 +46,7 @@ def check(dim, axis, nstep):

check(2, 1, 1)
check(4, 0, 3)
check(4, 1, 2)

if __name__ == "__main__":
test_simplify_batchnorm()
21 changes: 8 additions & 13 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import numpy as np

import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime

def ctx_list():
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist]

from nnvm.testing.config import test_ctx_list

def test_relu():
x = sym.Variable("x")
y = sym.relu(x)
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
Expand All @@ -40,7 +35,7 @@ def test_exp():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
Expand All @@ -63,7 +58,7 @@ def test_log():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
Expand All @@ -87,7 +82,7 @@ def test_tanh():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
Expand All @@ -111,7 +106,7 @@ def test_sigmoid():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
Expand All @@ -134,7 +129,7 @@ def test_softmax():
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
Expand Down Expand Up @@ -187,7 +182,7 @@ def test_batchnorm():
y = sym.batch_norm(
x, gamma, beta, moving_mean, moving_var, epsilon=eps)

for target, ctx in ctx_list():
for target, ctx in test_ctx_list():
graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape})
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
x_np = np.random.uniform(size=shape).astype(dtype)
Expand Down
Loading

0 comments on commit 2e9b6b9

Please sign in to comment.