Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Add RetinaNet Implementation #102

Merged
merged 56 commits into from
Feb 15, 2019
Merged

Conversation

chengyangfu
Copy link
Contributor

@chengyangfu chengyangfu commented Nov 3, 2018

Hi,
This PR contains the RetinaNet implementation. The following table contains the models which use ResNet50 and ResNet101 as the backbones.

GPU : Pascal Titan X
PyTorch : v1.0.0
Inference time is measured when setting batch size as 1.

Model Detectron Accuracy Current Accuracy inference time(s/im) download
RetinaNet_R-50-FPN_1x 35.7 36.3 0.102 model
RetinaNet_R-101-FPN_1x 37.7 38.5 0.123 model
RetinaNet_X-101-32x8d-FPN_1x 39.5 39.8 0.200 model
RetinaNet_R-50-FPN_P5_1x 35.7 36.2 0.097 model
RetinaNet_R-101-FPN_P5_1x 37.7 38.5 0.121 model

Add _C.TEST.DETECTIONS_PER_IMG = 100.
After using DETECTIONS_PER_IMG, the mAP drops 0.1.

Not Implemented parts.

  • Class specific bbox prediction.
  • Softmax Focal Loss

Updated 02/02/2018
Identify the reason why this branch gets higher AP.

Branch Accuracy Difference
This 36.3/55.2/38.9/
19.7/39.9/49.0
BoxCoder(10, 10, 5, 5),
add *4 in classification loss normalization
retinanet-detectron 35.6/55.8/37.7/
19.6/39.3/48.2
BoxCoder(1, 1, 1, 1)

Updated 01/30/2018
After updating PyTorch to v1.0.0, the inference time reduced around 15~20%.
Update the inference time in the table.


Updated 01/26/2018
Add RetinaNet_X-101-32x8d-FPN_1x model.
AP : 39.8
Inferece time : 0.200 second.


Updated 01/25/2018
In my first version, I accidentally used P5 to generate P6 instead of C5 which was used in Detectron and paper.
The following table compares the performances in these two settings.

Model C5 P5
RetinaNet_R-50-FPN_1x 36.3/55.2/38.9/19.7/39.9/49.0 36.2/55.1/38.7/19.7/39.5/48.6
RetinaNet_R-101-FPN_1x 38.5/57.6/41.0/20.8/42.3/51.7 38.5/57.9/41.3/21.0/42.8/51.3

Updated 01/23/2018

Train the model without "divide by 4" in the regression loss.
Performance:

Model AP AP50 AP75 APs APm APl
RetinaNet_R-50-FPN_1x 29.6 45.0 31.5 13.9 31.4 41.2

Updated 11/20/2018

The matching part is slightly different from the Detectron version.
In Detectron matching, anchors with IOU >= 0.5 are considered as positive examples, and anchors with IOU <=0.4 are negative examples. Then for the low qualities matches(best prediction for each gt), Detectron only uses the low-quality examples >= 0.4.

P.S.In Detectron, there are some cases occur in both fg_inds and bg_inds. Although in Line230, Detectron removes all the low-qualities positive examples < 0.4. I think the in Line231, num_fg calculation is not correct.

I also test the threshold used for low-qualities positive examples from 0.5 to 0.0.

threshold AP AP50 AP75 P.S.
0.5 35.5 53.7 38.1
0.4 36.0 54.1 38.6 Detectron version.
0.3 36.1 54.5 38.9
0.2 36.1 54.5 38.7
0.0 36.2 55.0 38.7 This branch.

@facebook-github-bot
Copy link

Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file. In order for us to review and merge your code, please sign up at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need the corporate CLA signed.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Nov 3, 2018
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

@fmassa
Copy link
Contributor

fmassa commented Nov 3, 2018

This is really awesome, thanks a lot for the PR!

I'll have a closer look at it next week, let us know the result of the training!

@chengyangfu
Copy link
Contributor Author

Finishing the training.
RetinaNet with X_101_32x8d backbone model costs too much time for training now. Due to the CVPR submission deadline is coming, our lab does not have extra machines for training this one. If anyone can train this one, I will highly appreciate it!

@fmassa
Copy link
Contributor

fmassa commented Nov 5, 2018

No worries about X_101_32x8d training, we can do it on our side.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once again thanks a lot for this awesome PR!

This is not a complete review yet.

