Skip to content

Commit

Permalink
Relay support
Browse files Browse the repository at this point in the history
  • Loading branch information
Wang committed Jan 8, 2019
1 parent e329898 commit 7f2712f
Show file tree
Hide file tree
Showing 19 changed files with 434 additions and 53 deletions.
14 changes: 14 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
}
};

struct SliceAxisAttrs : public tvm::AttrsNode<SliceAxisAttrs> {
int axis;
int begin;
int end;

TVM_DECLARE_ATTRS(SliceAxisAttrs, "relay.attrs.SliceAxisAttrs") {
TVM_ATTR_FIELD(axis)
.describe("Axis along which to be sliced.");
TVM_ATTR_FIELD(begin)
.describe("Index for begin of slice");
TVM_ATTR_FIELD(end).set_default(0)
.describe("Index for end of the slice");
}
};

struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
Array<Integer> axes;
Expand Down
32 changes: 24 additions & 8 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,35 @@ struct MultiBoxTransformLocAttrs
}
};

/*! \brief Attributes used in non_maximum_suppression operators */
/*! \brief Attributes used in get_valid_counts operator */
struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs>{
double score_threshold;

TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
TVM_ATTR_FIELD(score_threshold).set_default(0.0)
.describe("Lower limit of score for valid bounding boxes.");
}
};

/*! \brief Attributes used in non_maximum_suppression operator */
struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
double overlap_threshold;
double iou_threshold;
bool force_suppress;
int topk;
int id_index;
bool do_rearrange;

TVM_DECLARE_ATTRS(NMSAttrs, "relay.attrs.NMSAttrs") {
TVM_ATTR_FIELD(overlap_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
TVM_ATTR_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(do_rearrange).set_default(false)
.describe("Whether to move all valid bounding boxes to the top.");
}
};

Expand Down
18 changes: 9 additions & 9 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,9 @@ struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {

DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout).set_default("__undef__")
.describe("Dimension ordering of data");
.describe("Dimension ordering of data");
DMLC_DECLARE_FIELD(dst_layout).set_default("__undef__")
.describe("Dimension ordering of data.");
.describe("Dimension ordering of data.");
}
};

Expand All @@ -419,13 +419,13 @@ struct MultiBoxPriorParam : public dmlc::Parameter<MultiBoxPriorParam> {
DMLC_DECLARE_FIELD(sizes).set_default(Tuple<float>({1.0}))
.describe("List of sizes of generated MultiBoxPriores.");
DMLC_DECLARE_FIELD(ratios).set_default(Tuple<float>({1.0}))
.describe("List of aspect ratios of generated MultiBoxPriores.");
.describe("List of aspect ratios of generated MultiBoxPriores.");
DMLC_DECLARE_FIELD(steps).set_default(Tuple<float>({-1.0, -1.0}))
.describe("Priorbox step across y and x, -1 for auto calculation.");
.describe("Priorbox step across y and x, -1 for auto calculation.");
DMLC_DECLARE_FIELD(offsets).set_default(Tuple<float>({0.5, 0.5}))
.describe("Priorbox center offsets, y and x respectively.");
.describe("Priorbox center offsets, y and x respectively.");
DMLC_DECLARE_FIELD(clip).set_default(false)
.describe("Whether to clip out-of-boundary boxes.");
.describe("Whether to clip out-of-boundary boxes.");
}
};

