diff --git a/export.py b/export.py index 9d6d04967c80..b2f42142e16c 100644 --- a/export.py +++ b/export.py @@ -276,7 +276,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F assert onnx.exists(), f'failed to export ONNX file: {onnx}' LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') - f = str(file).replace('.pt', '.engine') # TensorRT engine file + f = file.with_suffix('.engine') # TensorRT engine file logger = trt.Logger(trt.Logger.INFO) if verbose: logger.min_severity = trt.Logger.Severity.VERBOSE @@ -310,6 +310,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F except Exception as e: LOGGER.info(f'\n{prefix} export failure: {e}') + @torch.no_grad() def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' weights=ROOT / 'yolov5s.pt', # weights path diff --git a/models/common.py b/models/common.py index 72549809c8c3..cbd4ff479885 100644 --- a/models/common.py +++ b/models/common.py @@ -7,7 +7,7 @@ import math import platform import warnings -from collections import namedtuple +from collections import OrderedDict, namedtuple from copy import copy from pathlib import Path @@ -326,14 +326,14 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True): logger = trt.Logger(trt.Logger.INFO) with open(w, 'rb') as f, trt.Runtime(logger) as runtime: model = runtime.deserialize_cuda_engine(f.read()) - bindings = dict() + bindings = OrderedDict() for index in range(model.num_bindings): name = model.get_binding_name(index) dtype = trt.nptype(model.get_binding_dtype(index)) shape = tuple(model.get_binding_shape(index)) data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) - binding_addrs = {n: d.ptr for n, d in bindings.items()} + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) context = model.create_execution_context() batch_size = bindings['images'].shape[0] else: # TensorFlow model (TFLite, pb, saved_model)