One question I have is that I think we might want to move _C.RETINANET into _C.MODEL.RETINANET, but let's wait until @rbgirshick comment on that.

maskrcnn_benchmark/config/defaults.py Outdated Show resolved Hide resolved
maskrcnn_benchmark/modeling/backbone/fpn.py Outdated Show resolved Hide resolved
maskrcnn_benchmark/modeling/rpn/anchor_generator.py Outdated Show resolved Hide resolved
maskrcnn_benchmark/modeling/rpn/retinanet.py Outdated Show resolved Hide resolved
maskrcnn_benchmark/modeling/rpn/retinanet.py Outdated Show resolved Hide resolved
maskrcnn_benchmark/structures/boxlist_ops.py Outdated Show resolved Hide resolved
@rbgirshick
Copy link

@chengyangfu nice work! Do you know what implementation differences might have caused the improvement in box AP relative to the Detectron implementation?

I'm also curious if you need to use a C++ implementation of sigmoid focal loss or if you can simply use a Python implementation using torch.nn.functional? Ideally it could be simplified to the Python version.

Generate Empty BoxLists instead of [] in retinanet_infer
@chengyangfu
Copy link
Contributor Author

chengyangfu commented Nov 6, 2018

@rbgirshick
For the C++ implementation of sigmoid focal loss, I have tested this in my another project. Python version needs to discard ignored examples first and then calculate focal loss, but C++/CUDA version can combine focal loss and selection. The C++/CUDA version definitely has lower memory footprints and runs faster. The critical part is the selection. If I called labels >=0 and use it to get the positive and negative examples, the performance will drop quickly. Due to a large number of positive and negative examples in the training, I think that's reasonable.

The following is the python version of Focal Loss I tested.

   def forward(self, inputs, targets):
       N = inputs.size(0)
       C = inputs.size(1)
       class_mask = inputs.new_zeros((N, C))
       ids = targets.view(-1, 1)
       class_mask.scatter_(1, ids, 1.)

       class_mask = class_mask[:, 1:]
       inputs = inputs[:, 1:]

       P = torch.sigmoid(inputs)
       PC = P*class_mask + (1-P)*(1-class_mask)
       alpha = self.alpha * class_mask + (1 - self.alpha) * (1 - class_mask)
       focal_weight = alpha * (1 - PC).pow(self.gamma)
       loss = F.binary_cross_entropy_with_logits(inputs, class_mask,
                                                     focal_weight)
       return loss

Add NUM_DETECTIONS_PER_IMAGE
@zimenglan-sysu-512
Copy link
Contributor

hi @laibe
i try to run your yaml files, it encounter OOM (IMS_PER_BATCH=2)

@fmassa
Copy link
Contributor

fmassa commented Feb 14, 2019

Hi @chengyangfu ,

Thanks for the benchmark!

I believe this is a consequence of the operations not being fused in PyTorch 1.0.0. I think there have been some improvements recently that made it better, but I'd need to check.

cc @ailzhang for the performance timings and memory

Let's keep the CUDA implementation for now then, and dispatch to the Python implementation if we the tensor is on CPU, how does that sound?

Then this will be ready to be merged.

Once again, thanks a lot for this awesome contribution!

@laibe
Copy link

laibe commented Feb 14, 2019

@zimenglan-sysu-512 what's your training setup? I was using:
1x GeForce GTX 1080 Ti
CUDA runtime version: 9.0.176
with max mem during training of 5132

@zimenglan-sysu-512
Copy link
Contributor

zimenglan-sysu-512 commented Feb 14, 2019

thanks @laibe
i make a mistake resulting in OOM.
btw, do u use CUDA implementation or Python implementation?

after using cuda version, retinanet_MobileNetV2-96-FPN_1x needs 5146 mem, while retinanet_MobileNetV2-FPN_1x.pth needs 5486 mem.

@ailzhang
Copy link

The numbers might make sense given the current fusion logic in jit, @waochaol @zou3519 could you also help check on the JIT numbers? Thanks!

@chengyangfu
Copy link
Contributor Author

@fmassa
It sounds good to me. I just updated the maskrcnn_benchmark/layers/sigmoid_focal_loss.py according to the suggestion.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!

@fmassa fmassa merged commit 6b1ab01 into facebookresearch:master Feb 15, 2019
@fmassa
Copy link
Contributor

fmassa commented Feb 15, 2019

