Skip to content

Commit

Permalink
ssd gluoncv ops
Browse files Browse the repository at this point in the history
merge with master

ssd gluoncv gpu op updated

tutorials and testes modified

fix lint

address comment

multibox bug fixed

space line added

use less threads per block

less threads per block for get valid count

Revert "less threads per block for get valid count"

This reverts commit 08896cf.

typo fixed

elem length made to a variable

fix lint error

fix lint error

lint fixed

bug fixed

lint fixed

error fixed

test ci

seperate argsort to be an independent op

fix lint

fix lint

remove unsupported models

ssd gluoncv gpu op updated

tutorials and testes modified

fix lint

use less threads per block

less threads per block for get valid count

Revert "less threads per block for get valid count"

This reverts commit 08896cf.

bug fixed

error fixed

test ci

seperate argsort to be an independent op

typo fixed

argsort added to realy

solve conflicts with master

fix lint

fix lint

test push

Revert "test push"

This reverts commit 6db0088.

fix lint error

fix more lint

cpu test_sort udpated

debug ci

nms fixed

expose argsort to relay frontend

test ci

fix lint

sort register error fixed

fix nnvm

adaptive pooling added to relay

nms type fixed

Revert "adaptive pooling added to relay"

This reverts commit 1119f1f.

fix lint

expose argsort op

fix lint

fix lint

fix lint

sort test updated

sort bug fixed

nnvm error fixed

fix argsort default data type returned to be float insteaf of int

fix lint

fix lint

test fixed

fix valid count

fix titanx bug

tutorial add both targets

titanx error fixed

try to fix CI old gpu error

try to solve CI GPU error

get_valid_count added

[AutoTVM] fix argument type for curve feature (apache#3004)
  • Loading branch information
Laurawly authored and icemelon committed Apr 26, 2019
1 parent 6c6273c commit c0e3e5e
Show file tree
Hide file tree
Showing 28 changed files with 1,331 additions and 376 deletions.
24 changes: 24 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@
namespace tvm {
namespace relay {

/*! \brief Attributes used in argsort operators */
struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
int axis;
bool is_ascend;
std::string dtype;

TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis along which to sort the input tensor."
"If not given, the flattened array is used.");
TVM_ATTR_FIELD(is_ascend).set_default(true)
.describe("Whether to sort in ascending or descending order."
"By default, sort in ascending order");
TVM_ATTR_FIELD(dtype).set_default("float32")
.describe("DType of the output indices.");
}
};

/*! \brief Attributes used in multibox_prior operators */
struct MultiBoxPriorAttrs : public tvm::AttrsNode<MultiBoxPriorAttrs> {
Array<IndexExpr> sizes;
Expand Down Expand Up @@ -92,6 +110,8 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
double iou_threshold;
bool force_suppress;
int top_k;
int coord_start;
int score_index;
int id_index;
bool return_indices;
bool invalid_to_bottom;
Expand All @@ -106,6 +126,10 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
TVM_ATTR_FIELD(coord_start).set_default(2)
.describe("Start index of the consecutive 4 coordinates.");
TVM_ATTR_FIELD(score_index).set_default(1)
.describe("Index of the scores/confidence of boxes.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(return_indices).set_default(true)
Expand Down
6 changes: 6 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,8 @@ struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppression
bool force_suppress;
int top_k;
int id_index;
int coord_start;
int score_index;
int max_output_size;
bool invalid_to_bottom;
DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) {
Expand All @@ -500,6 +502,10 @@ struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppression
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
DMLC_DECLARE_FIELD(coord_start).set_default(2)
.describe("Start index of the consecutive 4 coordinates.");
DMLC_DECLARE_FIELD(score_index).set_default(1)
.describe("Index of the scores/confidence of boxes.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
DMLC_DECLARE_FIELD(return_indices).set_default(true)
Expand Down
10 changes: 7 additions & 3 deletions nnvm/python/nnvm/top/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ def compute_nms(attrs, inputs, _):
id_index = attrs.get_int('id_index')
invalid_to_bottom = attrs.get_bool('invalid_to_bottom')

return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
return topi.vision.non_max_suppression(inputs[0], inputs[1],
max_output_size=max_output_size,
iou_threshold=iou_threshold,
force_suppress=force_suppress,
top_k=top_k, id_index=id_index,
return_indices=return_indices,
invalid_to_bottom=invalid_to_bottom)

reg.register_pattern("non_max_suppression", OpPattern.OPAQUE)
51 changes: 24 additions & 27 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,13 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1),
if clip:
np_out = np.clip(np_out, 0, 1)

target = "llvm"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
m.run()
out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)

def test_multibox_prior():
verify_multibox_prior((1, 3, 50, 50))
Expand All @@ -577,17 +576,16 @@ def test_multibox_transform_loc():
[0, 0.44999999, 1, 1, 1, 1],
[0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])

