-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] dynamic exportable pixel unshuffle #1637
[Fix] dynamic exportable pixel unshuffle #1637
Conversation
Please fix the CI error |
Codecov ReportBase: 83.72% // Head: 83.89% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #1637 +/- ##
==========================================
+ Coverage 83.72% 83.89% +0.16%
==========================================
Files 227 227
Lines 13027 13027
Branches 2027 2027
==========================================
+ Hits 10907 10929 +22
+ Misses 1760 1742 -18
+ Partials 360 356 -4
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
import torch
import math
import onnx
import onnxruntime
from onnxruntime.quantization import QuantizationMode
def pixel_unshuffle(x, scale):
"""Down-sample by pixel unshuffle.
Args:
x (Tensor): Input tensor.
scale (int): Scale factor.
Returns:
Tensor: Output tensor.
"""
b, c, h, w = x.shape
h = h // scale
w = w // scale
x = x.view(b, c, h, scale, w, scale)
x = x.permute(0, 1, 3, 5, 2, 4)
return x.reshape(b, -1, h, w)
class PixelUnshuffleModel(torch.nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, x):
y = pixel_unshuffle(x, self.scale)
return y
def create_pixel_unshuffle_model(scale):
model = PixelUnshuffleModel(scale)
dummy_input = torch.randn(1, 3, 256, 256)
inputs = (dummy_input,)
output = model(*inputs)
input_names = ['input']
output_names = ['output']
dynamic_axes = {'input': {0: 'batch_size', 2: 'height', 3: 'width'}, 'output': {0: 'batch_size', 2: 'height', 3: 'width'}}
return model, inputs, output, input_names, output_names, dynamic_axes
model, inputs, output, input_names, output_names, dynamic_axes = create_pixel_unshuffle_model(scale=2)
torch.onnx.export(model, inputs, 'pixel_unshuffle.onnx', input_names=input_names, output_names=output_names,
dynamic_axes=dynamic_axes, opset_version=11)
sess = onnxruntime.InferenceSession('pixel_unshuffle.onnx')
x = torch.randn(1, 3, 128, 128)
input_name = sess.get_inputs()[0].name
ort_output = sess.run(None, {input_name: x.numpy()})
torch_output = model(x)
print(torch.allclose(torch_output, torch.tensor(ort_output[0]), rtol=1e-03, atol=1e-05)) Thanks for correcting my mistake. The problem is that the lowest version you support, Torch (1.6), does not support rounding mode for division operations. So I changed it to use a simpler floor div to track the shape of the tensor (since the assert statement guarantees that it is divisible by scale)! thank you. The following script applied the onnx export for dynamic input to the modified model and it worked. thank you. |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
These are changes to support pixel unshuffling for inputs of various sizes during onnx export. The warning is as follows:
More detail could be found in here:
https://pytorch.org/docs/stable/onnx.html#avoid-numpy-and-built-in-python-types
Modification
Please briefly describe what modification is made in this PR.
Who can help? @ them here!
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist
Before PR:
After PR: