From e5c11ff632cfb3bc663d5f9a08d14078d2d80f21 Mon Sep 17 00:00:00 2001 From: Jiacong Fang Date: Fri, 18 Feb 2022 23:36:45 +0800 Subject: [PATCH] Fix TF exports >= 2GB (#6292) * Fix exporting saved_model: pb exceeds 2GB * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace TF v1.x API with TF v2.x API for saved_model export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean up * Remove lambda in tf.function() * Revert "Remove lambda in tf.function()" to be compatible with TF v2.4 This reverts commit 46c7931f11dfdea6ae340c77287c35c30b9e0779. * Fix for pre-commit.ci * Cleanup1 * Cleanup2 * Backwards compatibility update * Update common.py * Update common.py * Cleanup3 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- export.py | 98 ++++++++++++++++++++++-------------------------- models/common.py | 5 ++- 2 files changed, 48 insertions(+), 55 deletions(-) diff --git a/export.py b/export.py index 32c8622c08df..6a8c4f6f94a0 100644 --- a/export.py +++ b/export.py @@ -16,6 +16,10 @@ TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite TensorFlow.js | `tfjs` | yolov5s_web_model/ +Requirements: + $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU + $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU + Usage: $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ... @@ -45,6 +49,7 @@ import subprocess import sys import time +import warnings from pathlib import Path import pandas as pd @@ -239,41 +244,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F except Exception as e: LOGGER.info(f'\n{prefix} export failure: {e}') -def export_keras(model, im, file, dynamic, prefix=colorstr('Keras:')): - # YOLOv5 TensorFlow SavedModel export - try: - import tensorflow as tf - from tensorflow import keras - - from models.keras import TFDetect, KerasModel - - LOGGER.info(f'\n{prefix} starting export with keras {tf.__version__}...') - f = str(file).replace('.pt', '.h5') - batch_size, ch, *imgsz = list(im.shape) # BCHW - - model = KerasModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) - im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for Keras - _ = model.predict(im) # first call to create weights - inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) - outputs = model.predict(inputs) - keras_model = keras.Model(inputs=inputs, outputs=outputs, name="yolov5n") - keras_model.trainable = False - keras_model.summary() - keras_model.save(f, save_format='h5') - - LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') - return keras_model, f - except Exception as e: - LOGGER.info(f'\n{prefix} export failure: {e}') - return None, None def export_saved_model(model, im, file, dynamic, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45, - conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')): + conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')): # YOLOv5 TensorFlow SavedModel export try: import tensorflow as tf - from tensorflow import keras + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 from models.tf import TFDetect, TFModel @@ -282,16 +260,28 @@ def export_saved_model(model, im, file, dynamic, batch_size, ch, *imgsz = list(im.shape) # BCHW tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) - im = tf.ones((batch_size, *imgsz, 3)) # BHWC order for TensorFlow - y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) - y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) - inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) + im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow + _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) + inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) - keras_model = keras.Model(inputs=inputs, outputs=outputs) + keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) keras_model.trainable = False keras_model.summary() - keras_model.save(f, save_format='tf') - + if keras: + keras_model.save(f, save_format='tf') + else: + m = tf.function(lambda x: keras_model(x)) # full model + spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) + m = m.get_concrete_function(spec) + frozen_func = convert_variables_to_constants_v2(m) + tfm = tf.Module() + tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec]) + tfm.__call__(im) + tf.saved_model.save( + tfm, + f, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if + check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions()) LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') return keras_model, f except Exception as e: @@ -358,13 +348,14 @@ def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')): cmd = 'edgetpu_compiler --version' help_url = 'https://coral.ai/docs/edgetpu/compiler/' assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}' - if subprocess.run(cmd, shell=True).returncode != 0: + if subprocess.run(cmd + ' >/dev/null', shell=True).returncode != 0: LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') + sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system for c in ['curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler']: - subprocess.run(c, shell=True, check=True) + subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') @@ -446,16 +437,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) - # Checks - imgsz *= 2 if len(imgsz) == 1 else 1 # expand - opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12 - # Load PyTorch model device = select_device(device) assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0' model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model nc, names = model.nc, model.names # number of classes, class names + # Checks + imgsz *= 2 if len(imgsz) == 1 else 1 # expand + opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12 + assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}' + # Input gs = int(max(model.stride)) # grid size (max stride) imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples @@ -477,10 +469,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' for _ in range(2): y = model(im) # dry runs - LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)") + shape = tuple(y[0].shape) # model output shape + LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") # Exports - f = [''] * 11 # exported filenames + f = [''] * 10 # exported filenames + warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning if 'torchscript' in include: f[0] = export_torchscript(model, im, file, optimize) if 'engine' in include: # TensorRT required before ONNX @@ -510,17 +504,15 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' if tfjs: f[9] = export_tfjs(model, im, file) - if 'keras' in include: - _, f[10] = export_keras(model, im, file, dynamic) - # Finish f = [str(x) for x in f if x] # filter out '' and None - LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' - f"\nResults saved to {colorstr('bold', file.parent.resolve())}" - f"\nVisualize with https://netron.app" - f"\nDetect with `python detect.py --weights {f[-1]}`" - f" or `model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')" - f"\nValidate with `python val.py --weights {f[-1]}`") + if any(f): + LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f"\nDetect: python detect.py --weights {f[-1]}" + f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')" + f"\nValidate: python val.py --weights {f[-1]}" + f"\nVisualize: https://netron.app") return f # return list of exported files/dirs diff --git a/models/common.py b/models/common.py index 38b94129e274..8831723ffa25 100644 --- a/models/common.py +++ b/models/common.py @@ -359,7 +359,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None): if saved_model: # SavedModel LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...') import tensorflow as tf - model = tf.keras.models.load_model(w) + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') import tensorflow as tf @@ -431,7 +432,7 @@ def forward(self, im, augment=False, visualize=False, val=False): else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) if self.saved_model: # SavedModel - y = self.model(im, training=False).numpy() + y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy() elif self.pb: # GraphDef y = self.frozen_func(x=self.tf.constant(im)).numpy() elif self.tflite: # Lite