target = "llvm"
dtype = "float32"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
"loc_preds": (batch_size, num_anchors * 4),
"anchors": (1, num_anchors, 4)})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
m.run()
out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
"loc_preds": (batch_size, num_anchors * 4),
"anchors": (1, num_anchors, 4)})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)

def test_non_max_suppression():
dshape = (1, 5, 6)
Expand All @@ -607,15 +605,14 @@ def test_non_max_suppression():
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])

target = "llvm"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
dtype={"data": "float32", "valid_count": "int32"})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data": np_data, "valid_count": np_valid_count})
m.run()
out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
tvm.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
dtype={"data": "float32", "valid_count": "int32"})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data": np_data, "valid_count": np_valid_count})
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)

def np_slice_like(np_data, np_shape_like, axis=[]):
begin_idx = [0 for _ in np_data.shape]
Expand Down
46 changes: 29 additions & 17 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ def _pool2d(new_op, is_avg):
'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize()))


def _mx_adaptive_pooling(inputs, attrs):
output_size = attrs.get_int_tuple("output_size", [])
if output_size != (1,):
raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
return _op.nn.global_avg_pool2d(inputs[0])


def _mx_dropout(inputs, attrs):
rate = attrs.get_float("p", 0.5)
return _op.nn.dropout(inputs[0], rate=rate)
Expand Down Expand Up @@ -539,15 +546,6 @@ def _mx_box_nms(inputs, attrs):
id_index = attrs.get_int('id_index', -1)
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
if coord_start != 2:
raise tvm.error.OpAttributeInvalid(
'Value of attribute "coord_start" must equal 2 for operator box_nms.')
if score_index != 1:
raise tvm.error.OpAttributeInvalid(
'Value of attribute "score_index" must equal 1 for operator box_nms.')
if id_index != -1 and int(id_index) != 0:
raise tvm.error.OpAttributeInvalid(
'Value of attribute "id_index" must equal either -1 or 0 for operator box_nms.')
if in_format != 'corner':
raise tvm.error.OpAttributeInvalid(
'Value of attribute "in_format" must equal "corner" for operator box_nms.')
Expand All @@ -561,6 +559,8 @@ def _mx_box_nms(inputs, attrs):
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
coord_start=coord_start,
score_index=score_index,
id_index=id_index,
return_indices=False,
invalid_to_bottom=True)
Expand Down Expand Up @@ -658,6 +658,15 @@ def _mx_deformable_convolution(inputs, attrs):
return res


