diff --git a/gluoncv/torch/model_zoo/pose/directpose_resnet_fpn.py b/gluoncv/torch/model_zoo/pose/directpose_resnet_fpn.py index d2b174a140..387f134d9b 100644 --- a/gluoncv/torch/model_zoo/pose/directpose_resnet_fpn.py +++ b/gluoncv/torch/model_zoo/pose/directpose_resnet_fpn.py @@ -540,8 +540,8 @@ def forward(self, x): prev_features = self.lateral_convs[0](x[0]) results.append(self.output_convs[0](prev_features)) for features, lateral_conv, output_conv in zip(x[1:], self.lateral_convs[1:], self.output_convs[1:]): - top_down_features = F.interpolate(prev_features, scale_factor=2, mode="nearest") lateral_features = lateral_conv(features) + top_down_features = F.interpolate(prev_features, size=lateral_features.shape[-2:], mode="nearest") prev_features = lateral_features + top_down_features if self._fuse_type == "avg": prev_features /= 2 diff --git a/scripts/pose/directpose/demo_directpose.py b/scripts/pose/directpose/demo_directpose.py index 13645d1893..768ac4a7c5 100644 --- a/scripts/pose/directpose/demo_directpose.py +++ b/scripts/pose/directpose/demo_directpose.py @@ -18,6 +18,7 @@ def parse_args(): parser.add_argument('--use-gpu', type=int, default=1) parser.add_argument("--image-width", type=int, default=1280) parser.add_argument("--image-height", type=int, default=736) + parser.add_argument("--score-threshold", type=float, default=0.5) parser.add_argument("--visualize", type=int, default=1, help="enable visualizer") parser.add_argument("--save-output", type=str, default='visualize_output.png', help='Save visualize result to image') parser.add_argument("--verbose", type=int, default=1) @@ -77,6 +78,10 @@ def get_transforms(): img_width=args.image_width, img_height=args.image_height) with torch.no_grad(): predictions = net(images.to(device))[0] + if args.score_threshold > 0: + print(f"Applying score threshold: {args.score_threshold} to detected instances...") + valid_indices = predictions.scores > args.score_threshold + predictions = predictions[valid_indices] print(f"Detected {list(predictions.scores.size())} instances with scores: {predictions.scores}") # visualize if args.visualize or args.save_output: