Skip to content

Commit

Permalink
[TOPI] add squeeze (#494)
Browse files Browse the repository at this point in the history
* add squeeze

* should be squeeze
  • Loading branch information
sxjscience authored and tqchen committed Sep 26, 2017
1 parent fd864c5 commit dc7ab96
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
53 changes: 52 additions & 1 deletion topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from . import tag
from .util import ravel_index, unravel_index, get_const_int
from .util import ravel_index, unravel_index, get_const_int, get_const_tuple

@tvm.tag_scope(tag=tag.BROADCAST)
def expand_dims(a, axis, num_newaxis=1):
Expand Down Expand Up @@ -77,6 +77,57 @@ def reshape(a, newshape):
lambda *indices: a(*unravel_index(ravel_index(indices, newshape), a_shape)))


@tvm.tag_scope(tag=tag.INJECTIVE)
def squeeze(a, axis=None):
"""Remove single-dimensional entries from the shape of an array.
Parameters
----------
a : tvm.Tensor
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
Returns
-------
squeezed : tvm.Tensor
"""
a_ndim = len(a.shape)
a_shape = get_const_tuple(a.shape)
if axis is None:
axis = []
for i, ele in enumerate(a_shape):
if ele == 1:
axis.append(i)
else:
if isinstance(axis, int):
axis = axis + a_ndim if axis < 0 else axis
assert a_shape[axis] == 1
axis = [axis]
else:
axis = [ele + a_ndim if ele < 0 else ele for ele in axis]
for ele in axis:
assert a_shape[ele] == 1
out_shape = []
search_axis = set(axis)
for i, a_dim in enumerate(a_shape):
if i not in search_axis:
out_shape.append(a_dim)
def _compute(*indices):
real_indices = []
flag = 0
for i in range(a_ndim):
if i not in search_axis:
real_indices.append(indices[i - flag])
else:
real_indices.append(0)
flag += 1
return a(*real_indices)

return tvm.compute(out_shape, _compute)


@tvm.tag_scope(tag=tag.INJECTIVE)
def concatenate(a_tuple, axis=0):
"""Join a sequence of arrays along an existing axis.
Expand Down
29 changes: 29 additions & 0 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ def check_device(device):
check_device("metal")


def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.squeeze(A, axis=axis)
s = topi.cuda.schedule_injective(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

check_device("cuda")
check_device("opencl")
check_device("metal")


def verify_concatenate(shapes, axis):
tensor_l = []
for i, shape in enumerate(shapes):
Expand Down Expand Up @@ -133,6 +155,12 @@ def test_reshape():
verify_reshape((16, ), (2, 2, 2, 2))


def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
verify_squeeze((1, 1, 1, 4), (1, 2))


def test_concatenate():
verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
Expand All @@ -152,6 +180,7 @@ def test_split():
test_tranpose()
test_expand_dims()
test_reshape()
test_squeeze()
test_concatenate()
test_split()

0 comments on commit dc7ab96

Please sign in to comment.