Skip to content

Commit

Permalink
Parallelize cumsum in get_valid_counts (apache#7123)
Browse files Browse the repository at this point in the history
* Parallelize cumsum in get_valid_counts

* make the scan loop exclusive

* switch to directly using exclusive scan

* perform inner loop of final writes on anchor threads

* fix flaky test

fix lint

* remove final cuda kernel

Co-authored-by: masa <[email protected]>
  • Loading branch information
2 people authored and Tushar Dey committed Jan 20, 2021
1 parent 3c5a91c commit db2c7fa
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 52 deletions.
134 changes: 101 additions & 33 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, score_index
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
Expand Down Expand Up @@ -151,31 +151,103 @@ def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
valid_indices = ib.buffer_ptr(valid_indices)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

# Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
nthread_bx = batch_size // max_threads + 1
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
# TODO(mbrookhart): Parallelize the sum and cumsum here
current_index = ib.allocate("int32", (1,), name="current_index", scope="local")
with ib.if_scope(tid < batch_size):
current_index[0] = 0
valid_count[tid] = 0
with ib.for_range(0, num_anchors) as j:
idx = tid * num_anchors + j
valid_count[tid] = valid_count[tid] + valid_boxes[idx]
with ib.if_scope(valid_boxes[idx] == 1):
valid_indices[idx] = current_index[0]
current_index[0] = current_index[0] + 1
with ib.else_scope():
valid_indices[idx] = -1
ib.scope_attr(by, "thread_extent", nthread_by)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < num_anchors):
valid_indices[by, tid] = valid_boxes[by, tid]

nthread_tx = max_threads
nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size

## The following algorithm performs parallel exclusive scan to get
## a tensor that can later be used to select valid indices
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << l2_width

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < num_anchors):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.te.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
valid_indices[by * num_anchors + end[0] - 1] += valid_indices[
by * num_anchors + middle[0] - 1
]

# Down Sweep of exclusive scan
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", batch_size)
with ib.if_scope(bx < batch_size):
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
valid_indices[(bx + 1) * num_anchors - 1] = 0

with ib.for_range(0, lim, dtype="int64") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(
bx,
"thread_extent",
tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * width), "int32"),
)
tid = bx * nthread_tx + tx

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tmp = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < num_anchors)):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
end[0] = tvm.tir.min(start[0] + width, num_anchors)
with ib.if_scope(middle[0] < num_anchors):
tmp[0] = valid_indices[by * num_anchors + middle[0] - 1]
valid_indices[by * num_anchors + middle[0] - 1] = valid_indices[
by * num_anchors + end[0] - 1
]
valid_indices[by * num_anchors + end[0] - 1] += tmp[0]

return ib.get()


def get_valid_counts_ir(data, valid_indices, out, out_indices):
def get_valid_counts_ir(data, valid_indices, valid_boxes, out, out_indices):
"""Low level IR to get valid count of bounding boxes
given a score threshold. Also prepares to move valid boxes to the
top of input data.
Expand Down Expand Up @@ -203,8 +275,9 @@ def get_valid_counts_ir(data, valid_indices, out, out_indices):
ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)

valid_indices = ib.buffer_ptr(valid_indices)
valid_boxes = ib.buffer_ptr(valid_boxes)

out = ib.buffer_ptr(out)
out_indices = ib.buffer_ptr(out_indices)
one = tvm.tir.const(1, dtype=out.dtype)
Expand All @@ -213,41 +286,36 @@ def get_valid_counts_ir(data, valid_indices, out, out_indices):
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
nthread_by = batch_size
nthread_bz = elem_length
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
tid = bx * max_threads + tx
with ib.if_scope(tid < num_anchors):
i = by
j = tid
k = bz
out[(i * num_anchors + j) * elem_length + k] = -one
with ib.for_range(0, elem_length) as k:
out[(i * num_anchors + j) * elem_length + k] = -one
out_indices[i * num_anchors + j] = -1
with ib.new_scope():
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
bz = te.thread_axis("blockIdx.z")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(by, "thread_extent", nthread_by)
ib.scope_attr(bz, "thread_extent", nthread_bz)
tid = bx * max_threads + tx
with ib.if_scope(tid < num_anchors):
i = by
j = tid
k = bz
with ib.if_scope(valid_indices[i, tid] >= 0):
out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[
(i * num_anchors + j) * elem_length + k
]
with ib.if_scope(valid_boxes[i, tid] > 0):
with ib.for_range(0, elem_length) as k:
out[(i * num_anchors + valid_indices[i, tid]) * elem_length + k] = data[
(i * num_anchors + j) * elem_length + k
]
out_indices[i * num_anchors + valid_indices[i, tid]] = j
return ib.get()

Expand Down Expand Up @@ -321,10 +389,10 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):

out, out_indices = te.extern(
[data.shape, (batch_size, num_anchors)],
[data, valid_indices],
lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], outs[1]),
[data, valid_indices, valid_boxes],
lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], ins[2], outs[0], outs[1]),
dtype=["int32", data.dtype],
in_buffers=[data_buf, valid_indices_buf],
in_buffers=[data_buf, valid_indices_buf, valid_boxes_buf],
out_buffers=[out_buf, out_indices_buf],
name="get_valid_counts",
tag="get_valid_counts_gpu",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
out_indices: tvm.te.Tensor or numpy NDArray
Related index in input data.
"""
if isinstance(score_threshold, float):
if isinstance(score_threshold, (float, int)):
score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype)
id_index_const = tvm.tir.const(id_index, "int32")
score_index_const = tvm.tir.const(score_index, "int32")
Expand Down
4 changes: 1 addition & 3 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
for target, ctx in tvm.testing.enabled_targets():
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)

tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
# get_valid_count for opencl doesn't do data rearrangement
if target in ["opencl"]:
return
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04)

Expand Down
21 changes: 6 additions & 15 deletions tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,18 @@ def check_device(device):
tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx)
tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx)
tvm_out3 = tvm.nd.array(np.zeros(np_out3.shape, dtype="int32"), ctx)
if device == "llvm":
f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device)
f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)
else:
f = tvm.build(s, [data, outs[0], outs[1]], device)
f(tvm_input_data, tvm_out1, tvm_out2)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)

f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device)
f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3)
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)

for device in ["llvm", "cuda", "opencl"]:
check_device(device)


@tvm.testing.uses_gpu
@pytest.mark.skip(
"Skip this test as it is intermittent."
"See https://github.com/apache/tvm/pull/4901#issuecomment-595040094"
)
def test_get_valid_counts():
verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0)
verify_get_valid_counts((1, 2500, 6), 0, 0, 1)
Expand Down

0 comments on commit db2c7fa

Please sign in to comment.