Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
compute and schedule updated for yolo reorg
Browse files Browse the repository at this point in the history
siju-samuel committed Nov 22, 2018

Verified

This commit was signed with the committer’s verified signature.
hlts2 Hiroto Funakoshi
1 parent 5a81ae0 commit 68a637f
Showing 4 changed files with 52 additions and 12 deletions.
5 changes: 4 additions & 1 deletion include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
@@ -53,9 +53,12 @@ struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
}
};

/*! \brief Attributes used in yolo reorg operators */
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
IndexExpr stride;
Integer stride;

TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") {
TVM_ATTR_FIELD(stride)
1 change: 1 addition & 0 deletions python/tvm/relay/op/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -5,3 +5,4 @@
from .multibox import *
from .nms import *
from .yolo import *
from ._yolo import *
9 changes: 9 additions & 0 deletions python/tvm/relay/op/vision/_yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import
from ..op import register_schedule, register_pattern
from ..op import schedule_injective, OpPattern

# reorg
register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE)
register_schedule("vision.yolo_reorg", schedule_injective)
49 changes: 38 additions & 11 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
""" Support level5 operator test cases.
"""
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list
import topi.testing

def test_resize_infer_type():
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
@@ -70,22 +73,46 @@ def test_nms():
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(
(n, num_anchors, 6), "float32")
def test_yolo_reorg():


def test_yolo_reorg_infer_shape():
def verify_yolo_reorg(shape, stride, out_shape):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.vision.yolo_reorg(x, stride=stride)
zz = relay.ir_pass.infer_type(z)
assert "stride=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")

n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, 20, 20), "float32"))
z = relay.vision.yolo_reorg(x, stride=10)
zz = relay.ir_pass.infer_type(z)
assert "stride=10" in z.astext()
assert zz.checked_type == relay.ty.TensorType((n, c*10*10, 2, 2), "float32")
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))

x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
z = relay.vision.yolo_reorg(x, stride=2)
assert "stride=2" in z.astext()
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((n, c*2*2, h/2, w/2), "float32")
def test_yolo_reorg():
def verify_yolo_reorg(shape, stride):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = topi.testing.reorg_python(x_data, stride)

x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.vision.yolo_reorg(x, stride=stride)
zz = relay.ir_pass.infer_type(z)
assert "stride=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")

func = relay.Function([x], z)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
verify_yolo_reorg((1, 100, 20, 20), 10)
verify_yolo_reorg((1, 4, 6, 6), 2)

if __name__ == "__main__":
test_resize_infer_type()
test_multibox_prior()
test_nms()
test_yolo_reorg_infer_shape()
test_yolo_reorg()

0 comments on commit 68a637f

Please sign in to comment.