diff --git a/detect.py b/detect.py index 14cdf96ca9db..e6e74ea7dfeb 100644 --- a/detect.py +++ b/detect.py @@ -38,6 +38,7 @@ @torch.no_grad() def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam + data=ROOT / 'data/coco128.yaml', # dataset.yaml path imgsz=(640, 640), # inference size (height, width) conf_thres=0.25, # confidence threshold iou_thres=0.45, # NMS IOU threshold @@ -76,7 +77,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s) # Load model device = select_device(device) - model = DetectMultiBackend(weights, device=device, dnn=dnn) + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine imgsz = check_img_size(imgsz, s=stride) # check image size @@ -204,6 +205,7 @@ def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam') + parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') diff --git a/export.py b/export.py index 600e0c318f33..a0758010e816 100644 --- a/export.py +++ b/export.py @@ -248,6 +248,24 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te LOGGER.info(f'\n{prefix} export failure: {e}') +def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')): + # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ + try: + cmd = 'edgetpu_compiler --version' + out = subprocess.run(cmd, shell=True, capture_output=True, check=True) + ver = out.stdout.decode().split()[-1] + LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') + f = str(file).replace('.pt', '-int8_edgetpu.tflite') + f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model + + cmd = f"edgetpu_compiler -s {f_tfl}" + subprocess.run(cmd, shell=True, check=True) + + LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') + except Exception as e: + LOGGER.info(f'\n{prefix} export failure: {e}') + + def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): # YOLOv5 TensorFlow.js export try: @@ -285,6 +303,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')): def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): + # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt try: check_requirements(('tensorrt',)) import tensorrt as trt @@ -356,7 +375,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' ): t = time.time() include = [x.lower() for x in include] - tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports + tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # Checks @@ -405,15 +424,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' # TensorFlow Exports if any(tf_exports): - pb, tflite, tfjs = tf_exports[1:] + pb, tflite, edgetpu, tfjs = tf_exports[1:] assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs, agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all, conf_thres=conf_thres, iou_thres=iou_thres) # keras model if pb or tfjs: # pb prerequisite to tfjs export_pb(model, im, file) - if tflite: - export_tflite(model, im, file, int8=int8, data=data, ncalib=100) + if tflite or edgetpu: + export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100) + if edgetpu: + export_edgetpu(model, im, file) if tfjs: export_tfjs(model, im, file) diff --git a/models/common.py b/models/common.py index 4fd608f4b3e2..b53de7001454 100644 --- a/models/common.py +++ b/models/common.py @@ -17,6 +17,7 @@ import requests import torch import torch.nn as nn +import yaml from PIL import Image from torch.cuda import amp @@ -276,7 +277,7 @@ def forward(self, x): class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=None, dnn=False): + def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript @@ -284,6 +285,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False): # TensorFlow: *_saved_model # TensorFlow: *.pb # TensorFlow Lite: *.tflite + # TensorFlow Edge TPU: *_edgetpu.tflite # ONNX Runtime: *.onnx # OpenCV DNN: *.onnx with dnn=True # TensorRT: *.engine @@ -297,6 +299,9 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False): pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults w = attempt_download(w) # download if not local + if data: # data.yaml path (optional) + with open(data, errors='ignore') as f: + names = yaml.safe_load(f)['names'] # class names if jit: # TorchScript LOGGER.info(f'Loading {w} for TorchScript inference...') @@ -343,7 +348,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False): binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) context = model.create_execution_context() batch_size = bindings['images'].shape[0] - else: # TensorFlow model (TFLite, pb, saved_model) + else: # TensorFlow (TFLite, pb, saved_model) if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...') import tensorflow as tf @@ -425,6 +430,7 @@ def forward(self, im, augment=False, visualize=False, val=False): y[..., 1] *= h # y y[..., 2] *= w # w y[..., 3] *= h # h + y = torch.tensor(y) if isinstance(y, np.ndarray) else y return (y, []) if val else y diff --git a/val.py b/val.py index 4eec499d3029..c1fcf61b468c 100644 --- a/val.py +++ b/val.py @@ -124,7 +124,7 @@ def run(data, (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Load model - model = DetectMultiBackend(weights, device=device, dnn=dnn) + model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data) stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine imgsz = check_img_size(imgsz, s=stride) # check image size half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA