From 3f599a60bed40568722273ab3d7715e25adf5c42 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Thu, 28 Sep 2017 13:06:26 -0500 Subject: [PATCH] add squeeze (#52) * add transform * fix * update doc * Update tvm --- nnvm/docs/top.rst | 2 + nnvm/include/nnvm/top/tensor.h | 10 +++ nnvm/python/nnvm/compiler/compile_engine.py | 2 +- nnvm/python/nnvm/top/transform.py | 10 +++ nnvm/src/top/op_common.h | 1 + nnvm/src/top/tensor/transform.cc | 74 +++++++++++++++++++ nnvm/tests/python/compiler/test_top_level1.py | 25 +++++++ nnvm/tests/python/compiler/test_top_level4.py | 1 - 8 files changed, 123 insertions(+), 2 deletions(-) diff --git a/nnvm/docs/top.rst b/nnvm/docs/top.rst index 89af46509d02..65beccde0c5c 100644 --- a/nnvm/docs/top.rst +++ b/nnvm/docs/top.rst @@ -41,6 +41,7 @@ This level enables fully connected multi-layer perceptron. nnvm.symbol.flatten nnvm.symbol.concatenate nnvm.symbol.expand_dims + nnvm.symbol.squeeze nnvm.symbol.split nnvm.symbol.dropout nnvm.symbol.batch_norm @@ -112,6 +113,7 @@ Detailed Definitions .. autofunction:: nnvm.symbol.flatten .. autofunction:: nnvm.symbol.concatenate .. autofunction:: nnvm.symbol.expand_dims +.. autofunction:: nnvm.symbol.squeeze .. autofunction:: nnvm.symbol.split .. autofunction:: nnvm.symbol.dropout .. autofunction:: nnvm.symbol.batch_norm diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 23fe9a1e9644..16ed259dac69 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -79,6 +79,16 @@ struct ReshapeParam : public dmlc::Parameter { } }; +struct SqueezeParam : public dmlc::Parameter { + TShape axis; + + DMLC_DECLARE_PARAMETER(SqueezeParam) { + DMLC_DECLARE_FIELD(axis).set_default(TShape()) + .describe("The axis to squeeze in the input tensor." + " If set to None, all size=1 axes will be squeezed"); + } +}; + struct ScalarParam : public dmlc::Parameter { double scalar; diff --git a/nnvm/python/nnvm/compiler/compile_engine.py b/nnvm/python/nnvm/compiler/compile_engine.py index b215abf6eac2..289f09deb280 100644 --- a/nnvm/python/nnvm/compiler/compile_engine.py +++ b/nnvm/python/nnvm/compiler/compile_engine.py @@ -47,7 +47,7 @@ def items(self): """ res = _list_cache_items() assert len(res) % 2 == 0 - return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)] + return [(res[2*i], res[2*i+1]) for i in range(len(res) // 2)] def clear_cache(self): """Clear the existing cached functions.""" diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index a82511943ebf..21c10f5cccf4 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -36,6 +36,16 @@ def compute_reshape(attrs, inputs, out_info): reg.register_pattern("reshape", OpPattern.INJECTIVE) reg.register_schedule("reshape", _fschedule_injective) +# reshape +@reg.register_compute("squeeze") +def compute_squeeze(attrs, inputs, out_info): + """Compute definition of reshape""" + axis = attrs.get_int_tuple("axis") + axis = tuple(axis) if axis else None + return topi.squeeze(inputs[0], axis) +reg.register_pattern("squeeze", OpPattern.INJECTIVE) +reg.register_schedule("squeeze", _fschedule_injective) + # concatenate @reg.register_compute("concatenate") def compute_concatenate(attrs, inputs, out_info): diff --git a/nnvm/src/top/op_common.h b/nnvm/src/top/op_common.h index 110e08a5841e..e23059a88c42 100644 --- a/nnvm/src/top/op_common.h +++ b/nnvm/src/top/op_common.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace nnvm { namespace top { diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 0fe772e28e47..12ae06731145 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -445,6 +445,80 @@ The significance of each is explained below: .set_num_outputs(1) .set_support_level(3); +// squeeze +DMLC_REGISTER_PARAMETER(SqueezeParam); + +inline bool SqueezeShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + const SqueezeParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const TShape& shp = (*in_attrs)[0]; + if (shp.ndim() == 0) return false; + + std::vector oshape; + if (param.axis.ndim() == 0) { + for (dim_t i = 0; i < shp.ndim(); ++i) { + if(shp[i] != 1) { + oshape.emplace_back(shp[i]); + } + } + } else { + std::unordered_set axis_checker; + for (size_t i = 0; i < param.axis.ndim(); ++i) { + if(param.axis[i] < 0) { + int real_axis = param.axis[i] + static_cast(shp.ndim()); + CHECK(real_axis < static_cast(shp.ndim()) && real_axis >= 0); + axis_checker.insert(real_axis); + } + } + for (size_t i = 0; i < shp.ndim(); ++i) { + if(axis_checker.find(i) == axis_checker.end()) { + oshape.emplace_back(shp[i]); + } else { + CHECK_EQ(shp[i], 1) << "The squeezed axis must have shape 1!" + << "Want to squeeze " << i + << ", which has shape" << shp[i]; + } + } + } + if(oshape.size() == 0) { + // Handles the case where all axes are squeezed. + oshape.push_back(1); + } + TShape out_shape(oshape.begin(), oshape.end()); + CHECK_EQ(out_shape.Size(), shp.Size()) + << "Target shape size is different to source. " + << "Target: " << out_shape + << "\nSource: " << shp; + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, out_shape); + return true; +} + +NNVM_REGISTER_OP(squeeze) +.describe(R"code(Squeeze axises in the array. + +Examples:: + + x = [[[0], [1], [2]]] + + squeeze(x) = [0, 1, 2] + + squeeze(x, 0) = [[0], [1], [2]] + + squeeze(x, (0, 2)) = [0, 1, 2] +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "Source input") +.add_arguments(SqueezeParam::__FIELDS__()) +.set_attr_parser(ParamParser) +.set_attr("FGetAttrDict", ParamGetAttrDict) +.set_attr("FInferShape", SqueezeShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_num_inputs(1) +.set_num_outputs(1) +.set_support_level(1); + // tranpose DMLC_REGISTER_PARAMETER(TransposeParam); diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 068ea14c28b7..26ff34eb7ecb 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -220,6 +220,30 @@ def test_split(): verify_split((5, 3), [3], axis=0) verify_split((5, 9, 3), [3, 4], axis=1) + +def verify_squeeze(dshape, axis): + x = sym.Variable("x") + if axis: + y = sym.squeeze(x, axis=axis) + else: + y = sym.squeeze(x) + y = y + 1 + dtype = "float32" + for target, ctx in ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = graph_runtime.create(graph, lib, ctx) + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + m.run(x=data) + out_np = np.squeeze(data.asnumpy(), axis=axis) + 1 + out = m.get_output(0, tvm.nd.empty(out_np.shape)) + np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5) + +def test_squeeze(): + verify_squeeze((1, 3, 2, 5), None) + verify_squeeze((1, 3, 1), axis=0) + verify_squeeze((1, 3, 2, 5, 1), axis=-1) + if __name__ == "__main__": test_split() test_concatenate() @@ -232,3 +256,4 @@ def test_split(): test_tanh() test_sigmoid() test_softmax() + test_squeeze() diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 1b005e27d9ef..ad09d73ec28d 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -71,7 +71,6 @@ def test_reshape(): verify_reshape((2, 3, 4), (8, 3)) verify_reshape((4, 7), (2, 7, 2)) - if __name__ == "__main__": test_reshape() test_reduce()