Skip to content

Commit

Permalink
fix resnet divisible issue (#1658)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold authored May 13, 2021
1 parent 40c5c0e commit ca18cec
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gluoncv/torch/model_zoo/pose/directpose_resnet_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ def forward(self, x):
prev_features = self.lateral_convs[0](x[0])
results.append(self.output_convs[0](prev_features))
for features, lateral_conv, output_conv in zip(x[1:], self.lateral_convs[1:], self.output_convs[1:]):
top_down_features = F.interpolate(prev_features, scale_factor=2, mode="nearest")
lateral_features = lateral_conv(features)
top_down_features = F.interpolate(prev_features, size=lateral_features.shape[-2:], mode="nearest")
prev_features = lateral_features + top_down_features
if self._fuse_type == "avg":
prev_features /= 2
Expand Down
5 changes: 5 additions & 0 deletions scripts/pose/directpose/demo_directpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def parse_args():
parser.add_argument('--use-gpu', type=int, default=1)
parser.add_argument("--image-width", type=int, default=1280)
parser.add_argument("--image-height", type=int, default=736)
parser.add_argument("--score-threshold", type=float, default=0.5)
parser.add_argument("--visualize", type=int, default=1, help="enable visualizer")
parser.add_argument("--save-output", type=str, default='visualize_output.png', help='Save visualize result to image')
parser.add_argument("--verbose", type=int, default=1)
Expand Down Expand Up @@ -77,6 +78,10 @@ def get_transforms():
img_width=args.image_width, img_height=args.image_height)
with torch.no_grad():
predictions = net(images.to(device))[0]
if args.score_threshold > 0:
print(f"Applying score threshold: {args.score_threshold} to detected instances...")
valid_indices = predictions.scores > args.score_threshold
predictions = predictions[valid_indices]
print(f"Detected {list(predictions.scores.size())} instances with scores: {predictions.scores}")
# visualize
if args.visualize or args.save_output:
Expand Down

0 comments on commit ca18cec

Please sign in to comment.