From 61a6e622f9bffc895de1b0c0d13c91dbf6ad26a3 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Tue, 15 Feb 2022 17:42:20 +0800 Subject: [PATCH] Refactor ONNX export and update tutorial (#319) * Remove onnxsim in ONNX export CLI tools * Add ONNXBuidler for exporting ONNX * Remove lr attribute in YOLOv5 * Fixing export_onnx function * Update tutorials * Update ONNX expoter CLI tool * Cleanup unit test --- deployment/onnxruntime/README.md | 12 +- .../export-onnx-inference-onnxruntime.ipynb | 447 +++++++++++------- test/test_runtime.py | 0 test/{test_onnx.py => test_runtime_ort.py} | 59 +-- tools/export_model.py | 158 ++----- yolort/models/yolov5.py | 3 - yolort/runtime/ort_helper.py | 173 +++++++ 7 files changed, 506 insertions(+), 346 deletions(-) delete mode 100644 test/test_runtime.py rename test/{test_onnx.py => test_runtime_ort.py} (67%) create mode 100644 yolort/runtime/ort_helper.py diff --git a/deployment/onnxruntime/README.md b/deployment/onnxruntime/README.md index fc726277..e32526d4 100644 --- a/deployment/onnxruntime/README.md +++ b/deployment/onnxruntime/README.md @@ -13,11 +13,11 @@ The ONNXRuntime inference for `yolort`, both GPU and CPU are supported. ## Features -The `ONNX` model exported with `yolort` differs from the official one in the following three ways. +The ONNX model exported by yolort differs from other pipeline in the following three ways. -- The exported `ONNX` graph now supports dynamic shapes, and we use `(3, H, W)` as the input shape (for example `(3, 640, 640)`). -- We embed the pre-processing ([`letterbox`](https://github.com/ultralytics/yolov5/blob/9ef94940aa5e9618e7e804f0758f9a6cebfc63a9/utils/augmentations.py#L88-L118)) into the graph as well. We only require the input image to be in the `RGB` channel, and to be rescaled to `float32 [0-1]` from general `uint [0-255]`. The main logic we use to implement this mechanism is below. (And [this](https://github.com/zhiqwang/yolov5-rt-stack/blob/b9c67205a61fa0e9d7e6696372c133ea0d36d9db/yolort/models/transform.py#L210-L234) plays the same role of the official `letterbox`, but there will be a little difference in accuracy now.) -- We embed the post-processing (`nms`) into the model graph, which performs the same task as [`non_max_suppression`](https://github.com/ultralytics/yolov5/blob/fad57c29cd27c0fcbc0038b7b7312b9b6ef922a8/utils/general.py#L532-L623) except for the format of the inputs. (And here the `ONNX` graph is required to be dynamic.) +- We embed the pre-processing into the graph (mainly composed of `letterbox`). and the exported model expects a `Tensor[C, H, W]`, which is in `RGB` channel and is rescaled to range `float32 [0-1]`. +- We embed the post-processing into the model graph with `torchvision.ops.batched_nms`. So the outputs of the exported model are straightforward `boxes`, `labels` and `scores` fields of this image. +- We adopt the dynamic shape mechanism to export the ONNX models. ## Usage @@ -39,10 +39,10 @@ The `ONNX` model exported with `yolort` differs from the official one in the fol 1. Export your custom model to ONNX. ```bash - python tools/export_model.py [--checkpoint_path path/to/custom/best.pt] + python tools/export_model.py --checkpoint_path [path/to/your/best.pt] ``` - And then, you can find that a new pair of ONNX models ("best.onnx" and "best.sim.onnx") has been generated in the directory of "best.pt". + And then, you can find that a ONNX model ("best.onnx") have been generated in the directory of "best.pt". 1. \[Optional\] Quick test with the ONNXRuntime Python interface. diff --git a/notebooks/export-onnx-inference-onnxruntime.ipynb b/notebooks/export-onnx-inference-onnxruntime.ipynb index b0d3a14b..7b845c18 100644 --- a/notebooks/export-onnx-inference-onnxruntime.ipynb +++ b/notebooks/export-onnx-inference-onnxruntime.ipynb @@ -4,7 +4,24 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Deploying yolort on ONNX Runtime" + "# Deploying yolort on ONNX Runtime\n", + "\n", + "\n", + "The ONNX model exported by yolort differs from other pipeline in the following three ways.\n", + "\n", + "- We embed the pre-processing into the graph (mainly composed of `letterbox`). and the exported model expects a `Tensor[C, H, W]`, which is in `RGB` channel and is rescaled to range `float32 [0-1]`.\n", + "- We embed the post-processing into the model graph with `torchvision.ops.batched_nms`. So the outputs of the exported model are straightforward `boxes`, `labels` and `scores` fields of this image.\n", + "- We adopt the dynamic shape mechanism to export the ONNX models.\n", + "\n", + "## Set up environment and function utilities\n", + "\n", + "First you should install ONNX Runtime first to run this tutorial. See the ONNX Runtime [installation matrix](https://onnxruntime.ai) for recommended instructions for desired combinations of target operating system, hardware, accelerator, and language.\n", + "\n", + "A quick solution is to install via pip on X64:\n", + "\n", + "```bash\n", + "pip install onnxruntime\n", + "```" ] }, { @@ -13,15 +30,13 @@ "metadata": {}, "outputs": [], "source": [ - "import cv2\n", - "\n", + "import os\n", "import torch\n", - "import onnx\n", - "import onnxruntime\n", "\n", - "from yolort.models import yolov5s\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", "\n", - "from yolort.utils import get_image_from_url, read_image_to_tensor" + "device = torch.device('cpu')" ] }, { @@ -30,19 +45,22 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", + "import cv2\n", + "import onnx\n", + "import onnxruntime\n", "\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "from yolort.models import YOLOv5\n", + "from yolort.v5 import attempt_download\n", "\n", - "device = torch.device('cpu')" + "from yolort.utils import get_image_from_url, read_image_to_tensor\n", + "from yolort.utils.image_utils import to_numpy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Model Definition and Initialization" + "Define some parameters used for defining the model, exporting ONNX models and inferencing on ONNX Runtime." ] }, { @@ -51,17 +69,19 @@ "metadata": {}, "outputs": [], "source": [ - "model = yolov5s(export_friendly=True, pretrained=True, score_thresh=0.45)\n", - "\n", - "model = model.eval()\n", - "model = model.to(device)" + "img_size = 640\n", + "size = (img_size, img_size) # Used for pre-processing\n", + "size_divisible = 64\n", + "score_thresh = 0.35\n", + "nms_thresh = 0.45\n", + "opset_version = 11" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Load images to infer" + "Get images for inferenceing." ] }, { @@ -70,43 +90,126 @@ "metadata": {}, "outputs": [], "source": [ - "img_one = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/bus.jpg\")\n", - "# img_one = cv2.imread('../test/assets/bus.jpg')\n", + "img_src1 = \"https://huggingface.co/spaces/zhiqwang/assets/resolve/main/bus.jpg\"\n", + "img_one = get_image_from_url(img_src1)\n", "img_one = read_image_to_tensor(img_one, is_half=False)\n", "img_one = img_one.to(device)\n", "\n", - "img_two = get_image_from_url(\"https://gitee.com/zhiqwang/yolov5-rt-stack/raw/master/test/assets/zidane.jpg\")\n", - "# img_two = cv2.imread('../test/assets/zidane.jpg')\n", + "img_src2 = \"https://huggingface.co/spaces/zhiqwang/assets/resolve/main/zidane.jpg\"\n", + "img_two = get_image_from_url(img_src2)\n", "img_two = read_image_to_tensor(img_two, is_half=False)\n", - "img_two = img_two.to(device)\n", - "\n", - "# images = [img_one, img_two]\n", - "# Uncomment the above line and comment the next line if you want to\n", - "# use the multi-batch inferencing on onnxruntime\n", - "images = [img_one]" + "img_two = img_two.to(device)" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "tags": [] + }, "source": [ - "### Inference on PyTorch backend" + "## Load the model trained from yolov5\n", + "\n", + "The model used below is officially released by yolov5 and trained on COCO 2017 datasets." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, + "outputs": [], + "source": [ + "# yolov5n6.pt is downloaded from 'https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n6.pt'\n", + "model_path = \"yolov5n6.pt\"\n", + "onnx_path = \"yolov5n6.onnx\"\n", + "checkpoint_path = attempt_download(model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n", - " return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" + "\n", + " from n params module arguments \n", + " 0 -1 1 1760 yolort.v5.models.common.Conv [3, 16, 6, 2, 2] \n", + " 1 -1 1 4672 yolort.v5.models.common.Conv [16, 32, 3, 2] \n", + " 2 -1 1 4800 yolort.v5.models.common.C3 [32, 32, 1] \n", + " 3 -1 1 18560 yolort.v5.models.common.Conv [32, 64, 3, 2] \n", + " 4 -1 2 29184 yolort.v5.models.common.C3 [64, 64, 2] \n", + " 5 -1 1 73984 yolort.v5.models.common.Conv [64, 128, 3, 2] \n", + " 6 -1 3 156928 yolort.v5.models.common.C3 [128, 128, 3] \n", + " 7 -1 1 221568 yolort.v5.models.common.Conv [128, 192, 3, 2] \n", + " 8 -1 1 167040 yolort.v5.models.common.C3 [192, 192, 1] \n", + " 9 -1 1 442880 yolort.v5.models.common.Conv [192, 256, 3, 2] \n", + " 10 -1 1 296448 yolort.v5.models.common.C3 [256, 256, 1] \n", + " 11 -1 1 164608 yolort.v5.models.common.SPPF [256, 256, 5] \n", + " 12 -1 1 49536 yolort.v5.models.common.Conv [256, 192, 1, 1] \n", + " 13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 14 [-1, 8] 1 0 yolort.v5.models.common.Concat [1] \n", + " 15 -1 1 203904 yolort.v5.models.common.C3 [384, 192, 1, False] \n", + " 16 -1 1 24832 yolort.v5.models.common.Conv [192, 128, 1, 1] \n", + " 17 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 18 [-1, 6] 1 0 yolort.v5.models.common.Concat [1] \n", + " 19 -1 1 90880 yolort.v5.models.common.C3 [256, 128, 1, False] \n", + " 20 -1 1 8320 yolort.v5.models.common.Conv [128, 64, 1, 1] \n", + " 21 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 22 [-1, 4] 1 0 yolort.v5.models.common.Concat [1] \n", + " 23 -1 1 22912 yolort.v5.models.common.C3 [128, 64, 1, False] \n", + " 24 -1 1 36992 yolort.v5.models.common.Conv [64, 64, 3, 2] \n", + " 25 [-1, 20] 1 0 yolort.v5.models.common.Concat [1] \n", + " 26 -1 1 74496 yolort.v5.models.common.C3 [128, 128, 1, False] \n", + " 27 -1 1 147712 yolort.v5.models.common.Conv [128, 128, 3, 2] \n", + " 28 [-1, 16] 1 0 yolort.v5.models.common.Concat [1] \n", + " 29 -1 1 179328 yolort.v5.models.common.C3 [256, 192, 1, False] \n", + " 30 -1 1 332160 yolort.v5.models.common.Conv [192, 192, 3, 2] \n", + " 31 [-1, 12] 1 0 yolort.v5.models.common.Concat [1] \n", + " 32 -1 1 329216 yolort.v5.models.common.C3 [384, 256, 1, False] \n", + " 33 [23, 26, 29, 32] 1 164220 yolort.v5.models.yolo.Detect [80, [[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]], [64, 128, 192, 256]]\n", + "/opt/conda/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2157.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", + "Model Summary: 355 layers, 3246940 parameters, 3246940 gradients, 4.6 GFLOPs\n", + "\n" ] } ], + "source": [ + "model = YOLOv5.load_from_yolov5(\n", + " model_path,\n", + " size=size,\n", + " size_divisible=size_divisible,\n", + " score_thresh=score_thresh,\n", + " nms_thresh=nms_thresh,\n", + ")\n", + "\n", + "model = model.eval()\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference on PyTorch backend" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "images = [img_one]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], "source": [ "with torch.no_grad():\n", " model_out = model(images)" @@ -114,15 +217,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 3.44 s, sys: 20 ms, total: 3.46 s\n", - "Wall time: 96.9 ms\n" + "CPU times: user 44.1 s, sys: 160 ms, total: 44.3 s\n", + "Wall time: 1.61 s\n" ] } ], @@ -134,19 +237,19 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[669.26556, 391.30249, 809.86627, 885.23444],\n", - " [ 54.06350, 397.83176, 235.95316, 901.37323],\n", - " [222.88336, 406.81192, 341.55716, 854.77924],\n", - " [ 18.63205, 232.97676, 810.97394, 760.11700]])" + "tensor([[ 32.27846, 225.15266, 811.47729, 740.91071],\n", + " [ 50.42178, 387.48898, 241.54399, 897.61041],\n", + " [219.03331, 386.14346, 345.77689, 869.02582],\n", + " [678.05023, 374.65326, 809.80334, 874.80621]])" ] }, - "execution_count": 7, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -157,16 +260,16 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([0.89005, 0.87333, 0.85366, 0.72340])" + "tensor([0.88238, 0.84486, 0.72629, 0.70077])" ] }, - "execution_count": 8, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -177,16 +280,16 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([0, 0, 0, 5])" + "tensor([5, 0, 0, 0])" ] }, - "execution_count": 9, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -204,16 +307,16 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "from torchvision.ops._register_onnx_ops import _onnx_opset_version" + "from yolort.runtime.ort_helper import export_onnx" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -225,248 +328,276 @@ } ], "source": [ - "export_onnx_name = 'yolov5s.onnx' # path of the exported ONNX models\n", - "\n", - "print(f'We are using opset version: {_onnx_opset_version}')" + "print(f'We are using opset version: {opset_version}')" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:1192: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input images_tensors\n", - " 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n", - "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:1192: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input boxes\n", - " 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n", - "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:1192: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input labels\n", - " 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n", - "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:1192: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input scores\n", - " 'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:31: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", - " stride = torch.as_tensor([stride], dtype=dtype, device=device)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:50: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", - " anchor_grid = torch.as_tensor(anchor_grid, dtype=dtype, device=device)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/anchor_utils.py:79: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", - " shifts = shifts - torch.tensor(0.5, dtype=shifts.dtype, device=device)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:298: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", - " for s, s_orig in zip(new_size, original_size)\n", - "/data/wangzq/yolov5-rt-stack/yolort/models/transform.py:298: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " for s, s_orig in zip(new_size, original_size)\n", - "/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_opset9.py:2766: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.\n", - " \"If indices include negative values, the exported graph will produce incorrect results.\")\n", - "/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_opset9.py:701: UserWarning: This model contains a squeeze operation on dimension 1 on an input with unknown shape. Note that if the size of dimension 1 of the input is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on non-singleton dimensions, it is recommended to export this model using opset version 11 or higher.\n", - " \"version 11 or higher.\")\n" + "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3701: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))\n", + "/coding/yolov5-rt-stack/yolort/models/transform.py:282: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).\n", + " img_h, img_w = _get_shape_onnx(img)\n", + "/coding/yolov5-rt-stack/yolort/models/anchor_utils.py:45: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + " anchors = torch.as_tensor(self.anchor_grids, dtype=torch.float32, device=device).to(dtype=dtype)\n", + "/coding/yolov5-rt-stack/yolort/models/anchor_utils.py:46: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + " strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)\n", + "/coding/yolov5-rt-stack/yolort/models/box_head.py:402: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", + " strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)\n", + "/coding/yolov5-rt-stack/yolort/models/box_head.py:333: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).\n", + " for head_output, grid, shift, stride in zip(head_outputs, grids, shifts, strides):\n", + "/opt/conda/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:2815: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.\n", + " warnings.warn(\"Exporting aten::index operator of advanced indexing in opset \" +\n" ] } ], "source": [ - "# Export to ONNX model\n", - "torch.onnx.export(\n", - " model,\n", - " (images,),\n", - " export_onnx_name,\n", - " do_constant_folding=True,\n", - " opset_version=_onnx_opset_version, \n", - " input_names=[\"images_tensors\"],\n", - " output_names=[\"scores\", \"labels\", \"boxes\"],\n", - " dynamic_axes={\n", - " \"images_tensors\": [0, 1, 2],\n", - " \"boxes\": [0, 1],\n", - " \"labels\": [0],\n", - " \"scores\": [0],\n", - " },\n", - ")" + "export_onnx(model=model, onnx_path=onnx_path, opset_version=opset_version)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Simplify the exported ONNX model (Optional)\n", + "Check the exported ONNX model is well formed" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the ONNX model\n", + "onnx_model = onnx.load(onnx_path)\n", "\n", - "*ONNX* is great, but sometimes too complicated. And thanks to @daquexian for providing a powerful tool named [`onnxsim`](https://github.com/daquexian/onnx-simplifier/) to eliminate some redundant operators.\n", + "# Check that the model is well formed\n", + "onnx.checker.check_model(onnx_model)\n", "\n", - "First of all, let's install `onnx-simplifier` with following script." + "# Print a human readable representation of the graph\n", + "# print(onnx.helper.printable_graph(model.graph))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```shell\n", - "pip install -U onnx-simplifier\n", - "```" + "## Inference on ONNX Runtime backend\n", + "\n", + "Load the exported ONNX model." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Starting simplifing with onnxsim 0.3.6\n" + "Starting with onnx 1.10.2, onnxruntime 1.10.0...\n" ] } ], "source": [ - "import onnxsim\n", + "print(f'Starting with onnx {onnx.__version__}, onnxruntime {onnxruntime.__version__}...')\n", "\n", - "# onnx-simplifier version\n", - "print(f'Starting simplifing with onnxsim {onnxsim.__version__}')" + "ort_session = onnxruntime.InferenceSession(onnx_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Prepare the inputs for ONNX Runtime." ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ - "onnx_simp_name = 'yolov5s.simp.onnx' # path of the simplified ONNX models" + "inputs, _ = torch.jit._flatten(images)\n", + "outputs, _ = torch.jit._flatten(model_out)" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ - "# load your predefined ONNX model\n", - "onnx_model = onnx.load(export_onnx_name)\n", - "\n", - "# convert model\n", - "model_simp, check = onnxsim.simplify(\n", - " onnx_model,\n", - " input_shapes={\"images_tensors\": [3, 640, 640]},\n", - " dynamic_input_shape=True,\n", - ")\n", - "\n", - "assert check, \"Simplified ONNX model could not be validated\"\n", - "\n", - "# use model_simp as a standard ONNX model object\n", - "onnx.save(model_simp, onnx_simp_name)" + "inputs = list(map(to_numpy, inputs))\n", + "outputs = list(map(to_numpy, outputs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Inference on ONNXRuntime Backend\n", - "\n", - "Now, We begin to verify whether the inference results are consistent with PyTorch's, similarly, install `onnxruntime` first." + "Compute onnxruntime output prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))\n", + "ort_outs = ort_session.run(None, ort_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "48.7 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit\n", + "ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))\n", + "ort_outs = ort_session.run(None, ort_inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "```shell\n", - "pip install -U onnxruntime\n", - "```" + "### Verify whether the inference results are consistent with PyTorch's" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Starting with onnx 1.9.0, onnxruntime 1.8.1...\n" + "Exported model has been tested with ONNXRuntime, and the result looks good!\n" ] } ], "source": [ - "print(f'Starting with onnx {onnx.__version__}, onnxruntime {onnxruntime.__version__}...')" + "for i in range(0, len(outputs)):\n", + " torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-04, atol=1e-07)\n", + "\n", + "print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "### Verify another image\n", + "\n", + "When using dynamic shape inference in trace mode, the shape inference mechanism for some operators may not work, so we verify it once for another image with a different shape as well." ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ - "images, _ = torch.jit._flatten(images)\n", - "outputs, _ = torch.jit._flatten(model_out)" + "images = [img_two]" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ - "def to_numpy(tensor):\n", - " if tensor.requires_grad:\n", - " return tensor.detach().cpu().numpy()\n", - " else:\n", - " return tensor.cpu().numpy()" + "with torch.no_grad():\n", + " out_pytorch = model(images)" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ - "inputs = list(map(to_numpy, images))\n", - "outputs = list(map(to_numpy, outputs))" + "inputs, _ = torch.jit._flatten(images)\n", + "outputs, _ = torch.jit._flatten(out_pytorch)" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ - "# ort_session = onnxruntime.InferenceSession(export_onnx_name)\n", - "ort_session = onnxruntime.InferenceSession(onnx_simp_name)" + "inputs = list(map(to_numpy, inputs))\n", + "outputs = list(map(to_numpy, outputs))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute onnxruntime output prediction." ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ - "# compute onnxruntime output prediction\n", - "ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))\n", - "ort_outs = ort_session.run(None, ort_inputs)" + "input_ort = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))\n", + "out_ort = ort_session.run(None, input_ort)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 2.38 s, sys: 20 ms, total: 2.4 s\n", - "Wall time: 65.1 ms\n" + "37.4 ms ± 775 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ - "%%time\n", - "# compute onnxruntime output prediction\n", - "ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))\n", - "ort_outs = ort_session.run(None, ort_inputs)" + "%%timeit\n", + "input_ort = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))\n", + "out_ort = ort_session.run(None, input_ort)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Verify whether the inference results are consistent with PyTorch's." ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -479,7 +610,7 @@ ], "source": [ "for i in range(0, len(outputs)):\n", - " torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-04, atol=1e-07)\n", + " torch.testing.assert_allclose(outputs[i], out_ort[i], rtol=1e-04, atol=1e-07)\n", "\n", "print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")" ] @@ -495,7 +626,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -509,7 +640,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.12" } }, "nbformat": 4, diff --git a/test/test_runtime.py b/test/test_runtime.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/test_onnx.py b/test/test_runtime_ort.py similarity index 67% rename from test/test_onnx.py rename to test/test_runtime_ort.py index 6ff14a97..c7ec6cbf 100644 --- a/test/test_onnx.py +++ b/test/test_runtime_ort.py @@ -10,6 +10,7 @@ from torchvision.io import read_image from torchvision.ops._register_onnx_ops import _onnx_opset_version from yolort import models +from yolort.runtime.ort_helper import export_onnx from yolort.utils.image_utils import to_numpy # In environments without onnxruntime we prefer to @@ -18,15 +19,7 @@ class TestONNXExporter: - def run_model( - self, - model, - inputs_list, - do_constant_folding=True, - input_names=None, - output_names=None, - dynamic_axes=None, - ): + def run_model(self, model, inputs_list): """ The core part of exporting model to ONNX and inference with ONNX Runtime Copy-paste from @@ -34,21 +27,10 @@ def run_model( model = model.eval() onnx_io = io.BytesIO() - if isinstance(inputs_list[0][-1], dict): - torch_onnx_input = inputs_list[0] + ({},) - else: - torch_onnx_input = inputs_list[0] - # export to onnx with the first input - torch.onnx.export( - model, - torch_onnx_input, - onnx_io, - do_constant_folding=do_constant_folding, - opset_version=_onnx_opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - ) + + # export to onnx models + export_onnx(onnx_io, model=model, opset_version=_onnx_opset_version) + # validate the exported model with onnx runtime for test_inputs in inputs_list: with torch.no_grad(): @@ -97,7 +79,7 @@ def get_test_images(self): ("yolov5n6", False, "r6.0"), ], ) - def test_yolort_onnx_export(self, arch, fixed_size, upstream_version): + def test_onnx_export(self, arch, fixed_size, upstream_version): images_one, images_two = self.get_test_images() images_dummy = [torch.ones(3, 1080, 720) * 0.3] @@ -111,28 +93,7 @@ def test_yolort_onnx_export(self, arch, fixed_size, upstream_version): model = model.eval() model(images_one) # Test exported model on images of different size, or dummy input - self.run_model( - model, - [(images_one,), (images_two,), (images_dummy,)], - input_names=["images"], - output_names=["scores", "labels", "boxes"], - dynamic_axes={ - "images": [1, 2], - "boxes": [0, 1], - "labels": [0], - "scores": [0], - }, - ) + self.run_model(model, [(images_one,), (images_two,), (images_dummy,)]) + # Test exported model for an image with no detections on other images - self.run_model( - model, - [(images_dummy,), (images_one,)], - input_names=["images"], - output_names=["scores", "labels", "boxes"], - dynamic_axes={ - "images": [1, 2], - "boxes": [0, 1], - "labels": [0], - "scores": [0], - }, - ) + self.run_model(model, [(images_dummy,), (images_one,)]) diff --git a/tools/export_model.py b/tools/export_model.py index 363af516..e7e95d83 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -1,16 +1,7 @@ import argparse from pathlib import Path -import onnx -import torch -from torchvision.ops._register_onnx_ops import _onnx_opset_version as DEFAULT_OPSET - -try: - import onnxsim -except ImportError: - onnxsim = None - -from yolort.models import YOLO, YOLOv5 +from yolort.runtime.ort_helper import export_onnx def get_parser(): @@ -22,6 +13,12 @@ def get_parser(): required=True, help="The path of checkpoint weights", ) + parser.add_argument( + "--onnx_path", + type=str, + default=None, + help="The path of the exported ONNX models", + ) parser.add_argument( "--skip_preprocess", action="store_true", @@ -33,6 +30,12 @@ def get_parser(): type=float, help="Score threshold used for postprocessing the detections.", ) + parser.add_argument( + "--nms_thresh", + default=0.45, + type=float, + help="IOU threshold used for doing the NMS.", + ) parser.add_argument( "--version", type=str, @@ -40,88 +43,21 @@ def get_parser(): help="Upstream version released by the ultralytics/yolov5, Possible " "values are ['r3.1', 'r4.0', 'r6.0']. Default: 'r6.0'.", ) - parser.add_argument( - "--export_friendly", - action="store_true", - help="Replace torch.nn.SiLU with SiLU.", - ) parser.add_argument( "--image_size", nargs="+", type=int, default=[640, 640], - help="Image size for evaluation (default: 640, 640).", + help="Image size for inferencing (default: 640, 640).", ) - parser.add_argument("--batch_size", default=1, type=int, help="Batch size.") - parser.add_argument("--opset", default=DEFAULT_OPSET, type=int, help="opset_version") + parser.add_argument("--size_divisible", type=int, default=32, help="Stride for the preprocessing.") + parser.add_argument("--batch_size", default=1, type=int, help="Batch size for YOLOv5.") + parser.add_argument("--opset", default=11, type=int, help="Opset version for exporing ONNX models") parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model.") return parser -def export_onnx( - model, - inputs, - export_onnx_path, - dynamic_axes, - input_names=["images_tensors"], - output_names=["scores", "labels", "boxes"], - opset_version=11, - enable_simplify=False, -): - """ - Export the yolort models. - - Args: - model (nn.Module): The model to be exported. - inputs (Tuple[torch.Tensor]): The inputs to the model. - export_onnx_path (str): A string containing a file name. A binary Protobuf - will be written to this file. - dynamic_axes (dict): A dictionary of dynamic axes. - input_names (str): A names list of input names. - output_names (str): A names list of output names. - opset_version (int, default is 11): By default we export the model to the - opset version of the onnx submodule. - enable_simplify (bool, default is False): Whether to enable simplification - of the ONNX model. - """ - torch.onnx.export( - model, - inputs, - export_onnx_path, - do_constant_folding=True, - opset_version=opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - ) - - if enable_simplify: - input_shapes = {input_names[0]: list(inputs[0][0].shape)} - simplify_onnx(export_onnx_path, input_shapes) - - -def simplify_onnx(onnx_path, input_shapes): - if onnxsim is None: - raise ImportError("onnx-simplifier not found and is required by yolort.") - - print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") - - # Load onnx mode - onnx_model = onnx.load(onnx_path) - - # Simplify the ONNX model - model_sim, check = onnxsim.simplify( - onnx_model, - input_shapes=input_shapes, - dynamic_input_shape=True, - ) - - assert check, "There is something error when simplifying ONNX model" - export_onnx_sim_path = onnx_path.with_suffix(".sim.onnx") - onnx.save(model_sim, export_onnx_sim_path) - - def cli_main(): parser = get_parser() args = parser.parse_args() @@ -129,58 +65,20 @@ def cli_main(): checkpoint_path = Path(args.checkpoint_path) assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'" - image_size = args.image_size - image_size *= 2 if len(args.image_size) == 1 else 1 # auto expand - - if args.skip_preprocess: - # input data - inputs = torch.rand(args.batch_size, 3, *image_size) - dynamic_axes = { - "images_tensors": {0: "batch", 2: "height", 3: "width"}, - "boxes": {0: "batch", 1: "num_objects"}, - "labels": {0: "batch", 1: "num_objects"}, - "scores": {0: "batch", 1: "num_objects"}, - } - input_names = ["images_tensors"] - output_names = ["scores", "labels", "boxes"] - model = YOLO.load_from_yolov5( - checkpoint_path, - score_thresh=args.score_thresh, - version=args.version, - ) - model.eval() - else: - # input data - images = [torch.rand(3, *image_size)] - inputs = (images,) - dynamic_axes = { - "images_tensors": {1: "height", 2: "width"}, - "boxes": {0: "num_objects"}, - "labels": {0: "num_objects"}, - "scores": {0: "num_objects"}, - } - input_names = ["images_tensors"] - output_names = ["scores", "labels", "boxes"] - model = YOLOv5.load_from_yolov5( - checkpoint_path, - size=tuple(image_size), - score_thresh=args.score_thresh, - version=args.version, - ) - model.eval() - - # export ONNX models - export_onnx_path = checkpoint_path.with_suffix(".onnx") + # Save the ONNX model path in the same directory of the checkpoint if not determined + onnx_path = args.onnx_path + onnx_path = onnx_path or checkpoint_path.with_suffix(".onnx") export_onnx( - model, - inputs, - export_onnx_path, - dynamic_axes, - input_names=input_names, - output_names=output_names, + onnx_path, + checkpoint_path=checkpoint_path, + size=tuple(args.image_size), + size_divisible=args.size_divisible, + score_thresh=args.score_thresh, + nms_thresh=args.nms_thresh, + version=args.version, + skip_preprocess=args.skip_preprocess, opset_version=args.opset, - enable_simplify=args.simplify, ) diff --git a/yolort/models/yolov5.py b/yolort/models/yolov5.py index 33f7ee38..c941d011 100644 --- a/yolort/models/yolov5.py +++ b/yolort/models/yolov5.py @@ -261,7 +261,6 @@ def load_from_yolov5( cls, checkpoint_path: str, *, - lr: float = 0.01, size: Tuple[int, int] = (640, 640), size_divisible: int = 32, fixed_shape: Optional[Tuple[int, int]] = None, @@ -273,7 +272,6 @@ def load_from_yolov5( Args: checkpoint_path (str): Path of the YOLOv5 checkpoint model. - lr (float): The initial learning rate size: (Tuple[int, int]): the minimum and maximum size of the image to be rescaled. Default: (640, 640) size_divisible (int): stride of the models. Default: 32 @@ -285,7 +283,6 @@ def load_from_yolov5( """ model = YOLO.load_from_yolov5(checkpoint_path, **kwargs) yolov5 = cls( - lr=lr, model=model, size=size, size_divisible=size_divisible, diff --git a/yolort/runtime/ort_helper.py b/yolort/runtime/ort_helper.py new file mode 100644 index 00000000..ff613f3a --- /dev/null +++ b/yolort/runtime/ort_helper.py @@ -0,0 +1,173 @@ +# Copyright (c) 2022, yolort team. All rights reserved. + +from typing import Optional, Tuple + +import torch +from torch import nn +from yolort.models import YOLO, YOLOv5 + + +def export_onnx( + onnx_path: str, + checkpoint_path: Optional[str] = None, + model: Optional[nn.Module] = None, + size: Tuple[int, int] = (640, 640), + size_divisible: int = 32, + score_thresh: float = 0.25, + nms_thresh: float = 0.45, + version: str = "r6.0", + skip_preprocess: bool = False, + opset_version: int = 11, +) -> None: + """ + Export to ONNX models that can be used for ONNX Runtime inferencing. + + Args: + onnx_path (string): The path to the ONNX graph to be exported. + checkpoint_path (string, optional): Path of the custom trained YOLOv5 checkpoint. + Default: None + model (nn.Module): The defined PyTorch module to be exported. Default: None + size: (Tuple[int, int]): the minimum and maximum size of the image to be rescaled. + Default: (640, 640) + size_divisible (int): Stride in the preprocessing. Default: 32 + score_thresh (float): Score threshold used for postprocessing the detections. + Default: 0.25 + nms_thresh (float): NMS threshold used for postprocessing the detections. Default: 0.45 + version (string): Upstream YOLOv5 version. Default: 'r6.0' + skip_preprocess (bool): Skip the preprocessing transformation when exporting the ONNX + models. Default: False + opset_version (int): Opset version for exporting ONNX models. Default: 11 + """ + + onnx_builder = ONNXBuilder( + checkpoint_path=checkpoint_path, + model=model, + size=size, + size_divisible=size_divisible, + score_thresh=score_thresh, + nms_thresh=nms_thresh, + version=version, + skip_preprocess=skip_preprocess, + opset_version=opset_version, + ) + + onnx_builder.to_onnx(onnx_path) + + +class ONNXBuilder: + """ + YOLOv5 wrapper for exporting ONNX models. + + Args: + checkpoint_path (string): Path of the custom trained YOLOv5 checkpoint. + model (nn.Module): The defined PyTorch module to be exported. Default: None + size: (Tuple[int, int]): the minimum and maximum size of the image to be rescaled. + Default: (640, 640) + size_divisible (int): Stride in the preprocessing. Default: 32 + score_thresh (float): Score threshold used for postprocessing the detections. + Default: 0.25 + nms_thresh (float): NMS threshold used for postprocessing the detections. Default: 0.45 + version (string): Upstream YOLOv5 version. Default: 'r6.0' + skip_preprocess (bool): Skip the preprocessing transformation when exporting the ONNX + models. Default: False + opset_version (int): Opset version for exporting ONNX models. Default: 11 + """ + + def __init__( + self, + checkpoint_path: Optional[str] = None, + model: Optional[nn.Module] = None, + size: Tuple[int, int] = (640, 640), + size_divisible: int = 32, + score_thresh: float = 0.25, + nms_thresh: float = 0.45, + version: str = "r6.0", + skip_preprocess: bool = False, + opset_version: int = 11, + ) -> None: + + super().__init__() + self._checkpoint_path = checkpoint_path + self._version = version + # For post-processing + self._score_thresh = score_thresh + self._nms_thresh = nms_thresh + self._skip_preprocess = skip_preprocess + # For pre-processing + self._size = size + self._size_divisible = size_divisible + # Define the module + if model is None: + model = self._build_model() + self.model = model + + self.opset_version = opset_version + self.input_names = ["images"] + self.output_names = ["scores", "labels", "boxes"] + self.input_sample = self._get_input_sample() + self.dynamic_axes = self._get_dynamic_axes() + + def _build_model(self): + if self._skip_preprocess: + model = YOLO.load_from_yolov5( + self._checkpoint_path, + score_thresh=self._score_thresh, + nms_thresh=self._nms_thresh, + version=self._version, + ) + else: + model = YOLOv5.load_from_yolov5( + self._checkpoint_path, + size=self._size, + size_divisible=self._size_divisible, + score_thresh=self._score_thresh, + nms_thresh=self._nms_thresh, + version=self._version, + ) + + model = model.eval() + return model + + def _get_dynamic_axes(self): + if self._skip_preprocess: + return { + "images": {0: "batch", 2: "height", 3: "width"}, + "boxes": {0: "batch", 1: "num_objects"}, + "labels": {0: "batch", 1: "num_objects"}, + "scores": {0: "batch", 1: "num_objects"}, + } + else: + return { + "images": {1: "height", 2: "width"}, + "boxes": {0: "num_objects"}, + "labels": {0: "num_objects"}, + "scores": {0: "num_objects"}, + } + + def _get_input_sample(self): + if self._skip_preprocess: + return torch.rand(1, 3, 640, 640) + else: + return [torch.rand(3, 640, 640)] + + @torch.no_grad() + def to_onnx(self, onnx_path: str, **kwargs): + """ + Saves the model in ONNX format. + + Args: + onnx_path (string): The path to the ONNX graph to be exported. + **kwargs: Will be passed to torch.onnx.export function. + """ + + torch.onnx.export( + self.model, + self.input_sample, + onnx_path, + do_constant_folding=True, + opset_version=self.opset_version, + input_names=self.input_names, + output_names=self.output_names, + dynamic_axes=self.dynamic_axes, + **kwargs, + )