From 2e9b6b9959fee29051aa29985f29c46cea51d8cd Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 22 Sep 2017 15:46:56 -0700 Subject: [PATCH] [TOP][COMPILER] Add expand_dims, change graph_compare to not compare input optionally (#25) --- nnvm/include/nnvm/top/tensor.h | 11 +++++ nnvm/python/nnvm/compiler/build_module.py | 2 + nnvm/python/nnvm/compiler/graph_util.py | 9 +++- nnvm/python/nnvm/testing/__init__.py | 1 + nnvm/python/nnvm/testing/config.py | 12 +++++ nnvm/python/nnvm/top/nn.py | 2 +- nnvm/python/nnvm/top/tensor.py | 12 +++++ nnvm/python/nnvm/top/transform.py | 10 +++++ nnvm/src/compiler/graph_deep_compare.cc | 10 ++++- nnvm/src/compiler/simplify_inference.cc | 17 ++++--- nnvm/src/top/tensor/transform.cc | 45 ++++++++++++++++++- nnvm/tests/python/compiler/test_build.py | 3 -- .../compiler/test_simplify_inference.py | 9 ++-- nnvm/tests/python/compiler/test_top_level1.py | 21 ++++----- nnvm/tests/python/compiler/test_top_level2.py | 22 ++++----- .../tests/python/unittest/test_infer_shape.py | 11 +++++ nnvm/tests/python/unittest/test_top_level1.py | 6 +++ 17 files changed, 155 insertions(+), 48 deletions(-) create mode 100644 nnvm/python/nnvm/testing/__init__.py create mode 100644 nnvm/python/nnvm/testing/config.py diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 45541a151d57..23fe9a1e9644 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -21,6 +21,17 @@ struct ConcatenateParam : public dmlc::Parameter { } }; +struct ExpandDimsParam : public dmlc::Parameter { + int axis; + int num_newaxis; + DMLC_DECLARE_PARAMETER(ExpandDimsParam) { + DMLC_DECLARE_FIELD(axis) + .describe("the axis to be expanded."); + DMLC_DECLARE_FIELD(num_newaxis).set_lower_bound(1).set_default(1) + .describe("Number of new axis to be inserted."); + } +}; + struct SplitParam : public dmlc::Parameter { // numpy convention, only support indices, not support list. Tuple indices_or_sections; diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 24d7a85dea99..27d86e9a6967 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -2,6 +2,7 @@ """Namespace for building operators.""" from __future__ import absolute_import as _abs +import logging import tvm from . import graph_attr, graph_util from .. import graph as _graph @@ -74,6 +75,7 @@ def build_config(**kwargs): @tvm.register_func("nnvm.compiler.lower") def _lower(sch, inputs, func_name): f = tvm.lower(sch, inputs, name=func_name) + logging.debug("lower function %s", func_name) return f if isinstance( f, (tvm.container.Array, tuple, list)) else [f] diff --git a/nnvm/python/nnvm/compiler/graph_util.py b/nnvm/python/nnvm/compiler/graph_util.py index fcca00b0abe0..68ce857264b4 100644 --- a/nnvm/python/nnvm/compiler/graph_util.py +++ b/nnvm/python/nnvm/compiler/graph_util.py @@ -59,7 +59,7 @@ def infer_dtype(graph, **dtype): _deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare") -def check_graph_equal(grapha, graphb): +def check_graph_equal(grapha, graphb, compare_variable_attrs=False): """Check if two graphs have equal structure. Parameters @@ -70,11 +70,16 @@ def check_graph_equal(grapha, graphb): graphb : Graph The second graph + compare_variable_attrs : bool, optional + Whether we want to compare attributes(names) on variables. + Usually it is safe to skip it unless we want input name + to exactly match + Raises ------ ValueError ValueError is raised with error message when graph not equal """ - err = _deep_compare(grapha, graphb) + err = _deep_compare(grapha, graphb, compare_variable_attrs) if err: raise ValueError("Graph compare error: " + err) diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py new file mode 100644 index 000000000000..1241e403b1a0 --- /dev/null +++ b/nnvm/python/nnvm/testing/__init__.py @@ -0,0 +1 @@ +"""Utilities for testcase""" diff --git a/nnvm/python/nnvm/testing/config.py b/nnvm/python/nnvm/testing/config.py new file mode 100644 index 000000000000..26d1d41014cf --- /dev/null +++ b/nnvm/python/nnvm/testing/config.py @@ -0,0 +1,12 @@ +"""Configuration about tests""" +import os +import tvm + +def test_ctx_list(): + """Get context list for testcases""" + device_list = os.environ.get("NNVM_TEST_TARGETS", "") + device_list = (device_list.split(",") if device_list + else ["llvm", "cuda"]) + device_list = set(device_list) + res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))] + return [x for x in res if x[1].exist and x[0] in device_list] diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 5b0dfe2fe145..71246bc0823e 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -90,7 +90,7 @@ def compute_conv2d(attrs, inputs, _): raise ValueError("not support arbitrary group number for now") if attrs.get_bool("use_bias"): bias = inputs[2] - bias = topi.broadcast_to(bias, (1, bias.shape[0], 1, 1)) + bias = topi.expand_dims(bias, axis=1, num_newaxis=2) out = topi.broadcast_add(out, bias) return out diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index 05396d65abeb..c427f49ab8d2 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -115,6 +115,18 @@ def _compute(attrs, x, _): reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE) reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) +# pow_scalar +reg.register_compute("__pow_scalar__", + _compute_binary_scalar(tvm.power)) +reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE) +reg.register_schedule("__pow_scalar__", _fschedule_broadcast) + +# rpow_scalar +reg.register_compute("__rpow_scalar__", + _compute_binary_scalar(lambda x, y: tvm.power(y, x))) +reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE) +reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) + # elemwise_add reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add)) reg.register_pattern("elemwise_add", OpPattern.BROADCAST) diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index 89e3c64a05ce..e7419c030df4 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -3,11 +3,21 @@ from __future__ import absolute_import import tvm +import topi from .tensor import _fschedule_broadcast from ..compiler import registry as reg from ..compiler import OpPattern # Need add reshape, transpose +@reg.register_compute("expand_dims") +def compute_expand_dims(attrs, inputs, out_info): + """Compute definition of expand_dims""" + return topi.expand_dims( + inputs[0], attrs.get_int("axis"), + num_newaxis=attrs.get_int("num_newaxis")) +reg.register_pattern("expand_dims", OpPattern.BROADCAST) +reg.register_schedule("expand_dims", _fschedule_broadcast) + def _flatten_index(indices, shape): """flatten the index to 1D""" diff --git a/nnvm/src/compiler/graph_deep_compare.cc b/nnvm/src/compiler/graph_deep_compare.cc index dd64f0e3b062..df578165ed6b 100644 --- a/nnvm/src/compiler/graph_deep_compare.cc +++ b/nnvm/src/compiler/graph_deep_compare.cc @@ -16,7 +16,9 @@ namespace compiler { // not considering the graph attributes // return non-empty error message if the graph mismatch. // the comparator won't match name of intermediate node. -std::string DeepCompare(Graph a, Graph b) { +// compare_var_attr +std::string DeepCompare(Graph a, Graph b, + bool compare_variable_attr) { const IndexedGraph& idxa = a.indexed_graph(); const IndexedGraph& idxb = b.indexed_graph(); std::ostringstream err; @@ -51,6 +53,10 @@ std::string DeepCompare(Graph a, Graph b) { err << "Node mismatch "; return err.str(); } + if (anode.source->is_variable()) { + CHECK(bnode.source->is_variable()); + if (!compare_variable_attr) continue; + } AttrDict adict = GetAttrDict(anode.source->attrs); AttrDict bdict = GetAttrDict(bnode.source->attrs); @@ -107,7 +113,7 @@ std::string DeepCompare(Graph a, Graph b) { TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare") .set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { - *rv = DeepCompare(args[0], args[1]); + *rv = DeepCompare(args[0], args[1], args[2]); }); } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/simplify_inference.cc b/nnvm/src/compiler/simplify_inference.cc index ad27d885f68e..7dd9ade0ac96 100644 --- a/nnvm/src/compiler/simplify_inference.cc +++ b/nnvm/src/compiler/simplify_inference.cc @@ -58,16 +58,15 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, shift = MakeNode( "elemwise_add", bn_name + "_add_beta", {shift, beta}); } - // use broaodcast to reshape - std::ostringstream oshape; - for (dim_t i = 0; i < dshape.ndim(); ++i) { - dshape[i] = (i != param.axis) ? 1 : -1; + // use expand dims to pad lower dims to 1 + int num_pad_axis = static_cast(dshape.ndim() - param.axis) - 1; + if (num_pad_axis != 0) { + std::unordered_map kwargs{ + {"axis", std::to_string(param.axis)}, + {"num_newaxis", std::to_string(num_pad_axis)}}; + scale = MakeNode("expand_dims", bn_name + "_sc_expand", {scale}, kwargs); + shift = MakeNode("expand_dims", bn_name + "_sh_expand", {shift}, kwargs); } - oshape << dshape; - scale = MakeNode("reshape", bn_name + "_sc_reshape", - {scale}, {{"shape", oshape.str()}}); - shift = MakeNode("reshape", bn_name + "_sh_reshape", - {shift}, {{"shape", oshape.str()}}); NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data", {data, scale}); out = MakeNode("broadcast_add", bn_name + "_out", diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index c9eaf0c0ee6e..da93fe4cf26f 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -142,8 +142,51 @@ Example:: .set_num_inputs(kVarg) .set_support_level(1); +// expand_dims +DMLC_REGISTER_PARAMETER(ExpandDimsParam); + +inline bool ExpandDimsInferShape(const NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const ExpandDimsParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U); + const TShape& dshape = in_shape->at(0); + int ndim = static_cast(dshape.ndim()); + CHECK(param.axis >= -ndim - 1 && param.axis <= ndim); + int axis = param.axis < 0 ? ndim + param.axis + 1 : param.axis; + std::vector oshape; + for (int i = 0; i < axis; ++i) { + oshape.push_back(dshape[i]); + } + for (int i = 0; i < param.num_newaxis; ++i) { + oshape.push_back(1); + } + for (int i = axis; i < ndim; ++i) { + oshape.push_back(dshape[i]); + } + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, + TShape(oshape.begin(), oshape.end())); + return true; +} -// concatenate +NNVM_REGISTER_OP(expand_dims) +.describe(R"code(Inserts a new axis of size 1 into the array shape + +For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1)`` +will return a new array with shape ``(2,1,3,4)``. + +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "Input tensor") +.add_arguments(ExpandDimsParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_attr("FInferShape", ExpandDimsInferShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_num_inputs(1) +.set_num_outputs(1) +.set_support_level(1); + +// split DMLC_REGISTER_PARAMETER(SplitParam); inline void SplitParamParser(nnvm::NodeAttrs* attrs) { diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py index bb10e8400bb1..379975d2d6a4 100644 --- a/nnvm/tests/python/compiler/test_build.py +++ b/nnvm/tests/python/compiler/test_build.py @@ -40,9 +40,6 @@ def verify(graph, lib): assert graph.index.num_nodes == 4 verify(graph, lib) - - - def test_run(): x = sym.Variable("x") y = sym.Variable("y") diff --git a/nnvm/tests/python/compiler/test_simplify_inference.py b/nnvm/tests/python/compiler/test_simplify_inference.py index eaff669a904b..5da46608081a 100644 --- a/nnvm/tests/python/compiler/test_simplify_inference.py +++ b/nnvm/tests/python/compiler/test_simplify_inference.py @@ -12,8 +12,10 @@ def simple_bn(x, gamma, beta, moving_mean, moving_var, sym.elemwise_mul(sym.negative(moving_mean), scale), beta) shape = [-1 if i == axis else 1 for i in range(len(shape))] # for 2D - scale = sym.reshape(scale, shape=shape) - shift = sym.reshape(shift, shape=shape) + num_newaxis=len(shape) - axis - 1 + if num_newaxis: + scale = sym.expand_dims(scale, axis=axis, num_newaxis=num_newaxis) + shift = sym.expand_dims(shift, axis=axis, num_newaxis=num_newaxis) return x * scale + shift @@ -25,7 +27,7 @@ def check(dim, axis, nstep): gamma = sym.Variable("gamma") moving_var = sym.Variable("moving_var") moving_mean = sym.Variable("moving_mean") - y1, y2 = x, x + y1, y2 = x, sym.Variable("xx") + 1 ishape = {"x": tuple(10 for i in range(dim))} for i in range(nstep): y1 = sym.batch_norm( @@ -44,6 +46,7 @@ def check(dim, axis, nstep): check(2, 1, 1) check(4, 0, 3) + check(4, 1, 2) if __name__ == "__main__": test_simplify_batchnorm() diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 5822e58f995b..dc3423e67322 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -1,15 +1,10 @@ import numpy as np - import tvm import topi import nnvm.symbol as sym import nnvm.compiler import nnvm.runtime - -def ctx_list(): - res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))] - return [x for x in res if x[1].exist] - +from nnvm.testing.config import test_ctx_list def test_relu(): x = sym.Variable("x") @@ -17,7 +12,7 @@ def test_relu(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -40,7 +35,7 @@ def test_exp(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -63,7 +58,7 @@ def test_log(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) @@ -87,7 +82,7 @@ def test_tanh(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) @@ -111,7 +106,7 @@ def test_sigmoid(): dtype = "float32" dshape = (1, 3, 32, 32) oshape = dshape - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -134,7 +129,7 @@ def test_softmax(): dtype = "float32" dshape = (10, 1000) oshape = dshape - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): with nnvm.compiler.build_config(opt_level=1): graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = nnvm.runtime.create(graph, lib, ctx) @@ -187,7 +182,7 @@ def test_batchnorm(): y = sym.batch_norm( x, gamma, beta, moving_mean, moving_var, epsilon=eps) - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape}) m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) x_np = np.random.uniform(size=shape).astype(dtype) diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index 7e39bff4017f..793d6d3e955f 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -5,10 +5,8 @@ import nnvm.symbol as sym import nnvm.compiler import nnvm.runtime +from nnvm.testing.config import test_ctx_list -def ctx_list(): - res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))] - return [x for x in res if x[1].exist] def test_conv2d(): x = sym.Variable("x") @@ -19,7 +17,7 @@ def test_conv2d(): kshape = (10, 3, 3, 3) oshape = (1, 10, 18, 18) shape_dict = {"x": dshape} - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) # get member functions @@ -42,29 +40,25 @@ def test_conv2d(): def test_grouped_conv2d(): x = sym.Variable("x") y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32, - name="y", use_bias=False, padding=(1,1)) + name="y", padding=(1,1)) dtype = "float32" dshape = (1, 32, 18, 18) kshape = (32, 1, 3, 3) oshape = (1, 32, 18, 18) shape_dict = {"x": dshape} - for target, ctx in ctx_list(): + for target, ctx in test_ctx_list(): graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) m = nnvm.runtime.create(graph, lib, ctx) - # get member functions - set_input, run, get_output = m["set_input"], m["run"], m["get_output"] # set input data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) - set_input("x", data) - set_input("y_weight", kernel) - # execute - run() + bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype)) + m.run(x=data, y_weight=kernel, y_bias=bias) # get output - out = tvm.nd.empty(oshape, dtype) - get_output(0, out) + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) c_np = topi.testing.depthwise_conv2d_python_nchw( data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') + c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) diff --git a/nnvm/tests/python/unittest/test_infer_shape.py b/nnvm/tests/python/unittest/test_infer_shape.py index 77f14573b74e..251ea36f6a53 100644 --- a/nnvm/tests/python/unittest/test_infer_shape.py +++ b/nnvm/tests/python/unittest/test_infer_shape.py @@ -34,6 +34,16 @@ def test_concatenate(): assert(sdict["concat"][0] == [20, 20]) +def test_expand_dims(): + x = sym.Variable("x", shape=(10, 20)) + y = sym.expand_dims(x, axis=1, name="y") + sdict = infer_shape(y) + assert(sdict["y"][0] == [10, 1, 20]) + y = sym.expand_dims(x, axis=-1, name="y", num_newaxis=2) + sdict = infer_shape(y) + assert(sdict["y"][0] == [10, 20, 1, 1]) + + def test_split(): x1 = sym.Variable("x", shape=(10, 20)) z = sym.split(x1, indices_or_sections=[11], name="y") @@ -247,6 +257,7 @@ def check(in_shape, out_shape, **kwargs): if __name__ == "__main__": + test_expand_dims() test_dense() test_concatenate() test_split() diff --git a/nnvm/tests/python/unittest/test_top_level1.py b/nnvm/tests/python/unittest/test_top_level1.py index 6b75cf1d83f1..cf0baf4c3696 100644 --- a/nnvm/tests/python/unittest/test_top_level1.py +++ b/nnvm/tests/python/unittest/test_top_level1.py @@ -19,6 +19,11 @@ def test_concatenate_split(): z = sym.split(y, indices_or_sections=[10, 20]) assert len(z.list_output_names()) == 3 +def test_expand_dims(): + x = sym.Variable('x') + y = sym.expand_dims(x, axis=1, num_newaxis=2) + assert y.list_input_names() == ['x'] + def test_unary(): x = sym.Variable('x') @@ -39,6 +44,7 @@ def test_batchnorm(): if __name__ == "__main__": test_concatenate_split() + test_expand_dims() test_dense() test_unary() test_batchnorm()