From 21a5306dd1d9f5c416f63d343523ad1f60f00730 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 16 Feb 2022 11:49:10 +0800 Subject: [PATCH 1/5] Add batch_size argument for exporting ONNX with pre-processing --- tools/export_model.py | 5 +-- yolort/runtime/ort_helper.py | 67 ++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/tools/export_model.py b/tools/export_model.py index e7e95d83..ca48f9b6 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -50,8 +50,8 @@ def get_parser(): default=[640, 640], help="Image size for inferencing (default: 640, 640).", ) - parser.add_argument("--size_divisible", type=int, default=32, help="Stride for the preprocessing.") - parser.add_argument("--batch_size", default=1, type=int, help="Batch size for YOLOv5.") + parser.add_argument("--size_divisible", type=int, default=32, help="Stride for pre-processing.") + parser.add_argument("--batch_size", default=1, type=int, help="Batch size for pre-processing.") parser.add_argument("--opset", default=11, type=int, help="Opset version for exporing ONNX models") parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model.") @@ -79,6 +79,7 @@ def cli_main(): version=args.version, skip_preprocess=args.skip_preprocess, opset_version=args.opset, + batch_size=args.batch_size, ) diff --git a/yolort/runtime/ort_helper.py b/yolort/runtime/ort_helper.py index ff613f3a..515ed67a 100644 --- a/yolort/runtime/ort_helper.py +++ b/yolort/runtime/ort_helper.py @@ -9,6 +9,7 @@ def export_onnx( onnx_path: str, + *, checkpoint_path: Optional[str] = None, model: Optional[nn.Module] = None, size: Tuple[int, int] = (640, 640), @@ -18,6 +19,7 @@ def export_onnx( version: str = "r6.0", skip_preprocess: bool = False, opset_version: int = 11, + batch_size: int = 1, ) -> None: """ Export to ONNX models that can be used for ONNX Runtime inferencing. @@ -37,6 +39,9 @@ def export_onnx( skip_preprocess (bool): Skip the preprocessing transformation when exporting the ONNX models. Default: False opset_version (int): Opset version for exporting ONNX models. Default: 11 + batch_size (int): Only used for models that include pre-processing, you need to specify + the batch sizes and ensure that the number of input images is the same as the batches + when inferring if you want to export multiple batches ONNX models. Default: 1 """ onnx_builder = ONNXBuilder( @@ -49,6 +54,7 @@ def export_onnx( version=version, skip_preprocess=skip_preprocess, opset_version=opset_version, + batch_size=batch_size, ) onnx_builder.to_onnx(onnx_path) @@ -71,6 +77,9 @@ class ONNXBuilder: skip_preprocess (bool): Skip the preprocessing transformation when exporting the ONNX models. Default: False opset_version (int): Opset version for exporting ONNX models. Default: 11 + batch_size (int): Only used for models that include pre-processing, you need to specify + the batch sizes and ensure that the number of input images is the same as the batches + when inferring if you want to export multiple batches ONNX models. Default: 1 """ def __init__( @@ -84,6 +93,7 @@ def __init__( version: str = "r6.0", skip_preprocess: bool = False, opset_version: int = 11, + batch_size: int = 1, ) -> None: super().__init__() @@ -96,16 +106,18 @@ def __init__( # For pre-processing self._size = size self._size_divisible = size_divisible + self._batch_size = batch_size # Define the module if model is None: model = self._build_model() self.model = model - self.opset_version = opset_version - self.input_names = ["images"] - self.output_names = ["scores", "labels", "boxes"] - self.input_sample = self._get_input_sample() - self.dynamic_axes = self._get_dynamic_axes() + # For exporting ONNX model + self._opset_version = opset_version + self.input_names = self._set_input_names() + self.output_names = self._set_output_names() + self.input_sample = self._set_input_sample() + self.dynamic_axes = self._set_dynamic_axes() def _build_model(self): if self._skip_preprocess: @@ -128,7 +140,23 @@ def _build_model(self): model = model.eval() return model - def _get_dynamic_axes(self): + def _set_input_names(self): + if self._skip_preprocess: + return ["images"] + if self._batch_size == 1: + return ["image"] + + return ["images1", "images2"] + + def _set_output_names(self): + if self._skip_preprocess: + return ["scores", "labels", "boxes"] + if self._batch_size == 1: + return ["score", "label", "box"] + + return ["scores1", "labels1", "boxes1", "scores2", "labels2", "boxes2"] + + def _set_dynamic_axes(self): if self._skip_preprocess: return { "images": {0: "batch", 2: "height", 3: "width"}, @@ -136,19 +164,30 @@ def _get_dynamic_axes(self): "labels": {0: "batch", 1: "num_objects"}, "scores": {0: "batch", 1: "num_objects"}, } - else: + if self._batch_size == 1: return { - "images": {1: "height", 2: "width"}, - "boxes": {0: "num_objects"}, - "labels": {0: "num_objects"}, - "scores": {0: "num_objects"}, + "image": {1: "height", 2: "width"}, + "box": {0: "num_objects"}, + "label": {0: "num_objects"}, + "score": {0: "num_objects"}, } - def _get_input_sample(self): + return { + "images1": {1: "height", 2: "width"}, + "images2": {1: "height", 2: "width"}, + "boxes1": {0: "num_objects"}, + "labels1": {0: "num_objects"}, + "scores1": {0: "num_objects"}, + "boxes2": {0: "num_objects"}, + "labels2": {0: "num_objects"}, + "scores2": {0: "num_objects"}, + } + + def _set_input_sample(self): if self._skip_preprocess: return torch.rand(1, 3, 640, 640) else: - return [torch.rand(3, 640, 640)] + return [torch.rand(3, 640, 640)] * 2 @torch.no_grad() def to_onnx(self, onnx_path: str, **kwargs): @@ -165,7 +204,7 @@ def to_onnx(self, onnx_path: str, **kwargs): self.input_sample, onnx_path, do_constant_folding=True, - opset_version=self.opset_version, + opset_version=self._opset_version, input_names=self.input_names, output_names=self.output_names, dynamic_axes=self.dynamic_axes, From 3e9aa497a1d565790045c1b556fe87ab4f8adfb0 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 16 Feb 2022 11:54:11 +0800 Subject: [PATCH 2/5] Fix input_sample setter --- yolort/runtime/ort_helper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/yolort/runtime/ort_helper.py b/yolort/runtime/ort_helper.py index 515ed67a..353fc7e4 100644 --- a/yolort/runtime/ort_helper.py +++ b/yolort/runtime/ort_helper.py @@ -186,8 +186,10 @@ def _set_dynamic_axes(self): def _set_input_sample(self): if self._skip_preprocess: return torch.rand(1, 3, 640, 640) - else: - return [torch.rand(3, 640, 640)] * 2 + if self._batch_size == 1: + return [torch.rand(3, 640, 640)] + + return [torch.rand(3, 640, 640)] * self._batch_size @torch.no_grad() def to_onnx(self, onnx_path: str, **kwargs): From c0c98ab3d9f509e4972e5624ff2f5cbaa42c8521 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 16 Feb 2022 12:38:06 +0800 Subject: [PATCH 3/5] Fix input & output names and dynamic axes for multi-batch exporting --- yolort/runtime/ort_helper.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/yolort/runtime/ort_helper.py b/yolort/runtime/ort_helper.py index 353fc7e4..5ba21cc0 100644 --- a/yolort/runtime/ort_helper.py +++ b/yolort/runtime/ort_helper.py @@ -146,7 +146,10 @@ def _set_input_names(self): if self._batch_size == 1: return ["image"] - return ["images1", "images2"] + input_names = [] + for i in range(self._batch_size): + input_names.append(f"image{i + 1}") + return input_names def _set_output_names(self): if self._skip_preprocess: @@ -154,7 +157,10 @@ def _set_output_names(self): if self._batch_size == 1: return ["score", "label", "box"] - return ["scores1", "labels1", "boxes1", "scores2", "labels2", "boxes2"] + output_names = [] + for i in range(self._batch_size): + output_names.extend([f"score{i + 1}", f"label{i + 1}", f"box{i + 1}"]) + return output_names def _set_dynamic_axes(self): if self._skip_preprocess: @@ -172,16 +178,13 @@ def _set_dynamic_axes(self): "score": {0: "num_objects"}, } - return { - "images1": {1: "height", 2: "width"}, - "images2": {1: "height", 2: "width"}, - "boxes1": {0: "num_objects"}, - "labels1": {0: "num_objects"}, - "scores1": {0: "num_objects"}, - "boxes2": {0: "num_objects"}, - "labels2": {0: "num_objects"}, - "scores2": {0: "num_objects"}, - } + dynamic_axes = {} + for i in range(self._batch_size): + dynamic_axes[f"image{i + 1}"] = {1: "height", 2: "width"} + dynamic_axes[f"box{i + 1}"] = {0: "num_objects"} + dynamic_axes[f"label{i + 1}"] = {0: "num_objects"} + dynamic_axes[f"score{i + 1}"] = {0: "num_objects"} + return dynamic_axes def _set_input_sample(self): if self._skip_preprocess: From b6928b14ada30c21fd645b08463a0126070c65dc Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 16 Feb 2022 13:07:01 +0800 Subject: [PATCH 4/5] Add unit-test for multi-batches exporting --- test/test_runtime_ort.py | 46 ++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/test/test_runtime_ort.py b/test/test_runtime_ort.py index c7ec6cbf..d9aed5d0 100644 --- a/test/test_runtime_ort.py +++ b/test/test_runtime_ort.py @@ -29,7 +29,8 @@ def run_model(self, model, inputs_list): onnx_io = io.BytesIO() # export to onnx models - export_onnx(onnx_io, model=model, opset_version=_onnx_opset_version) + batch_size = len(inputs_list[0]) + export_onnx(onnx_io, model=model, opset_version=_onnx_opset_version, batch_size=batch_size) # validate the exported model with onnx runtime for test_inputs in inputs_list: @@ -65,35 +66,52 @@ def get_image(self, img_name): return image def get_test_images(self): - return [self.get_image("bus.jpg")], [self.get_image("zidane.jpg")] + return self.get_image("bus.jpg"), self.get_image("zidane.jpg") @pytest.mark.parametrize( "arch, fixed_size, upstream_version", [ - ("yolov5s", False, "r3.1"), - ("yolov5m", True, "r4.0"), + ("yolov5s", True, "r3.1"), ("yolov5m", False, "r4.0"), ("yolov5n", True, "r6.0"), ("yolov5n", False, "r6.0"), - ("yolov5n6", True, "r6.0"), ("yolov5n6", False, "r6.0"), ], ) - def test_onnx_export(self, arch, fixed_size, upstream_version): - images_one, images_two = self.get_test_images() - images_dummy = [torch.ones(3, 1080, 720) * 0.3] + def test_onnx_export_single_image(self, arch, fixed_size, upstream_version): + img_one, img_two = self.get_test_images() + img_dummy = torch.ones(3, 1080, 720) * 0.3 + size = (640, 640) if arch[-1] == "6" else (320, 320) model = models.__dict__[arch]( upstream_version=upstream_version, pretrained=True, - size=(640, 640), - fixed_shape=(640, 640) if fixed_size else None, + size=size, + fixed_shape=size if fixed_size else None, score_thresh=0.45, ) model = model.eval() - model(images_one) + model([img_one]) # Test exported model on images of different size, or dummy input - self.run_model(model, [(images_one,), (images_two,), (images_dummy,)]) + self.run_model(model, [[img_one], [img_two], [img_dummy]]) - # Test exported model for an image with no detections on other images - self.run_model(model, [(images_dummy,), (images_one,)]) + @pytest.mark.parametrize("arch", ["yolov5n6"]) + def test_onnx_export_multi_batches(self, arch): + img_one, img_two = self.get_test_images() + img_dummy = torch.ones(3, 1080, 720) * 0.3 + + size = (640, 640) if arch[-1] == "6" else (320, 320) + model = models.__dict__[arch](pretrained=True, size=size, score_thresh=0.45) + model = model.eval() + model([img_one, img_two]) + + # Test exported model on images of different size, or dummy input + inputs_list = [ + [img_one, img_two], + [img_two, img_one], + [img_dummy, img_one], + [img_one, img_one], + [img_two, img_dummy], + [img_dummy, img_two], + ] + self.run_model(model, inputs_list=inputs_list) From 2e0ffe3784aa1a61190d58e82237e21338e5d1a0 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 16 Feb 2022 13:44:53 +0800 Subject: [PATCH 5/5] Add more test --- test/test_runtime_ort.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/test/test_runtime_ort.py b/test/test_runtime_ort.py index d9aed5d0..deab6779 100644 --- a/test/test_runtime_ort.py +++ b/test/test_runtime_ort.py @@ -114,4 +114,22 @@ def test_onnx_export_multi_batches(self, arch): [img_two, img_dummy], [img_dummy, img_two], ] - self.run_model(model, inputs_list=inputs_list) + self.run_model(model, inputs_list) + + @pytest.mark.parametrize("arch", ["yolov5n"]) + def test_onnx_export_misbatch(self, arch): + img_one, img_two = self.get_test_images() + img_dummy = torch.ones(3, 640, 480) * 0.3 + + size = (640, 640) if arch[-1] == "6" else (320, 320) + model = models.__dict__[arch](pretrained=True, size=size, score_thresh=0.45) + model = model.eval() + model([img_one, img_two]) + + # Test exported model on images of misbatch + with pytest.raises(IndexError, match="list index out of range"): + self.run_model(model, [[img_one, img_two], [img_two, img_one, img_dummy]]) + + # Test exported model on images of misbatch + with pytest.raises(ValueError, match="Model requires 3 inputs. Input Feed contains 2"): + self.run_model(model, [[img_two, img_one, img_dummy], [img_one, img_two]])