That's an awesome contribution @chengyangfu , thanks a lot for all your effort!

@buaaMars
Copy link

@chengyangfu
Hi,
Great work!
I have a question on rpn/retinanet/inference.py(77)
Why do you reshape the box_regression when it has been permute_and_flatten just in the last line? According to rpn/utils.py(13), (N, -1, 4) is exactly the shape of box_regression. Do you reshape it again in order to handle with some spacial cases?

thks a lot

@as1392
Copy link

as1392 commented Jul 24, 2019

Ah... I finally realized why model zoo does not have these trained weights... Removing OUT_CHANNELS: 256 from backbone destroyed trained networks? I hope someone update/convert these weights :(

Edit : OK, never mind this comment. It was just giving scores lower than 0.7(on the whole image). Try predeictions on COCO_val2014_000000355257.jpg.

@dedoogong
Copy link

could you please support Sigmoid Focal Loss cuda implementation to run on FP16?

Thank you

@simaiden
Copy link

simaiden commented Sep 9, 2019

Can I use this model to train my custom dataset as in #521 ?

@chenjoya
Copy link
Contributor

supplement performance of retinanet_r101fpn_2x on COCO minival:
AP, AP50, AP75, APs, APm, APl
0.3878, 0.5811, 0.4132, 0.2081, 0.4256, 0.5183

@adizhol
Copy link

adizhol commented Dec 2, 2019

Hi @chengyangfu :)
Why is the focal loss (the sum of the losses) divided by the number of positive labels plus number of labels (N = len(labels))?

retinanet_cls_loss = self.box_cls_loss_func(
            box_cls,
            labels
        ) / (pos_inds.numel() + N)

Lyears pushed a commit to Lyears/maskrcnn-benchmark that referenced this pull request Jun 28, 2020
* Add RetinetNet parameters in cfg.

* hot fix.

* Add the retinanet head module now.

* Add the function to generate the anchors for RetinaNet.

* Add the SigmoidFocalLoss cuda operator.

* Fix the bug in the extra layers.

* Change the normalizer for SigmoidFocalLoss

* Support multiscale in training.

* Add retinannet  training script.

* Add the inference part of RetinaNet.

* Fix the bug when building the extra layers in retinanet.
Update the matching part in retinanet_loss.

* Add the first version of the inference of RetinaNet.
Need to check it again to see if is there any room for speed
improvement.

* Remove the  retinanet_R-50-FPN_2x.yaml first.

* Optimize the retinanet postprocessing.

* quick fix.

* Add script for training RetinaNet with ResNet101 backbone.

* Move cfg.RETINANET to cfg.MODEL.RETINANET

* Remove the variables which are not used.

* revert boxlist_ops.
Generate Empty BoxLists instead of [] in retinanet_infer

* Remove the not used commented lines.
Add NUM_DETECTIONS_PER_IMAGE

* remove the not used codes.

* Move retinanet related files under Modeling/rpn/retinanet

* Add retinanet_X_101_32x8d_FPN_1x.yaml script.
This model is not fully validated. I only trained it around 5000
iterations and everything is fine.

* set RETINANET.PRE_NMS_TOP_N as 0 in level5 (p7), because previous setting may generate zero detections and could cause
the program break.
This part is used in original Detectron setting.

* Fix the rpn only bug when the training ends.

* Minor improvements

* Comments and add Python-only implementation

* Bugfix and remove commented code

* keep the generalized_rcnn same.
Move the build_retinanet inside build_rpn.

* Add USE_C5 in the MODEL.RETINANET

* Add two configs using P5 to generate P6.

* fix the bug when loading the Caffe2 ImageNet pretrained model.

* Reduce the code depulication of RPN loss and RetinaNet loss.

* Remove the comment which is not used.

* Remove the hard coded number of classes.

* share the foward part of rpn inference.

* fix the bug in rpn inference.

* Remove the conditional part in the inference.

* Bug fix: add the utils file for permute and flatten of the box
prediction layers.

* Update the comment.

* quick fix. Adding import cat.

* quick fix: forget including import.

* Adjust the normalization part according to Detectron's setting.

* Use the bbox reg normalization term.

* Clean the code according to recent review.

* Using CUDA version for training now. And the python version for training
on cpu.

* rename the directory to retinanet.

* Make the train and val datasets are consistent with mask r-cnn setting.

* add comment.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.