From ac409f29467fd772d1de7e6983d19de762852c4a Mon Sep 17 00:00:00 2001 From: DavidBaldsiefen Date: Sun, 27 Feb 2022 12:29:39 +0700 Subject: [PATCH] Assert engine precision #6777 --- detect.py | 2 ++ models/common.py | 3 +++ val.py | 1 + 3 files changed, 6 insertions(+) diff --git a/detect.py b/detect.py index 76f67bea1b90..d913921060d5 100644 --- a/detect.py +++ b/detect.py @@ -95,6 +95,8 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) # Half half &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16 supported on limited backends with CUDA + if engine: + assert (model.trt_fp16_input == half), 'model ' + ('requires' if model.trt_fp16_input else 'incompatible with') + ' --half' if pt or jit: model.model.half() if half else model.model.float() diff --git a/models/common.py b/models/common.py index 0dae0244e932..9e265144435f 100644 --- a/models/common.py +++ b/models/common.py @@ -296,6 +296,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): w = str(weights[0] if isinstance(weights, list) else weights) pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults + trt_fp16_input = False w = attempt_download(w) # download if not local if data: # data.yaml path (optional) with open(data, errors='ignore') as f: @@ -348,6 +349,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): shape = tuple(model.get_binding_shape(index)) data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) + if model.binding_is_input(index) and dtype == np.float16: + trt_fp16_input = dtype == np.float16 binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) context = model.create_execution_context() batch_size = bindings['images'].shape[0] diff --git a/val.py b/val.py index 78abbda8231a..68e5d9ff0159 100644 --- a/val.py +++ b/val.py @@ -143,6 +143,7 @@ def run(data, if pt or jit: model.model.half() if half else model.model.float() elif engine: + assert (model.trt_fp16_input == half), 'model ' + ('requires' if model.trt_fp16_input else 'incompatible with') + ' --half' batch_size = model.batch_size else: half = False