Skip to content

Commit

Permalink
Support multiple batches for exporting ONNX with pre-processing (#320)
Browse files Browse the repository at this point in the history
* Add batch_size argument for exporting ONNX with pre-processing

* Fix input_sample setter

* Fix input & output names and dynamic axes for multi-batch exporting

* Add unit-test for multi-batches exporting

* Add more test
  • Loading branch information
zhiqwang authored Feb 16, 2022
1 parent 61a6e62 commit 724f60e
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 30 deletions.
64 changes: 50 additions & 14 deletions test/test_runtime_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def run_model(self, model, inputs_list):
onnx_io = io.BytesIO()

# export to onnx models
export_onnx(onnx_io, model=model, opset_version=_onnx_opset_version)
batch_size = len(inputs_list[0])
export_onnx(onnx_io, model=model, opset_version=_onnx_opset_version, batch_size=batch_size)

# validate the exported model with onnx runtime
for test_inputs in inputs_list:
Expand Down Expand Up @@ -65,35 +66,70 @@ def get_image(self, img_name):
return image

def get_test_images(self):
return [self.get_image("bus.jpg")], [self.get_image("zidane.jpg")]
return self.get_image("bus.jpg"), self.get_image("zidane.jpg")

@pytest.mark.parametrize(
"arch, fixed_size, upstream_version",
[
("yolov5s", False, "r3.1"),
("yolov5m", True, "r4.0"),
("yolov5s", True, "r3.1"),
("yolov5m", False, "r4.0"),
("yolov5n", True, "r6.0"),
("yolov5n", False, "r6.0"),
("yolov5n6", True, "r6.0"),
("yolov5n6", False, "r6.0"),
],
)
def test_onnx_export(self, arch, fixed_size, upstream_version):
images_one, images_two = self.get_test_images()
images_dummy = [torch.ones(3, 1080, 720) * 0.3]
def test_onnx_export_single_image(self, arch, fixed_size, upstream_version):
img_one, img_two = self.get_test_images()
img_dummy = torch.ones(3, 1080, 720) * 0.3

size = (640, 640) if arch[-1] == "6" else (320, 320)
model = models.__dict__[arch](
upstream_version=upstream_version,
pretrained=True,
size=(640, 640),
fixed_shape=(640, 640) if fixed_size else None,
size=size,
fixed_shape=size if fixed_size else None,
score_thresh=0.45,
)
model = model.eval()
model(images_one)
model([img_one])
# Test exported model on images of different size, or dummy input
self.run_model(model, [(images_one,), (images_two,), (images_dummy,)])
self.run_model(model, [[img_one], [img_two], [img_dummy]])

# Test exported model for an image with no detections on other images
self.run_model(model, [(images_dummy,), (images_one,)])
@pytest.mark.parametrize("arch", ["yolov5n6"])
def test_onnx_export_multi_batches(self, arch):
img_one, img_two = self.get_test_images()
img_dummy = torch.ones(3, 1080, 720) * 0.3

size = (640, 640) if arch[-1] == "6" else (320, 320)
model = models.__dict__[arch](pretrained=True, size=size, score_thresh=0.45)
model = model.eval()
model([img_one, img_two])

# Test exported model on images of different size, or dummy input
inputs_list = [
[img_one, img_two],
[img_two, img_one],
[img_dummy, img_one],
[img_one, img_one],
[img_two, img_dummy],
[img_dummy, img_two],
]
self.run_model(model, inputs_list)

@pytest.mark.parametrize("arch", ["yolov5n"])
def test_onnx_export_misbatch(self, arch):
img_one, img_two = self.get_test_images()
img_dummy = torch.ones(3, 640, 480) * 0.3

size = (640, 640) if arch[-1] == "6" else (320, 320)
model = models.__dict__[arch](pretrained=True, size=size, score_thresh=0.45)
model = model.eval()
model([img_one, img_two])

# Test exported model on images of misbatch
with pytest.raises(IndexError, match="list index out of range"):
self.run_model(model, [[img_one, img_two], [img_two, img_one, img_dummy]])

# Test exported model on images of misbatch
with pytest.raises(ValueError, match="Model requires 3 inputs. Input Feed contains 2"):
self.run_model(model, [[img_two, img_one, img_dummy], [img_one, img_two]])
5 changes: 3 additions & 2 deletions tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def get_parser():
default=[640, 640],
help="Image size for inferencing (default: 640, 640).",
)
parser.add_argument("--size_divisible", type=int, default=32, help="Stride for the preprocessing.")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size for YOLOv5.")
parser.add_argument("--size_divisible", type=int, default=32, help="Stride for pre-processing.")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size for pre-processing.")
parser.add_argument("--opset", default=11, type=int, help="Opset version for exporing ONNX models")
parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model.")

Expand Down Expand Up @@ -79,6 +79,7 @@ def cli_main():
version=args.version,
skip_preprocess=args.skip_preprocess,
opset_version=args.opset,
batch_size=args.batch_size,
)


