Skip to content

Commit

Permalink
Fix arguments of image_size
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Dec 22, 2021
1 parent f5f6655 commit 600d25a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def get_parser():
)
parser.add_argument(
"--image_size",
default=(320,320),
type=tupe,
help="Image size for evaluation (default: 640).",
nargs="+",
type=int,
default=[640, 640],
help="Image size for evaluation (default: 640, 640).",
)
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
parser.add_argument("--opset", default=DEFAULT_OPSET, type=int, help="opset_version")
Expand Down Expand Up @@ -128,9 +129,10 @@ def cli_main():
checkpoint_path = Path(args.checkpoint_path)
assert checkpoint_path.exists(), f"Not found checkpoint file at '{checkpoint_path}'"

image_size = args.image_size * 2 if len(args.image_size) == 1 else 1 # expand
if args.skip_preprocess:
# input data
inputs = torch.rand(args.batch_size, 3, args.image_size, args.image_size)
inputs = torch.rand(args.batch_size, 3, *image_size)
dynamic_axes = {
"images_tensors": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_objects"},
Expand All @@ -147,7 +149,7 @@ def cli_main():
model.eval()
else:
# input data
images = [torch.rand(3, args.image_size, args.image_size)]
images = [torch.rand(3, *image_size)]
inputs = (images,)
dynamic_axes = {
"images_tensors": {1: "height", 2: "width"},
Expand All @@ -159,7 +161,7 @@ def cli_main():
output_names = ["scores", "labels", "boxes"]
model = YOLOv5.load_from_yolov5(
checkpoint_path,
size=args.image_size,
size=tuple(image_size),
core_thresh=args.score_thresh,
version=args.version,
)
Expand Down

0 comments on commit 600d25a

Please sign in to comment.