Skip to content

Commit

Permalink
TensorRT 7 anchor_grid compatibility fix (ultralytics#6185)
Browse files Browse the repository at this point in the history
* fix: TensorRT 7 incompatiable

* Add comment

* Add if: else and comment

Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
imyhxy and glenn-jocher authored Jan 4, 2022
1 parent 42d1a94 commit 132fc5d
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,13 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
import tensorrt as trt

opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
export_onnx(model, im, file, opset, train, False, simplify)
if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
grid = model.model[-1].anchor_grid
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
export_onnx(model, im, file, opset, train, False, simplify)
model.model[-1].anchor_grid = grid
else: # TensorRT >= 8
export_onnx(model, im, file, opset, train, False, simplify)
onnx = file.with_suffix('.onnx')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'

Expand Down Expand Up @@ -418,12 +424,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
# Exports
if 'torchscript' in include:
export_torchscript(model, im, file, optimize)
if 'engine' in include: # TensorRT required before ONNX
export_engine(model, im, file, train, half, simplify, workspace, verbose)
if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'openvino' in include:
export_openvino(model, im, file)
if 'engine' in include:
export_engine(model, im, file, train, half, simplify, workspace, verbose)
if 'coreml' in include:
export_coreml(model, im, file)

Expand Down

0 comments on commit 132fc5d

Please sign in to comment.