Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Add RetinaNet Implementation #102

Merged
merged 56 commits into from
Feb 15, 2019
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
e185c79
Add RetinetNet parameters in cfg.
Oct 26, 2018
6167fa4
hot fix.
Oct 26, 2018
99920af
Add the retinanet head module now.
Oct 27, 2018
9e82436
Add the function to generate the anchors for RetinaNet.
Oct 27, 2018
69d5d3a
Add the SigmoidFocalLoss cuda operator.
Oct 28, 2018
587ccd8
Fix the bug in the extra layers.
Oct 28, 2018
89e35b2
Change the normalizer for SigmoidFocalLoss
Oct 28, 2018
bd9a817
Support multiscale in training.
Oct 28, 2018
882655a
Add retinannet training script.
Oct 29, 2018
b5ca053
Add the inference part of RetinaNet.
Oct 31, 2018
a1f7365
Fix the bug when building the extra layers in retinanet.
Nov 2, 2018
6cc7264
Add the first version of the inference of RetinaNet.
Nov 2, 2018
dca2453
Remove the retinanet_R-50-FPN_2x.yaml first.
Nov 2, 2018
ce06ecd
Optimize the retinanet postprocessing.
Nov 3, 2018
615af53
Merge branch 'master' of https://github.com/facebookresearch/maskrcnn…
Nov 3, 2018
a859b1e
quick fix.
Nov 3, 2018
2e9881f
Add script for training RetinaNet with ResNet101 backbone.
Nov 3, 2018
21a84e7
Move cfg.RETINANET to cfg.MODEL.RETINANET
Nov 6, 2018
ee7760f
Remove the variables which are not used.
Nov 6, 2018
adb25d6
revert boxlist_ops.
Nov 6, 2018
b84ff0e
Remove the not used commented lines.
Nov 6, 2018
b3af003
remove the not used codes.
Nov 6, 2018
911196f
Move retinanet related files under Modeling/rpn/retinanet
Nov 6, 2018
5fc4b75
Add retinanet_X_101_32x8d_FPN_1x.yaml script.
Nov 6, 2018
c8c4bc7
set RETINANET.PRE_NMS_TOP_N as 0 in level5 (p7), because previous set…
Nov 10, 2018
cfe06d8
Fix the rpn only bug when the training ends.
Nov 10, 2018
bac17d6
Merge branch 'master' of https://github.com/facebookresearch/maskrcnn…
Nov 19, 2018
9e2baa7
Minor improvements
fmassa Nov 23, 2018
a8c919a
Comments and add Python-only implementation
fmassa Nov 23, 2018
190e132
Bugfix and remove commented code
fmassa Nov 23, 2018
8bab238
keep the generalized_rcnn same.
Nov 25, 2018
51bbb17
Merge branch 'master' of https://github.com/facebookresearch/maskrcnn…
Jan 23, 2019
b328d22
Add USE_C5 in the MODEL.RETINANET
Jan 25, 2019
0dac79d
Merge branch 'retinanet' of https://github.com/chengyangfu/maskrcnn-b…
Jan 25, 2019
328ea98
Add two configs using P5 to generate P6.
Jan 25, 2019
0177419
fix the bug when loading the Caffe2 ImageNet pretrained model.
Jan 28, 2019
128c491
Reduce the code depulication of RPN loss and RetinaNet loss.
Jan 29, 2019
dc73a33
Remove the comment which is not used.
Jan 29, 2019
a51abdc
Remove the hard coded number of classes.
Jan 30, 2019
5c7b391
share the foward part of rpn inference.
Jan 31, 2019
bec7cc1
fix the bug in rpn inference.
Jan 31, 2019
77e3626
Remove the conditional part in the inference.
Jan 31, 2019
a8ba755
Bug fix: add the utils file for permute and flatten of the box
Jan 31, 2019
dbbb6f9
Update the comment.
Jan 31, 2019
997ae29
quick fix. Adding import cat.
Feb 1, 2019
8c93b1d
quick fix: forget including import.
Feb 3, 2019
fb3fe10
fix merge.
Feb 4, 2019
096f0d6
Adjust the normalization part according to Detectron's setting.
Feb 4, 2019
da19923
Pull from maskrcnn-benchmark master branch.
Feb 4, 2019
afd5f0b
Use the bbox reg normalization term.
Feb 13, 2019
c235307
Merge https://github.com/facebookresearch/maskrcnn-benchmark into ret…
Feb 13, 2019
6389130
Clean the code according to recent review.
Feb 13, 2019
26c707e
Using CUDA version for training now. And the python version for training
Feb 14, 2019
45372c4
rename the directory to retinanet.
Feb 14, 2019
3bb285d
Make the train and val datasets are consistent with mask r-cnn setting.
Feb 14, 2019
37e9075
add comment.
Feb 14, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/retina/retinanet_R-101-FPN_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101"
RPN_ONLY: True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think “PRN_ONLY” should be False, due to this option maybe affect the eval process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I think this does not affect anything in the evaluation. RPN_ONLY is not used, once the RETINANET_ON is set as True.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really got some errors.