Expand Down
72 changes: 58 additions & 14 deletions yolort/runtime/ort_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

def export_onnx(
onnx_path: str,
*,
checkpoint_path: Optional[str] = None,
model: Optional[nn.Module] = None,
size: Tuple[int, int] = (640, 640),
Expand All @@ -18,6 +19,7 @@ def export_onnx(
version: str = "r6.0",
skip_preprocess: bool = False,
opset_version: int = 11,
batch_size: int = 1,
) -> None:
"""
Export to ONNX models that can be used for ONNX Runtime inferencing.
Expand All @@ -37,6 +39,9 @@ def export_onnx(
skip_preprocess (bool): Skip the preprocessing transformation when exporting the ONNX
models. Default: False
opset_version (int): Opset version for exporting ONNX models. Default: 11
batch_size (int): Only used for models that include pre-processing, you need to specify
the batch sizes and ensure that the number of input images is the same as the batches
when inferring if you want to export multiple batches ONNX models. Default: 1
"""

onnx_builder = ONNXBuilder(
Expand All @@ -49,6 +54,7 @@ def export_onnx(
version=version,
skip_preprocess=skip_preprocess,
opset_version=opset_version,
batch_size=batch_size,
)

onnx_builder.to_onnx(onnx_path)
Expand All @@ -71,6 +77,9 @@ class ONNXBuilder:
skip_preprocess (bool): Skip the preprocessing transformation when exporting the ONNX
models. Default: False
opset_version (int): Opset version for exporting ONNX models. Default: 11
batch_size (int): Only used for models that include pre-processing, you need to specify
the batch sizes and ensure that the number of input images is the same as the batches
when inferring if you want to export multiple batches ONNX models. Default: 1
"""

def __init__(
Expand All @@ -84,6 +93,7 @@ def __init__(
version: str = "r6.0",
skip_preprocess: bool = False,
opset_version: int = 11,
batch_size: int = 1,
) -> None:

super().__init__()
Expand All @@ -96,16 +106,18 @@ def __init__(
# For pre-processing
self._size = size
self._size_divisible = size_divisible
self._batch_size = batch_size
# Define the module
if model is None:
model = self._build_model()
self.model = model

self.opset_version = opset_version
self.input_names = ["images"]
self.output_names = ["scores", "labels", "boxes"]
self.input_sample = self._get_input_sample()
self.dynamic_axes = self._get_dynamic_axes()
# For exporting ONNX model
self._opset_version = opset_version
self.input_names = self._set_input_names()
self.output_names = self._set_output_names()
self.input_sample = self._set_input_sample()
self.dynamic_axes = self._set_dynamic_axes()

def _build_model(self):
if self._skip_preprocess:
Expand All @@ -128,28 +140,60 @@ def _build_model(self):
model = model.eval()
return model

def _get_dynamic_axes(self):
def _set_input_names(self):
if self._skip_preprocess:
return ["images"]
if self._batch_size == 1:
return ["image"]

input_names = []
for i in range(self._batch_size):
input_names.append(f"image{i + 1}")
return input_names

def _set_output_names(self):
if self._skip_preprocess:
return ["scores", "labels", "boxes"]
if self._batch_size == 1:
return ["score", "label", "box"]

output_names = []
for i in range(self._batch_size):
output_names.extend([f"score{i + 1}", f"label{i + 1}", f"box{i + 1}"])
return output_names

def _set_dynamic_axes(self):
if self._skip_preprocess:
return {
"images": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_objects"},
"labels": {0: "batch", 1: "num_objects"},
"scores": {0: "batch", 1: "num_objects"},
}
else:
if self._batch_size == 1:
return {
"images": {1: "height", 2: "width"},
"boxes": {0: "num_objects"},
"labels": {0: "num_objects"},
"scores": {0: "num_objects"},
"image": {1: "height", 2: "width"},
"box": {0: "num_objects"},
"label": {0: "num_objects"},
"score": {0: "num_objects"},
}

def _get_input_sample(self):
dynamic_axes = {}
for i in range(self._batch_size):
dynamic_axes[f"image{i + 1}"] = {1: "height", 2: "width"}
dynamic_axes[f"box{i + 1}"] = {0: "num_objects"}
dynamic_axes[f"label{i + 1}"] = {0: "num_objects"}
dynamic_axes[f"score{i + 1}"] = {0: "num_objects"}
return dynamic_axes

def _set_input_sample(self):
if self._skip_preprocess:
return torch.rand(1, 3, 640, 640)
else:
if self._batch_size == 1:
return [torch.rand(3, 640, 640)]

return [torch.rand(3, 640, 640)] * self._batch_size

@torch.no_grad()
def to_onnx(self, onnx_path: str, **kwargs):
"""
Expand All @@ -165,7 +209,7 @@ def to_onnx(self, onnx_path: str, **kwargs):
self.input_sample,
onnx_path,
do_constant_folding=True,
opset_version=self.opset_version,
opset_version=self._opset_version,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
Expand Down

0 comments on commit 724f60e

Please sign in to comment.