diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 4a9b7346c459..f1b052c8d31f 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -3,7 +3,7 @@ from __future__ import absolute_import as _abs import tvm from . import tag -from .util import ravel_index, unravel_index, get_const_int +from .util import ravel_index, unravel_index, get_const_int, get_const_tuple @tvm.tag_scope(tag=tag.BROADCAST) def expand_dims(a, axis, num_newaxis=1): @@ -77,6 +77,57 @@ def reshape(a, newshape): lambda *indices: a(*unravel_index(ravel_index(indices, newshape), a_shape))) +@tvm.tag_scope(tag=tag.INJECTIVE) +def squeeze(a, axis=None): + """Remove single-dimensional entries from the shape of an array. + + Parameters + ---------- + a : tvm.Tensor + + axis : None or int or tuple of ints, optional + Selects a subset of the single-dimensional entries in the shape. + If an axis is selected with shape entry greater than one, an error is raised. + + Returns + ------- + squeezed : tvm.Tensor + """ + a_ndim = len(a.shape) + a_shape = get_const_tuple(a.shape) + if axis is None: + axis = [] + for i, ele in enumerate(a_shape): + if ele == 1: + axis.append(i) + else: + if isinstance(axis, int): + axis = axis + a_ndim if axis < 0 else axis + assert a_shape[axis] == 1 + axis = [axis] + else: + axis = [ele + a_ndim if ele < 0 else ele for ele in axis] + for ele in axis: + assert a_shape[ele] == 1 + out_shape = [] + search_axis = set(axis) + for i, a_dim in enumerate(a_shape): + if i not in search_axis: + out_shape.append(a_dim) + def _compute(*indices): + real_indices = [] + flag = 0 + for i in range(a_ndim): + if i not in search_axis: + real_indices.append(indices[i - flag]) + else: + real_indices.append(0) + flag += 1 + return a(*real_indices) + + return tvm.compute(out_shape, _compute) + + @tvm.tag_scope(tag=tag.INJECTIVE) def concatenate(a_tuple, axis=0): """Join a sequence of arrays along an existing axis. diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 14f767f5c760..8acb4b4f5a1b 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -69,6 +69,28 @@ def check_device(device): check_device("metal") +def verify_squeeze(src_shape, axis): + A = tvm.placeholder(shape=src_shape, name="A") + B = topi.squeeze(A, axis=axis) + s = topi.cuda.schedule_injective(B) + def check_device(device): + if not tvm.module.enabled(device): + print("Skip because %s is not enabled" % device) + return + ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) + foo = tvm.build(s, [A, B], device, name="squeeze") + data_npy = np.random.normal(size=src_shape).astype(A.dtype) + out_npy = np.squeeze(data_npy, axis=axis) + data_nd = tvm.nd.array(data_npy, ctx) + out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype) + foo(data_nd, out_nd) + np.testing.assert_allclose(out_nd.asnumpy(), out_npy) + + check_device("cuda") + check_device("opencl") + check_device("metal") + + def verify_concatenate(shapes, axis): tensor_l = [] for i, shape in enumerate(shapes): @@ -133,6 +155,12 @@ def test_reshape(): verify_reshape((16, ), (2, 2, 2, 2)) +def test_squeeze(): + verify_squeeze((1, 2, 3, 4), 0) + verify_squeeze((1, 2, 1, 4), None) + verify_squeeze((1, 1, 1, 4), (1, 2)) + + def test_concatenate(): verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1) verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1) @@ -152,6 +180,7 @@ def test_split(): test_tranpose() test_expand_dims() test_reshape() + test_squeeze() test_concatenate() test_split()