diff --git a/tools/torch/trace_torchvision.py b/tools/torch/trace_torchvision.py index b1336acd8..1adc0525a 100755 --- a/tools/torch/trace_torchvision.py +++ b/tools/torch/trace_torchvision.py @@ -37,6 +37,7 @@ parser.add_argument('--print-models', action='store_true', help="Print all the available models names and exit") parser.add_argument('--to-dd-native', action='store_true', help="Prepare the model so that the weights can be loaded on native model with dede") parser.add_argument('--to-onnx', action="store_true", help="If specified, export to onnx instead of jit.") +parser.add_argument('--onnx_out', type=str, default="prob", help="Name of onnx output") parser.add_argument('--weights', type=str, help="If not None, these weights will be embedded in the model before exporting") parser.add_argument('-a', "--all", action='store_true', help="Export all available models") parser.add_argument('-v', "--verbose", action='store_true', help="Set logging level to INFO") @@ -336,11 +337,13 @@ def get_detection_input(batch_size=1, img_width=224, img_height=224): # remove extension filename = filename[:-3] + ".onnx" example = get_image_input(args.batch_size, args.img_width, args.img_height) + + # change for detection torch.onnx.export( model, example, filename, export_params=True, verbose=args.verbose, opset_version=11, do_constant_folding=True, - input_names=["input"], output_names=["output"]) + input_names=["input"], output_names=[args.onnx_out]) # dynamic_axes={"input":{0:"batch_size"},"output":{0:"batch_size"}} else: logging.info("Saving to %s", filename)