Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ONNX dynamic axes export support with onnx simplifier, make onnx simplifier optional #2856

Merged
merged 3 commits into from
Apr 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path') # from yolov5/models/
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') # ONNX-only
parser.add_argument('--simplify', action='store_true', help='simplify ONNX model') # ONNX-only
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)
Expand Down Expand Up @@ -58,7 +59,7 @@
model.model[-1].export = not opt.grid # set Detect() layer grid export
y = model(img) # dry run

# TorchScript export
# TorchScript export -----------------------------------------------------------------------------------------------
prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
Expand All @@ -69,7 +70,7 @@
except Exception as e:
print(f'{prefix} export failure: {e}')

# ONNX export
# ONNX export ------------------------------------------------------------------------------------------------------
prefix = colorstr('ONNX:')
try:
import onnx
Expand All @@ -87,21 +88,24 @@
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
if opt.simplify:
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(model_onnx,
dynamic_input_shape=opt.dynamic,
input_shapes={'images': list(img.shape)} if opt.dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f}')
except Exception as e:
print(f'{prefix} export failure: {e}')

# CoreML export
# CoreML export ----------------------------------------------------------------------------------------------------
prefix = colorstr('CoreML:')
try:
import coremltools as ct
Expand Down