Expand Down Expand Up @@ -461,11 +461,11 @@ struct NMSParam : public dmlc::Parameter<NMSParam> {
DMLC_DECLARE_FIELD(iou_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Axis index for id.");
DMLC_DECLARE_FIELD(do_rearrange).set_default(false)
.describe("Whether to move all valid bounding boxes to the top.");
}
Expand Down
1 change: 0 additions & 1 deletion nnvm/src/top/vision/nms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"

namespace nnvm {
namespace top {
Expand Down
10 changes: 1 addition & 9 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,16 +700,8 @@ def test_slice_like():
def verify_slice_axis(dshape, axis, begin, end):
data = sym.Variable("data")
net = sym.slice_axis(data, axis=axis, begin=begin, end=end)
if axis < 0:
axis += len(dshape)
if begin < 0:
begin += dshape[axis]
if end <= 0:
end += dshape[axis]
np_data = np.random.uniform(size=dshape)
slc = [slice(None)] * len(dshape)
slc[axis] = slice(begin, end)
np_out = np_data[slc]
np_out = topi.testing.slice_axis_python(np_data, axis, begin, end)

dtype = "float32"
for target, ctx in ctx_list():
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,54 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)


def _mx_box_nms(inputs, attrs):
force_suppress = attrs.get_bool("force_suppress", False)
overlap_thresh = attrs.get_float('overlap_thresh', 0.5)
topk = attrs.get_int('topk', -1)
valid_thresh = attrs.get_float('valid_thresh', 0)
coord_start = attrs.get_int('coord_start', 2)
score_index = attrs.get_int('score_index', 1)
id_index = attrs.get_int('id_index', -1)
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
if coord_start != 2:
raise RuntimeError('coord_start %s is not supported.' % coord_start)
if score_index != 1:
raise RuntimeError('score_index %s is not supported.' % score_index)
if id_index != -1 and int(id_index) != 0:
raise RuntimeError('id_index %s is not supported.' % id_index)
if in_format != 'corner':
raise RuntimeError('in_format %s is not supported.' % in_format)
if out_format != 'corner':
raise RuntimeError('out_format %s is not supported.' % out_format)

valid_counts, inter_out = \
_op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh)
nms_out = _op.vision.nms(inter_out, valid_counts,
iou_threshold=overlap_thresh,
force_suppress=force_suppress,
topk=topk, id_index=id_index,
do_rearrange=True)
return nms_out


def _mx_slice_axis(inputs, attrs):
new_attrs = {}
new_attrs['axis'] = attrs.get_int('axis')
new_attrs['begin'] = attrs.get_int('begin')
new_attrs['end'] = attrs.get_int('end', 0)
return _op.slice_axis(inputs[0], **new_attrs)

def _mx_l2_normalize(inputs, attrs):
new_attrs = {}
mode = attrs.get_str('mode', 'instance')
if mode != 'channel':
raise RuntimeError('mode %s is not supported.' % mode)
new_attrs['eps'] = attrs.get_float('eps', 1e-10)
new_attrs['axis'] = 1
return _op.nn.l2_normalize(inputs[0], **new_attrs)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -346,7 +394,9 @@ def _mx_multibox_detection(inputs, attrs):
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
"SliceChannel" : _mx_split,
"slice_axis" : _mx_slice_axis,
"split" : _mx_split,
"expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("cast", schedule_injective)
_reg.register_schedule("strided_slice", schedule_injective)
_reg.register_schedule("slice_axis", schedule_injective)
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_schedule("split", schedule_injective)
_reg.register_schedule("take", schedule_injective)
Expand Down
28 changes: 27 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def strided_slice(data, begin, end, strides=None):
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
Indices indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
Expand All @@ -403,6 +403,32 @@ def strided_slice(data, begin, end, strides=None):
return _make.strided_slice(data, list(begin), list(end), list(strides))


def slice_axis(data, axis, begin, end=None):
"""Slice input array along specific axis.
Parameters
----------
data : relay.Expr
The source array to be sliced.
axis : int
Axis to be sliced.
begin: int
The index to begin with in the slicing.
end: int, optional
The index indicating end of the slice.
Returns
-------
ret : relay.Expr
The computed result.
"""
end = end or 0
return _make.slice_axis(data, axis, begin, end)


def slice_like(data, shape_like, axes=None):
"""Slice the first input with respect to the second input.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

from .multibox import *
from .nms import *
from . import _multibox
from . import _vision
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ def compute_multibox_transform_loc(attrs, inputs, _, target):
reg.register_pattern("vision.multibox_detection", OpPattern.OPAQUE)


# Get counts of valid boxes
@reg.register_schedule("vision.get_valid_counts")
def schedule_get_valid_counts(_, outs, target):
"""Schedule definition of get_valid_counts"""
with target:
return topi.generic.schedule_nms(outs)


@reg.register_compute("vision.get_valid_counts")
def compute_get_valid_counts(attrs, inputs, _, target):
"""Compute definition of get_valid_counts"""
score_threshold = get_const_float(attrs.score_threshold)
return topi.vision.get_valid_counts(inputs[0], score_threshold)

reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)


# non-maximum suppression
@reg.register_schedule("vision.nms")
def schedule_nms(_, outs, target):
Expand All @@ -65,12 +82,14 @@ def schedule_nms(_, outs, target):
@reg.register_compute("vision.nms")
def compute_nms(attrs, inputs, _, target):
"""Compute definition of nms"""
overlap_threshold = get_const_float(attrs.overlap_threshold)
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
topk = get_const_int(attrs.topk)
id_index = get_const_int(attrs.id_index)
do_rearrange = bool(get_const_int(attrs.do_rearrange))
return [
topi.vision.nms(inputs[0], inputs[1], overlap_threshold,
force_suppress, topk)
topi.vision.nms(inputs[0], inputs[1], iou_threshold,
force_suppress, topk, id_index, do_rearrange)
]


Expand Down
41 changes: 37 additions & 4 deletions python/tvm/relay/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,37 @@
from __future__ import absolute_import as _abs
from . import _make

def get_valid_counts(data,
score_threshold):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
Parameters
----------
data : relay.Expr
Input data. 3-D tensor with shape [batch_size, num_anchors, 6].
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
Returns
-------
out_tensor : relay.Expr
Rearranged data tensor.
valid_count : relay.Expr
1-D tensor for valid number of boxes.
"""
return _make.get_valid_counts(data, score_threshold)


def nms(data,
valid_count,
overlap_threshold=0.5,
iou_threshold=0.5,
force_suppress=False,
topk=-1):
topk=-1,
id_index=0,
do_rearrange=False):
"""Non-maximum suppression operator for object detection.
Parameters
Expand All @@ -19,7 +45,7 @@ def nms(data,
valid_count : relay.Expr
1-D tensor for valid number of boxes.
overlap_threshold : float, optional
iou_threshold : float, optional
Non-maximum suppression threshold.
force_suppress : bool, optional
Expand All @@ -28,9 +54,16 @@ def nms(data,
topk : int, optional
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
index of the class categories, -1 to disable.
do_rearrange : optional, boolean
Whether to move all valid bounding boxes to the top.
Returns
-------
out : relay.Expr
3-D tensor with shape [batch_size, num_anchors, 6].
"""
return _make.nms(data, valid_count, overlap_threshold, force_suppress, topk)
return _make.nms(data, valid_count, iou_threshold,
force_suppress, topk, id_index, do_rearrange)
Loading

0 comments on commit 7f2712f

Please sign in to comment.