Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor export.py #4080

Merged
merged 4 commits into from
Jul 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 80 additions & 68 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,78 @@
from utils.torch_utils import select_device


def export_torchscript(model, img, file, optimize):
# TorchScript model export
prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript.pt')
ts = torch.jit.trace(model, img, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return ts
except Exception as e:
print(f'{prefix} export failure: {e}')


def export_onnx(model, img, file, opset_version, train, dynamic, simplify):
# ONNX model export
prefix = colorstr('ONNX:')
try:
check_requirements(('onnx', 'onnx-simplifier'))
import onnx

print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')
torch.onnx.export(model, img, f, verbose=False, opset_version=opset_version,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
input_names=['images'],
output_names=['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
} if dynamic else None)

# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
if simplify:
try:
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=dynamic,
input_shapes={'images': list(img.shape)} if dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')


def export_coreml(ts_model, img, file, train):
# CoreML model export
prefix = colorstr('CoreML:')
try:
import coremltools as ct

print(f'{prefix} starting export with coremltools {ct.__version__}...')
f = file.with_suffix('.mlmodel')
assert train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`'
model = ct.convert(ts_model, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
model.save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')


def run(weights='./yolov5s.pt', # weights path
img_size=(640, 640), # image (height, width)
batch_size=1, # batch size
Expand All @@ -40,12 +112,13 @@ def run(weights='./yolov5s.pt', # weights path
t = time.time()
include = [x.lower() for x in include]
img_size *= 2 if len(img_size) == 1 else 1 # expand
file = Path(weights)

# Load PyTorch model
device = select_device(device)
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
model = attempt_load(weights, map_location=device) # load FP32 model
labels = model.names
names = model.names

# Input
gs = int(max(model.stride)) # grid size (max stride)
Expand All @@ -57,7 +130,6 @@ def run(weights='./yolov5s.pt', # weights path
img, model = img.half(), model.half() # to FP16
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, Conv): # assign export-friendly activations
if isinstance(m.act, nn.Hardswish):
m.act = Hardswish()
Expand All @@ -72,73 +144,13 @@ def run(weights='./yolov5s.pt', # weights path
y = model(img) # dry runs
print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")

# TorchScript export -----------------------------------------------------------------------------------------------
if 'torchscript' in include or 'coreml' in include:
prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')

# ONNX export ------------------------------------------------------------------------------------------------------
# Exports
if 'onnx' in include:
prefix = colorstr('ONNX:')
try:
import onnx

print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = weights.replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=opset_version,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
input_names=['images'],
output_names=['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
} if dynamic else None)

# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print

# Simplify
if simplify:
try:
check_requirements(['onnx-simplifier'])
import onnxsim

print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=dynamic,
input_shapes={'images': list(img.shape)} if dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{prefix} simplifier failure: {e}')
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')

# CoreML export ----------------------------------------------------------------------------------------------------
if 'coreml' in include:
prefix = colorstr('CoreML:')
try:
import coremltools as ct

print(f'{prefix} starting export with coremltools {ct.__version__}...')
assert train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`'
model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
f = weights.replace('.pt', '.mlmodel') # filename
model.save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
export_onnx(model, img, file, opset_version, train, dynamic, simplify)
if 'torchscript' in include or 'coreml' in include:
ts = export_torchscript(model, img, file, optimize)
if 'coreml' in include:
export_coreml(ts, img, file, train)

# Finish
print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')
Expand Down