diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 0bbad3f0b0e3f..0133bcb1b4d85 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -53,9 +53,12 @@ struct NMSAttrs : public tvm::AttrsNode{ .describe("Suppress all detections regardless of class_id."); TVM_ATTR_FIELD(topk).set_default(-1) .describe("Keep maximum top k detections before nms, -1 for no limit."); + } +}; + /*! \brief Attributes used in yolo reorg operators */ struct YoloReorgAttrs : public tvm::AttrsNode { - IndexExpr stride; + Integer stride; TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") { TVM_ATTR_FIELD(stride) diff --git a/python/tvm/relay/op/vision/__init__.py b/python/tvm/relay/op/vision/__init__.py index 14c49a37e8f62..8dd876275e6a7 100644 --- a/python/tvm/relay/op/vision/__init__.py +++ b/python/tvm/relay/op/vision/__init__.py @@ -5,3 +5,4 @@ from .multibox import * from .nms import * from .yolo import * +from ._yolo import * diff --git a/python/tvm/relay/op/vision/_yolo.py b/python/tvm/relay/op/vision/_yolo.py new file mode 100644 index 0000000000000..749ebfa26dd09 --- /dev/null +++ b/python/tvm/relay/op/vision/_yolo.py @@ -0,0 +1,9 @@ +#pylint: disable=invalid-name, unused-argument +"""Backend compiler related feature registration""" +from __future__ import absolute_import +from ..op import register_schedule, register_pattern +from ..op import schedule_injective, OpPattern + +# reorg +register_pattern("vision.yolo_reorg", OpPattern.INJECTIVE) +register_schedule("vision.yolo_reorg", schedule_injective) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index c86f20be9e609..33d525029b310 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1,7 +1,10 @@ """ Support level5 operator test cases. """ +import numpy as np import tvm from tvm import relay +from tvm.relay.testing import ctx_list +import topi.testing def test_resize_infer_type(): n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") @@ -70,22 +73,46 @@ def test_nms(): zz = relay.ir_pass.infer_type(z) assert zz.checked_type == relay.ty.TensorType( (n, num_anchors, 6), "float32") -def test_yolo_reorg(): + + +def test_yolo_reorg_infer_shape(): + def verify_yolo_reorg(shape, stride, out_shape): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.vision.yolo_reorg(x, stride=stride) + zz = relay.ir_pass.infer_type(z) + assert "stride=" in z.astext() + assert zz.checked_type == relay.ty.TensorType(out_shape, "float32") + n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") - x = relay.var("x", relay.TensorType((n, c, 20, 20), "float32")) - z = relay.vision.yolo_reorg(x, stride=10) - zz = relay.ir_pass.infer_type(z) - assert "stride=10" in z.astext() - assert zz.checked_type == relay.ty.TensorType((n, c*10*10, 2, 2), "float32") + verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2)) + verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2)) - x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) - z = relay.vision.yolo_reorg(x, stride=2) - assert "stride=2" in z.astext() - zz = relay.ir_pass.infer_type(z) - assert zz.checked_type == relay.ty.TensorType((n, c*2*2, h/2, w/2), "float32") +def test_yolo_reorg(): + def verify_yolo_reorg(shape, stride): + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = topi.testing.reorg_python(x_data, stride) + + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.vision.yolo_reorg(x, stride=stride) + zz = relay.ir_pass.infer_type(z) + assert "stride=" in z.astext() + assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32") + + func = relay.Function([x], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + verify_yolo_reorg((1, 100, 20, 20), 10) + verify_yolo_reorg((1, 4, 6, 6), 2) if __name__ == "__main__": test_resize_infer_type() test_multibox_prior() test_nms() + test_yolo_reorg_infer_shape() test_yolo_reorg()