Problem maybe occur on here:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know. I am on vacation now but will check it once I am back.

RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8


46 changes: 46 additions & 0 deletions configs/retina/retinanet_R-50-FPN_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-50-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
INPUT:
MIN_SIZE_TRAIN: (800,)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 16
52 changes: 52 additions & 0 deletions configs/retina/retinanet_X_101_32x8d_FPN_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d"
RPN_ONLY: True
RETINANET_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
OUT_CHANNELS: 256
RPN:
USE_FPN: True
FG_IOU_THRESHOLD: 0.5
BG_IOU_THRESHOLD: 0.4
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
RESNETS:
STRIDE_IN_1X1: False
NUM_GROUPS: 32
WIDTH_PER_GROUP: 8
RETINANET:
SCALES_PER_OCTAVE: 3
STRADDLE_THRESH: -1
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
INPUT:
MIN_SIZE_TRAIN: (800, )
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
# Assume 4 gpus
BASE_LR: 0.005
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 8


63 changes: 62 additions & 1 deletion maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_C.MODEL = CN()
_C.MODEL.RPN_ONLY = False
_C.MODEL.MASK_ON = False
_C.MODEL.RETINANET_ON = False
_C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"

Expand All @@ -37,7 +38,7 @@
# -----------------------------------------------------------------------------
_C.INPUT = CN()
# Size of the smallest side of the image during training
_C.INPUT.MIN_SIZE_TRAIN = 800 # (800,)
_C.INPUT.MIN_SIZE_TRAIN = (800,) # 800
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are currently not using this I believe, and it conflicts with the changes in the Keypoints.
Maybe I'll just revert this part for now, or add it as a separate PR, are you ok with that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am totally okay with that.

# Maximum size of the side of the image during training
_C.INPUT.MAX_SIZE_TRAIN = 1333
# Size of the smallest side of the image during testing
Expand Down Expand Up @@ -223,6 +224,64 @@
_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64


# ---------------------------------------------------------------------------- #
# RetinaNet Options (Follow the Detectron version)
# ---------------------------------------------------------------------------- #
_C.MODEL.RETINANET = CN()

# This is the number of foreground classes, background is not included.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems strange given that we have 80 classes in COCO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, do we need to keep in two different places the number of classes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the comment is wrong. For the second question, definitely, one place is much better. But I would like it is under MODEL instead of MODEL.ROI_BOX_HEAD.NUM_CLASSES.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, MODEL.ROI_BOX_HEAD.NUM_CLASSES is not a good place to take RetinaNet into account. Let me think about it.

Also, I haven't looked closely at the retinanet implementation of Detectron, but don't we have a background class in the classifier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think MODEL.NUM_CLASSES would be a good choice. But it needs changing many parts of current Faster/Mask R-CNN. I think it would make this PR be too complicated.
Detctron version supports two types. The first one is softmax focal loss(foreground classes + background) and another is sigmoid focal loss(only foreground classes). Sigmoid Focal Loss is used in the original paper and all results are reported based on this version. So, I didn't implement the softmax focal loss.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather keep the ROI_BOX_HEAD.NUM_CLASSES for now, as when it was defined it was a property of ROI_HEAD, and different heads can have different number of classes.

