diff --git a/tools/export_model.py b/tools/export_model.py index eff81cca..997a674d 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -116,9 +116,10 @@ def export_tensorrt( checkpoint_path: str, score_thresh: float, nms_thresh: float, + version: str, onnx_path: str, engine_path: str, - input_sample: torch.Tensor, + input_sample: torch.Tensor = None, detections_per_img: int = 100, workspace: int = 12, ): @@ -128,6 +129,7 @@ def export_tensorrt( checkpoint_path, score_thresh=score_thresh, nms_thresh=nms_thresh, + version=version, onnx_path=onnx_path, engine_path=engine_path, input_sample=input_sample, @@ -177,6 +179,7 @@ def cli_main(): checkpoint_path, score_thresh=args.score_thresh, nms_thresh=args.nms_thresh, + version=args.version, onnx_path=str(onnx_path), engine_path=str(tensorrt_path), input_sample=input_sample,