diff --git a/references/detection/train.py b/references/detection/train.py index c39ae3e723c..7aa71314230 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -96,7 +96,8 @@ def main(args): "trainable_backbone_layers": args.trainable_backbone_layers } if "rcnn" in args.model: - kwargs["rpn_score_thresh"] = args.rpn_score_thresh + if args.rpn_score_thresh is not None: + kwargs["rpn_score_thresh"] = args.rpn_score_thresh model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, **kwargs) model.to(device) @@ -179,9 +180,9 @@ def main(args): parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) - parser.add_argument('--rpn-score-thresh', default=0.0, type=float, help='rpn score threshold for faster-rcnn') + parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn') parser.add_argument('--trainable-backbone-layers', default=None, type=int, - help='number of trainable layers of backbone ') + help='number of trainable layers of backbone') parser.add_argument( "--test-only", dest="test_only", diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 54001fb4f76..8b1d6952271 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -414,7 +414,8 @@ def fasterrcnn_mobilenet_v3_large(pretrained=False, progress=True, num_classes=9 def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, min_size=320, max_size=640, **kwargs): + trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05, + **kwargs): """ Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. @@ -435,6 +436,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone max_size (int): maximum size of the image to be rescaled before feeding it to the backbone + rpn_score_thresh (float): during inference, only return proposals with a classification score + greater than rpn_score_thresh """ trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) @@ -448,7 +451,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - min_size=min_size, max_size=max_size, **kwargs) + min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress) model.load_state_dict(state_dict)