Skip to content

Commit

Permalink
[TOPI] cuda reduction schedule (#7131)
Browse files Browse the repository at this point in the history
* complex reduce

* fix

* fix

* fix
  • Loading branch information
hzfan authored Dec 23, 2020
1 parent e51bcdd commit c000631
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/tvm/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,14 @@ def traverse_after_reduce(operator):
for tensor in operator.input_tensors:
traverse_after_reduce(tensor.op)
elif operator.tag == "comm_reduce":
_schedule_reduce(operator, sch, is_idx_reduce=False)
if operator not in scheduled_ops:
_schedule_reduce(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors:
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
elif operator.tag == "comm_reduce_idx":
_schedule_reduce(operator, sch, is_idx_reduce=True)
if operator not in scheduled_ops:
_schedule_reduce(operator, sch, is_idx_reduce=True)
input_tensors = operator.input_tensors[0].op.input_tensors
for tensor in input_tensors:
if tensor.op not in scheduled_ops:
Expand All @@ -147,5 +149,6 @@ def traverse_after_reduce(operator):

scheduled_ops.append(operator)

traverse_after_reduce(outs[0].op)
for out in outs:
traverse_after_reduce(out.op)
return sch
26 changes: 26 additions & 0 deletions tests/python/topi/python/test_topi_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,31 @@ def test_reduce_map():
)


@tvm.testing.uses_gpu
def test_complex_reduce():
in_shape = (2, 3)
dtype = "float32"
axis = 0
keepdims = False
A = te.placeholder(shape=in_shape, name="A", dtype=dtype)
B = topi.sum(A, axis=axis, keepdims=keepdims)
C = topi.add(B, B)
D = topi.multiply(B, B)
E = topi.add(C, D)
for device, ctx in tvm.testing.enabled_targets():
print("Running on target: %s" % device)
with tvm.target.Target(device):
s = tvm.topi.testing.get_reduce_schedule(device)(E)
foo = tvm.build(s, [A, E], device, name="sum")
in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
sum_npy = in_npy.sum(axis=axis, keepdims=keepdims)
out_npy = sum_npy * 2 + sum_npy * sum_npy
data_tvm = tvm.nd.array(in_npy, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=dtype)
foo(data_tvm, out_tvm)
tvm.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1e-3, 1e-3)


if __name__ == "__main__":
test_reduce_map()
test_complex_reduce()

0 comments on commit c000631

Please sign in to comment.