Skip to content

Commit

Permalink
error fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 18, 2019
1 parent 96648c3 commit d6830a2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
10 changes: 3 additions & 7 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
idx = ib.buffer_ptr(idx)
score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)

max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
bx = tvm.thread_axis("vthread")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
ib.scope_attr(bx, "virtual_thread", nthread_bx)
tid = bx * max_threads + tx

with ib.if_scope(tid < batch_size * num_anchors):
Expand All @@ -64,10 +64,6 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
flag[tid] = 0
idx[tid] = 0

ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))

with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors) as k:
with ib.if_scope(k > 0):
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/cuda/ssd/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)

max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
nthread_tx = max_threads
nthread_bx = (batch_size * num_anchors) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
Expand Down Expand Up @@ -293,7 +293,7 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw,
out_loc = ib.buffer_ptr(out)

max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
nthread_tx = max_threads
nthread_bx = (batch_size * num_anchors) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
Expand Down
8 changes: 6 additions & 2 deletions topi/python/topi/cuda/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ def _default_schedule(outs):
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if op.tag in ["nms", "invalid_to_bottom"]:
sort = op.input_tensors[1]
if op.tag == "nms":
sort = op.input_tensors[1]
else:
out = op.input_tensors[0]
sort = s[out].op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads)
bx, tx = s[score].split(fused, factor=num_thread)
s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
Expand Down

0 comments on commit d6830a2

Please sign in to comment.