Skip to content

Commit

Permalink
bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 16, 2019
1 parent 958a929 commit 0ab436d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
8 changes: 6 additions & 2 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,12 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res,
check_type_only=False):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int"))
z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False)
z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k)
z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
iou_threshold = iou_threshold, force_suppress = force_suppress, \
top_k = top_k, return_indices=False)
z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
iou_threshold = iou_threshold, force_suppress = force_suppress, \
top_k = top_k)
assert "iou_threshold" in z.astext()
assert "iou_threshold" in z_indices.astext()
zz = relay.ir_pass.infer_type(z)
Expand Down
4 changes: 4 additions & 0 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
flag[tid] = 0
idx[tid] = 0

ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))

with ib.if_scope(tid < batch_size):
with ib.for_range(0, num_anchors) as k:
with ib.if_scope(k > 0):
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/cuda/ssd/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = non_max_suppression(
inter_out[0], inter_out[1], -1, nms_threshold, force_suppress, \
nms_topk, return_indices=False)
out = non_max_suppression(inter_out[0], inter_out[1], max_output_size = -1,
iou_threshold = nms_threshold, force_suppress = force_suppress,
top_k = nms_topk, return_indices=False)
return out
6 changes: 1 addition & 5 deletions topi/python/topi/cuda/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ def _default_schedule(outs):
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if op.tag in ["nms", "invalid_to_bottom"]:
if op.name in ['nms']:
sort = op.input_tensors[1]
else:
out = op.input_tensors[0]
sort = s[out].op.input_tensors[1]
sort = op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/vision/ssd/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = non_max_suppression(inter_out[0], inter_out[1], -1,
nms_threshold, force_suppress, nms_topk,
return_indices=False)
out = non_max_suppression(inter_out[0], inter_out[1], max_output_size = -1,
iou_threshold = nms_threshold, force_suppress = force_suppress,
top_k = nms_topk, return_indices=False)
return out

0 comments on commit 0ab436d

Please sign in to comment.