-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[YOLOX-TI] ERROR: onnx_op_name: /head/ScatterND #269
Comments
Knowing that TI's model is rather verbose, I optimized it independently and created a script to replace all https://github.com/PINTO0309/PINTO_model_zoo/tree/main/363_YOLO-6D-Pose |
Thank you for your quick response |
I will be home with my parents today, tomorrow, and the day after, so I will not be able to provide detailed testing or assistance. |
Thanks for the heads up! Testing this on my own on a detection model, not on pose. Let's see if I manage to get it working. The eval result on both models is as follows:
|
Ok. As I didn't see ScatterND in the original model, I checked what the differences where. I found out that this def meshgrid(*tensors):
if _TORCH_VER >= [1, 10]:
return torch.meshgrid(*tensors, indexing="ij")
else:
return torch.meshgrid(*tensors)
def decode_outputs(self, outputs, dtype):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
strides.append(torch.full((*shape, 1), stride))
grids = torch.cat(grids, dim=1).type(dtype)
strides = torch.cat(strides, dim=1).type(dtype)
outputs = torch.cat([
(outputs[..., 0:2] + grids) * strides,
torch.exp(outputs[..., 2:4]) * strides,
outputs[..., 4:]
], dim=-1)
return outputs gives: While this: def (self, outputs, dtype):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
grid = torch.stack((xv, yv), 2).view(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
strides.append(torch.full((*shape, 1), stride))
grids = torch.cat(grids, dim=1).type(dtype)
strides = torch.cat(strides, dim=1).type(dtype)
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
return outputs gives: This as well as some other minor fixes make it possible to get rid of ScatterND completely. |
Excellent. Perhaps the overall size of the model should be significantly smaller. 64-bit index values are almost always overly precise. However, since the computational efficiency of |
The model performance did not decrease after the changes and for the first time I got results on one of the quantized models (
But still nothing for the |
Feel free to play around with it 😄 |
I can't see the structure of the model today, but I believe there were a couple of What if the model transformation is stopped just before post-processing? However, it is difficult to measure mAP. e.g.
It's an interesting topic and I'd like to try it myself, but I can't easily try it right now. |
You are right @PINTO0309 . I missed this: output = torch.cat(
[reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
) which in the ONNX model is represented as: then in the TFLite models these But why is the dynamic range quantized model working and not the rest of the quantized models? |
If I remember correctly, dynamic range is less prone to accuracy degradation because it recalculates the quantization range each time; compared to INT8 full quantization, the inference speed would have been very slow in exchange for maintaining accuracy. I may be wrong because I do not have an accurate grasp of recent quantization specifications. By the way, |
Maybe a bit out of topic. Anyways, I am using the official TFLite benchmark tool for the exported models and on the specific android device i I am running this on I get that the Float32 models is much faster that the dynamically quantized one. |
People are getting the same quantization problems with YOLOv8 ultralytics/ultralytics#1447: |
But then I guess that the only option we have is to perform the |
@mikel-brostrom Just before the last Concat, xywh seems to have a distribution of (min, max)~(0.0, 416.0). On the other hand, scores have a much narrower distribution of (min, max) = (0.0, 1.0) because of sigmoid. In TFLite quantization, activation is quantized in per-tensor manner. That is, the OR distribution of xywh and scores, (min, max) = (0.0, 416.0), is mapped to integer values of (min, max) = (0, 255) after the Concat. As a result, even if the score is 1.0, after quantization it is mapped to: int(1.0 / 416 * 255) = int(0.61) = 0, resulting in all scores being zero! A possible solution is to divide xywh tensors by the image size (416) to keep it in the range (min, max) ~ (0.0, 1.0) and then concat with the score tensor so that scores are not "collapsed" due to the per-tensor quantization. The same workaround is done in YOLOv5: |
This was super helpful @motokimura! Will try this out |
I hope this helps.. |
Get it! |
No change on the INT8 models @motokimura after implementing what you suggested... Still the same results for all the TFLite models, so the problem may primarily be in an operation or set of operations |
hmm.. It may be helpful to export onnx without '--export-det' option and compare the int8 and float outputs. |
First, let me tell you that your results will vary greatly depending on the architecture of the CPU you are using for your verification. If you are using an Intel x64(x86) or AMD x64(x86) architecture CPU, the Float32 model should be able to reason about 10 times faster than the INT8 model. INT8 models are very slow on the x64 architecture. Perhaps the RaspberryPi's ARM64 CPU 4 threads would be 10 times faster. The keyword XNNPACK is a good way to search for information. In the case of Intel's x64 architecture, CPUs of the 10th generation or later differ from CPUs of the 9th generation or earlier in the presence or absence of an optimization mechanism for processing Integer. If you are using a 10th generation or later CPU, it should run about 20% faster. Therefore, when benchmarking using benchmarking tools, it is recommended to try to do so on ARM64 devices. The benchmarking in the discussion on the ultralytics thread is not appropriate. Next, let's look at dynamic range quantization.
Next, we discuss post-quantization accuracy degradation. I think motoki's point is mostly correct. I think you should first try to split the model at the red line and see how the accuracy changes. If the |
I just cut the model at the point you suggested by: onnx2tf -i /datadrive/mikel/yolox_tflite_export/yolox_nano.onnx -b 1 -cotof -cotoa 1e-1 -onimc /head/Concat_6_output_0 But I get the following error: File "/datadrive/mikel/yolox_tflite_export/env/lib/python3.8/site-packages/onnx2tf/utils/common_functions.py", line 3071, in onnx_tf_tensor_validation
onnx_tensor_shape = onnx_tensor.shape
AttributeError: 'NoneType' object has no attribute 'shape' I couldn't find a similar issue and I had the same problem when I tried to cut YOLOX in our previous discussion. I probably misinterpreted how the tool is supposed to be used... |
I compiled the benchmark binary for android_arm64. The device has a Exynos9810 which is arm 64-bit. It contains a Mali-G72MP18 GPU. However, I am running the model without GPU accelerators, so the INT8 model must be running on CPU. The CPU got released 2018 so that may explain why the quantized model is that slow... |
I came home and tried the same conversion as you. The following command did not generate an error. It is a little strange that the situation is different in your environment and mine. Since
|
In/out quantization from top-left to bottom-right of the operations you pointed at: quantization: -3.1056954860687256 ≤ 0.00014265520439948887 * q ≤ 4.674383163452148
quantization: -3.1056954860687256 ≤ 0.00014265520439948887 * q ≤ 4.674383163452148
quantization: -2.3114538192749023 ≤ 0.00010453650611452758 * q ≤ 3.4253478050231934
quantization: 0.00014265520439948887 * q
quantization: -2.2470905780792236 ≤ 0.00011867172725033015 * q ≤ 3.888516426086426
quantization: 0.00014265520439948887 * q
quantization: 0.00014265520439948887 * q
quantization: -3.1056954860687256 ≤ 0.00014265520439948887 * q ≤ 4.674383163452148
|
It looks fine to me. |
Going for a full COCO eval now 🚀 |
Great! 🚀🚀 |
Great that we get this into YOLOv8 as well @motokimura! Thank you both for this joint effort ❤️
|
congratulations! 👍 |
I will close this issue once the original problem has been solved and the INT8 quantization problem seems to have been resolved. |
Sorry for bothering you again but one thing is still unclear to me. Even when bringing the results are much worse than using separate From our lengthy discussion I recall this:
and this:
Which makes total sense to me. Specially given the disparity in the different ranges within the same output. But why are the quantization results much worse for the model with a single output given that the values have the same range for all values? Does this make sense to you?
|
There is no part of the model left to explain in more detail than Motoki's explanation, but again, take a good look at the quantization parameters around the final output of the model. I think you can see why All The values diverge when inverse quantization (
Perhaps that is why TI used |
In your inference code posted in this comment,
The first dim of However, this should decrease the accuracy of float models as well.. |
Yup, sorry @motokimura, that's a typo. It is outputs[:, :, 0:4] = outputs[:, :, 0:4] * 416 |
I have no idea what is happening in Concat.. As I posted, you may find something if you compare the distribution of outputs from float/int8 models. |
@mikel-brostrom
Assumption: xy and/or wh may have a few outliers which make quantization range much wider than we expected. Especially wh can have such outliers because Exp is used as activation function. |
Good point @motokimura. Reporting back on Monday 😊 |
Interesting. It actually made it worse...
|
At this point I have no idea more than this comment about the quantization of Concat and what kind of quantization errors are happening inside actually.. This Concat is not necessary by nature and has no benefit for the model quantization, so I think we don't need go any deeper with this. All I can say at this point is that tensors with very different value ranges should not be concatenated, especially in post-processing of the model. Thank you for doing the experiment and sharing your results! |
Agree, let's close this. Enough experimentation on this topic 😄 . Again, thank you both @motokimura, @PINTO0309 for time and guidance during this quantization journey. I learnt a lot, hopefully you got something out of the experiment results posted here as well 🙏 |
Issue Type
Others
onnx2tf version number
1.8.1
onnx version number
1.13.1
tensorflow version number
2.12.0
Download URL for ONNX
yolox_nano_ti_lite_26p1_41p8.zip
Parameter Replacement JSON
Description
Hi @PINTO0309. After our lengthy discussion regarding INT8 YOLOX export I decided to try out Ti's version of these models (https://github.com/TexasInstruments/edgeai-yolox/tree/main/pretrained_models). It looked to me that you manged to INT8-export those so maybe you could provide some hints 😄. I just downloaded the ONNX version of YOLOX-nano. For this model, the following fails:
The error I get:
The text was updated successfully, but these errors were encountered: