diff --git a/.gitignore b/.gitignore index df990f7..92f2bb4 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ /common/data.cpp /common/data.h /data/others.txt +/data/*-opt*.onnx # CCS-generated files .launches/ diff --git a/data/KWS-DNN_S-opt.onnx b/data/KWS-DNN_S-opt.onnx deleted file mode 100644 index 75b2751..0000000 Binary files a/data/KWS-DNN_S-opt.onnx and /dev/null differ diff --git a/data/mnist-8-opt.onnx b/data/mnist-8-opt.onnx deleted file mode 100644 index 4fcafe7..0000000 Binary files a/data/mnist-8-opt.onnx and /dev/null differ diff --git a/data/squeezenet_cifar10-opt.onnx b/data/squeezenet_cifar10-opt.onnx deleted file mode 100644 index ab5da69..0000000 Binary files a/data/squeezenet_cifar10-opt.onnx and /dev/null differ diff --git a/transform.py b/transform.py index a87b338..340091d 100755 --- a/transform.py +++ b/transform.py @@ -15,6 +15,7 @@ import onnx import onnx.helper +import onnxoptimizer import numpy as np from utils import ( @@ -264,22 +265,15 @@ def get_prev_node(n): Constants.LEA_BUFFER_SIZE = lea_buffer_size[args.target] onnx_opt_model_name = config['onnx_model'].replace('.onnx', '-opt.onnx') -if os.path.exists(onnx_opt_model_name): - onnx_model = onnx.load(onnx_opt_model_name) -else: - onnx_model = onnx.load(config['onnx_model']) - try: - import onnx.optimizer - # https://zhuanlan.zhihu.com/p/41255090 - onnx_model = onnx.optimizer.optimize(onnx_model, [ - 'fuse_add_bias_into_conv', - 'fuse_matmul_add_bias_into_gemm', - ]) - except IndexError: - # Somehow the optimizer cannot handle models transformed from keras2onnx - pass - onnx_model = onnx.shape_inference.infer_shapes(onnx_model) - onnx.save_model(onnx_model, onnx_opt_model_name) +onnx_model = onnx.load(config['onnx_model']) +# https://zhuanlan.zhihu.com/p/41255090 +onnx_model = onnxoptimizer.optimize(onnx_model, [ + 'fuse_add_bias_into_conv', + 'fuse_matmul_add_bias_into_gemm', +]) + +onnx_model = onnx.shape_inference.infer_shapes(onnx_model) +onnx.save_model(onnx_model, onnx_opt_model_name) g = onnx_model.graph names = {} @@ -301,8 +295,12 @@ def find_initializer(name): return initializer def replace_squeeze(node, inp): - axes_name = node.input[1] - axes = find_initializer(axes_name).int64_data + # Since opset 13, axes is an input instead of an attribute + try: + axes_name = node.input[1] + axes = find_initializer(axes_name).int64_data + except IndexError: + axes = get_attr(node, 'axes') new_dims = [dim for dim_idx, dim in enumerate(inp.dims) if dim_idx not in axes] # Repeated fields cannot be assigned directly # https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-fields