Skip to content

Commit

Permalink
Fix multibox_transform_loc
Browse files Browse the repository at this point in the history
  • Loading branch information
Wang committed Jan 14, 2019
1 parent 81bb789 commit d2f9601
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/top/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def compute_multibox_transform_loc(attrs, inputs, _):
return topi.vision.ssd.multibox_transform_loc(inputs[0], inputs[1], inputs[2],
clip, threshold, variance)

reg.register_pattern("multibox_detection", OpPattern.OPAQUE)
reg.register_pattern("multibox_transform_loc", OpPattern.OPAQUE)

# Get valid number of anchor boxes
@reg.register_schedule("get_valid_counts")
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tvm
from tvm.contrib import graph_runtime
import topi
import topi.testing
import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
Expand Down Expand Up @@ -657,7 +658,7 @@ def np_slice_like(np_data, np_shape_like, axis=[]):
slice_idx = []
for b, e in zip(begin_idx, end_idx):
slice_idx.append(slice(b, e))
np_result = np_data[slice_idx]
np_result = np_data[tuple(slice_idx)]
return np_result

def verify_slice_like(np_data, np_shape_like, axis=[]):
Expand Down
9 changes: 4 additions & 5 deletions topi/python/topi/vision/ssd/multibox.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,10 @@ def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01
-------
ret : tuple of tvm.Tensor
"""
out, valid_count = hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
tvm.const(clip, "bool"),
tvm.const(threshold, "float32"),
tvm.convert(variances))
return out, valid_count
return hybrid_multibox_transform_loc(cls_prob, loc_pred, anchor,
tvm.const(clip, "bool"),
tvm.const(threshold, "float32"),
tvm.convert(variances))

@tvm.target.generic_func
def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nms_threshold=0.5,
Expand Down

0 comments on commit d2f9601

Please sign in to comment.