Skip to content

Commit

Permalink
[TOPI] Fix traverse function not inline zero-input op (#3623)
Browse files Browse the repository at this point in the history
* Fix traverse_inline not inline zero input op properly

* Add where to python and set tag to broadcast

* Fix inline

* test

* fix test target

* fix
  • Loading branch information
vinx13 authored and tqchen committed Jul 30, 2019
1 parent d4a5175 commit 9d583cf
Show file tree
Hide file tree
Showing 22 changed files with 114 additions and 31 deletions.
2 changes: 1 addition & 1 deletion topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ inline Tensor where(const Tensor& condition,
const Tensor& x,
const Tensor& y,
std::string name = "T_where",
std::string tag = kInjective) {
std::string tag = kBroadcast) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: "
<< x->shape.size() << " vs " << y->shape.size();
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'spatial_bitserial_conv_nhwc' in op.tag:
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)

elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def traverse(operator):
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn':
Apad = operator.input_tensors[0]
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d_nhwc':
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/cuda/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('adaptive_pool'):
Expand Down Expand Up @@ -137,7 +137,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
Expand Down
10 changes: 5 additions & 5 deletions topi/python/topi/hls/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule conv2d
elif OP.tag.find("conv2d") >= 0:
Expand Down Expand Up @@ -220,7 +220,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
elif OP.tag in ["comm_reduce", "comm_reduce_idx"]:
if OP.tag == "comm_reduce":
Expand Down Expand Up @@ -298,7 +298,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
Expand Down Expand Up @@ -342,7 +342,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
Expand Down Expand Up @@ -386,7 +386,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('adaptive_pool'):
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/intel_graphics/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d' in op.tag:
_schedule_cl_spatialpack_NCHWc(s, op)
Expand Down Expand Up @@ -378,7 +378,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv2d' in op.tag:
_schedule_cl_spatialpack(s, op)
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/opengl/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule conv2d_nchw
elif OP.tag.startswith('conv2d_nchw'):
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/opengl/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/opengl/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].opengl()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('adaptive_pool'):
Expand Down Expand Up @@ -108,7 +108,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op not in scheduled_ops and tensor.op.input_tensors:
if tensor.op not in scheduled_ops and isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
Expand Down
22 changes: 22 additions & 0 deletions topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,25 @@ def ndarray_size(array, dtype="int32"):
The resulting tensor.
"""
return cpp.ndarray_size(array, dtype)


def where(condition, x, y):
"""Get the elements, either from x or y, depending on the condition.
Parameters
----------
condition : tvm.Tensor
The condition array.
x : tvm.Tensor
First array to be selected.
y : tvm.Tensor
Second array to be selected.
Returns
-------
result : tvm.Tensor
A Tensor selected from x or y depending on condition.
"""
return cpp.where(condition, x, y)
2 changes: 1 addition & 1 deletion topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
_traverse(tensor.op)
callback(op)

Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/binary_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule binary_dense
elif OP.tag == 'binary_dense':
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors and tensor.op not in scheduled_ops:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)

elif 'spatial_bitserial_conv_nchw' in op.tag or 'spatial_bitserial_conv_nhwc' in op.tag:
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/bitserial_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp):
traverse(tensor.op)

elif op.tag == 'bitserial_dense' or 'bitserial_dense_unipolar':
Expand Down
8 changes: 4 additions & 4 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv2d_nchw' in op.tag:
Expand Down Expand Up @@ -284,7 +284,7 @@ def traverse(op):
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv2d_nhwc_pack_int8' in op.tag:
Expand Down Expand Up @@ -335,7 +335,7 @@ def traverse(op):
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv2d_nhwc' in op.tag:
Expand Down Expand Up @@ -648,7 +648,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv2d_NCHWc' in op.tag:
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)

if 'conv2d_transpose_nchw' in op.tag:
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/x86/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def traverse(op):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'depthwise_conv2d_NCHWc' in op.tag:
conv_out = op.output(0)
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/x86/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
Expand Down Expand Up @@ -136,7 +136,7 @@ def traverse(OP):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('adaptive_pool'):
Expand Down
61 changes: 61 additions & 0 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,35 @@ def check_device(device):
for device in get_all_backend():
check_device(device)

def verify_where(in_shape):
Cond = tvm.placeholder(shape=in_shape, name="cond")
dtype = Cond.dtype
A = tvm.placeholder(shape=in_shape, name="A")
B = tvm.placeholder(shape=in_shape, name="B")
C = topi.where(Cond, A, B)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_broadcast(C)
f = tvm.build(s, [Cond, A, B, C], device, name="where")
cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype)
x_npy = np.random.uniform(size=in_shape).astype(dtype)
y_npy = np.random.uniform(size=in_shape).astype(dtype)
out_npy = np.where(cond_npy, x_npy, y_npy)
cond_nd = tvm.nd.array(cond_npy, ctx)
x_nd = tvm.nd.array(x_npy, ctx)
y_nd = tvm.nd.array(y_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
f(cond_nd, x_nd, y_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)

for device in get_all_backend():
check_device(device)

def test_strided_slice():
verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
Expand Down Expand Up @@ -483,6 +512,10 @@ def test_reshape():
verify_reshape((16, ), (2, 2, 2, 2))


def test_where():
verify_where((1, 2, 3, 4))


def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
Expand Down Expand Up @@ -712,13 +745,40 @@ def check_device(device):
check_device(backend)


def test_where_fusion():
"""integration test that where and zeros should be properly inlined"""
def check_device(device):
with tvm.target.create(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
data = tvm.placeholder((2, 1, 2, 4), 'int8', 'data')
w = tvm.placeholder((3, 1, 2, 2), 'int8', 'w')
conv1 = topi.nn.conv2d(data, w, 1, 0, 1, out_dtype='int32')
zeros = topi.full((2, 3, 1, 3), 'int32', tvm.const(0, dtype='int32'))
gt = topi.greater_equal(conv1, zeros)
one = topi.full((2, 3, 1, 3), 'int32', tvm.const(1, dtype='int32'))
two = topi.full((2, 3, 1, 3), 'int32', tvm.const(2, dtype='int32'))
where = topi.where(gt, one, two)
add = topi.add(conv1, where)
outs = [add]
s = topi.generic.schedule_conv2d_nchw(outs)
tvm.build(s, [data, w, add], target=backend)

for backend in get_all_backend():
check_device(backend)


if __name__ == "__main__":
test_strided_slice()
test_concatenate()
test_stack()
test_transpose()
test_expand_dims()
test_reshape()
test_where()
test_squeeze()
test_split()
test_flip()
Expand All @@ -732,3 +792,4 @@ def check_device(device):
test_shape()
test_sequence_mask()
test_ndarray_size()
test_where_fusion()

0 comments on commit 9d583cf

Please sign in to comment.