def _mx_argsort(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.vision.argsort(inputs[0], **new_attrs)


def _mx_contrib_div_sqrt_dim(inputs, attrs):
assert len(inputs) == 1
ndim = len(ir_pass.infer_type(inputs[0])._checked_type_.shape)
Expand All @@ -673,7 +682,7 @@ def _mx_foreach(inputs, attrs, subgraphs, dtype_info, mod):
nil = p.nil
cons = p.cons
l = p.l

assert len(subgraphs) == 1
in_data_locs = json.loads(attrs.get_str('in_data_locs'))
in_state_locs = json.loads(attrs.get_str('in_state_locs'))
Expand Down Expand Up @@ -705,7 +714,7 @@ def _mx_foreach(inputs, attrs, subgraphs, dtype_info, mod):
for k, v in enumerate(remain_locs):
assert loop_body_args[v] is None
loop_body_args[v] = params[k]
loop_body_arg_shapes = [ir_pass.infer_type(arg).checked_type.shape
loop_body_arg_shapes = [ir_pass.infer_type(arg).checked_type.shape
for arg in loop_body_args]
loop_body = _from_mxnet_impl(mod, subgraphs[0], loop_body_arg_shapes, dtype_info)
loop_body_ret = _expr.Call(loop_body, loop_body_args)
Expand Down Expand Up @@ -742,20 +751,20 @@ def _mx_while_loop(inputs, attrs, subgraphs, dtype_info, mod):
nil = p.nil
cons = p.cons
l = p.l

assert len(subgraphs) == 2
input_args = []
for i, arg in enumerate(inputs):
var = _expr.var("arg%s" % i, ir_pass.infer_type(arg).checked_type)
input_args.append(var)

cond_input_locs = attrs.get_int_tuple("cond_input_locs")
func_input_locs = attrs.get_int_tuple("func_input_locs")
# indices of state vars in the func_input_locs
func_var_locs = attrs.get_int_tuple("func_var_locs")
num_out_data = attrs.get_int("num_out_data")
num_outputs = attrs.get_int("num_outputs")

all_outs = _expr.var("all_outs")
while_loop = _expr.GlobalVar("while_loop")
prev_states = [input_args[func_input_locs[j]] for j in func_var_locs]
Expand All @@ -765,7 +774,7 @@ def _mx_while_loop(inputs, attrs, subgraphs, dtype_info, mod):
cond_body = _from_mxnet_impl(mod, subgraphs[0], cond_arg_shapes, dtype_info)
cond_ret = _expr.Call(cond_body, cond_args)
cond = _op.take(cond_ret, _expr.const(0)).astype("bool")

sb = _scope_builder.ScopeBuilder()
with sb.if_scope(cond):
func_args = [input_args[j] for j in func_input_locs]
Expand All @@ -785,7 +794,7 @@ def _mx_while_loop(inputs, attrs, subgraphs, dtype_info, mod):
sb.ret(recur_ret)
with sb.else_scope():
sb.ret(_expr.Tuple([all_outs] + prev_states))

body = sb.get()
while_args = input_args + [all_outs]
# print(while_args)
Expand All @@ -808,6 +817,7 @@ def _mx_layer_norm(inputs, attrs):
def _mx_sequence_mask(inputs, attrs):
return inputs[0]


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -944,6 +954,7 @@ def _mx_sequence_mask(inputs, attrs):
"BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"smooth_l1" : _mx_smooth_l1,
Expand All @@ -958,6 +969,7 @@ def _mx_sequence_mask(inputs, attrs):
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
"_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_pooling,
# control flow
"_foreach" : _mx_foreach,
"_while_loop" : _mx_while_loop,
Expand Down Expand Up @@ -1028,7 +1040,7 @@ def _from_mxnet_impl(mod, symbol, shape_dict, dtype_info):
subgraphs = node['subgraphs']
res = _convert_map[op_name](children, attrs, subgraphs, dtype_info, mod)
else:
res = _convert_map[op_name](children, attrs)
res = _convert_map[op_name](children, attrs)
if isinstance(res, (_expr.TupleWrapper, tuple, list)):
pass
elif isinstance(res, _expr.Expr):
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,30 @@ def concatenate(data, axis):
return _make.concatenate(Tuple(data), axis)


def stack(data, axis):
"""Join a sequence of arrays along a new axis.
Parameters
----------
data : Union(List[relay.Expr], Tuple(relay.Expr))
A list of tensors.
axis : int
The axis in the result array along which the input arrays are stacked.
Returns
-------
ret : relay.Expr
The stacked tensor.
"""
data = list(data)
if not data:
raise ValueError("relay.stack requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.stack(Tuple(data), axis)


def copy(data):
"""Copy a tensor.
Expand Down
23 changes: 0 additions & 23 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,28 +315,6 @@ def arange(start, stop=None, step=const(1, dtype="int32"), dtype="float32"):
return _make.arange(start, stop, step, dtype)


def stack(data, axis):
"""Join a sequence of arrays along a new axis.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis : int
The axis in the result array along which the input arrays are stacked.
.. note::
Each array in the input array sequence must have the same shape.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.stack(data, axis)


def repeat(data, repeats, axis):
"""Repeats elements of an array.
By default, repeat flattens the input array into 1-D and then repeats the elements.
Expand Down Expand Up @@ -698,5 +676,4 @@ def gather_nd(data, indices):
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
"""

return _make.gather_nd(data, indices)
2 changes: 2 additions & 0 deletions python/tvm/relay/op/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from .nms import *
from .rcnn import *
from .yolo import *
from .sort import *
from . import _rcnn
from . import _yolo
from . import _vision
from .import _sort
Loading

0 comments on commit c0e3e5e

Please sign in to comment.