Thanks for the explanation on the second question!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense.

_C.MODEL.RETINANET.NUM_CLASSES = 81

# Anchor aspect ratios to use
_C.MODEL.RETINANET.ANCHOR_SIZES = (32, 64, 128, 256, 512)
_C.MODEL.RETINANET.ASPECT_RATIOS = (0.5, 1.0, 2.0)
_C.MODEL.RETINANET.ANCHOR_STRIDES = (8, 16, 32, 64, 128)
_C.MODEL.RETINANET.STRADDLE_THRESH = 0

# Anchor scales per octave
_C.MODEL.RETINANET.OCTAVE = 2.0
_C.MODEL.RETINANET.SCALES_PER_OCTAVE = 3

# Convolutions to use in the cls and bbox tower
# NOTE: this doesn't include the last conv for logits
_C.MODEL.RETINANET.NUM_CONVS = 4

# Weight for bbox_regression loss
_C.MODEL.RETINANET.BBOX_REG_WEIGHT = 1.0

# Smooth L1 loss beta for bbox regression
_C.MODEL.RETINANET.BBOX_REG_BETA = 0.11

# During inference, #locs to select based on cls score before NMS is performed
# per FPN level
_C.MODEL.RETINANET.PRE_NMS_TOP_N = 1000

# IoU overlap ratio for labeling an anchor as positive
# Anchors with >= iou overlap are labeled positive
_C.MODEL.RETINANET.POSITIVE_OVERLAP = 0.5

# IoU overlap ratio for labeling an anchor as negative
# Anchors with < iou overlap are labeled negative
_C.MODEL.RETINANET.NEGATIVE_OVERLAP = 0.4

# Focal loss parameter: alpha
_C.MODEL.RETINANET.LOSS_ALPHA = 0.25

# Focal loss parameter: gamma
_C.MODEL.RETINANET.LOSS_GAMMA = 2.0

# Prior prob for the positives at the beginning of training. This is used to set
# the bias init for the logits layer
_C.MODEL.RETINANET.PRIOR_PROB = 0.01

# Inference cls score threshold, anchors with score > INFERENCE_TH are
# considered for inference
_C.MODEL.RETINANET.INFERENCE_TH = 0.05

# NMS threshold used in RetinaNet
_C.MODEL.RETINANET.NMS_TH = 0.4

# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -261,6 +320,8 @@
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.TEST.IMS_PER_BATCH = 8
# Number of detections per image
_C.TEST.DETECTIONS_PER_IMG = 100


# ---------------------------------------------------------------------------- #
Expand Down
8 changes: 8 additions & 0 deletions maskrcnn_benchmark/config/paths_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ class DatasetCatalog(object):
DATA_DIR = "datasets"

DATASETS = {
"coco_2017_train": (
"coco/train2017",
"coco/annotations/instances_train2017.json",
),
"coco_2017_val": (
"coco/val2017",
"coco/annotations/instances_val2017.json",
),
"coco_2014_train": (
"coco/train2014",
"coco/annotations/instances_train2014.json",
Expand Down
41 changes: 41 additions & 0 deletions maskrcnn_benchmark/csrc/SigmoidFocalLoss.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include "cpu/vision.h"

#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif

// Interface for Python
at::Tensor SigmoidFocalLoss_forward(
const at::Tensor& logits,
const at::Tensor& targets,
const int num_classes,
const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
#ifdef WITH_CUDA
return SigmoidFocalLoss_forward_cuda(logits, targets, num_classes, gamma, alpha);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}

at::Tensor SigmoidFocalLoss_backward(
const at::Tensor& logits,
const at::Tensor& targets,
const at::Tensor& d_losses,
const int num_classes,
const float gamma,
const float alpha) {
if (logits.type().is_cuda()) {
#ifdef WITH_CUDA
return SigmoidFocalLoss_backward_cuda(logits, targets, d_losses, num_classes, gamma, alpha);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
Loading