Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add A Single Line to Save Your Life #6867

Merged
merged 1 commit into from
Jan 5, 2022
Merged

Conversation

imyhxy
Copy link
Contributor

@imyhxy imyhxy commented Dec 23, 2021

I know the title is a bit frivolous, but I think it satisfies this PR.

Thanks for your excellent works, the mmdetection is a very flexible and powerful deep learning framework.

Threads number is insane

Currently, I was implementing the popular YOLOv5 model with mmdetection, but when I try to train my new model, I found out that the training speed is much slower than the official implementation of YOLOv5 (almost ~4x slower). I can't afford that time overhead. So I digged into it for some day and notice that the mmdetection always spawn thousands of threads for data preprocessing. Compare to mmdetection, offical YOLOv5 only fork hundreds of threads (~2000 v.s. 400). Because of such insane number of threads, almost half of the CPU workload is occupied by the kernel, which used to manager the threads.

The CPU is too busy to handle the threads making the whole training speed is slow, and the system also get very lagging. The reason is that the cv2 module will activate multi-processing automatically and the processes number is equal to the CPU core. Which means if you have a more powerful CPU, it hurts the performance harder. My server has a 64-core CPU and 8 T4 GPU, if I set the workers_per_gpu to 4, the number of cv2-threads will be 8*4*64 = 2048. If I set the workers_per_gpu to 8, there will be 4096 theads. Each of this will overwhleming my system.

Lower the workers_per_gpu won't help

I have tried to lower the workers_per_gpu, but that just don't fixed the kernel overhead issue, and also slow down the data preprocessing speed. So it't hard to find an optimized workers_per_gpu number for good training speed because there is no one.

Solution

The solution is trival, disable the multi-processsing of cv2 manually, just involve the cv2.setNumThreads(0) anywhere before involve the cv2.

Comparison

My setup:

sys.platform: ubuntu 20.04                                                                                                                                                                                                       
Python: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0]                                                                                                                                                                
CUDA available: True                                                                                                                                                                                                       
GPU 0,1,2,3,4,5,6,7: Tesla T4                                                                                                                                                                                              
CUDA_HOME: /usr/local/cuda                                                                                                                                                                                                                                                                                                                                                                              
PyTorch: 1.10.0a0+git36449ea
TorchVision: 0.11.0a0+fa347eb                                                                                                                                                                                              
OpenCV: 4.5.4                                                                                                                                                                                                              
MMCV: 1.4.0                                                                                                                                                                                                                
MMCV Compiler: GCC 9.3                                                                                                                                                                                                     
MMCV CUDA Compiler: 11.3                                                                                                                                                                                                   
MMDetection: 2.19.0+cf9cd06
Model name: Intel(R) Xeon(R) Gold 5218 CPU @ 2.30GHz

I have do some before-after experiments:

RetinaNet

The preprocessing of retinanet is so slim, and the default workers_per_gpu for retinanet is 2, so the speed gain is none. If people use retinanet to do performance test, they will not notice the threads and kernel overhead problem.

before:

# ps
%Cpu(s): 12.4 us,  7.6 sy,  0.0 ni, 80.0 id,  0.0 wa,  0.0 hi,  0.0 si,  0.0 st

# ps auxH | grep -E "(train.py|spawn)" | wc -l
1228

# train log
2021-12-23 15:01:12,119 - mmdet - INFO - Epoch [1][500/7330]    lr: 9.980e-03, eta: 19:57:07, time: 0.718, data_time: 0.009, memory: 3512, loss_cls: 1.1594, loss_bbox: 0.5733, loss: 1.7326
2021-12-23 15:01:47,805 - mmdet - INFO - Epoch [1][550/7330]    lr: 1.000e-02, eta: 19:42:19, time: 0.715, data_time: 0.010, memory: 3512, loss_cls: 1.2114, loss_bbox: 0.6517, loss: 1.8631
2021-12-23 15:02:23,515 - mmdet - INFO - Epoch [1][600/7330]    lr: 1.000e-02, eta: 19:29:41, time: 0.713, data_time: 0.009, memory: 3512, loss_cls: 1.2083, loss_bbox: 0.6279, loss: 1.8361
2021-12-23 15:02:59,285 - mmdet - INFO - Epoch [1][650/7330]    lr: 1.000e-02, eta: 19:19:13, time: 0.716, data_time: 0.010, memory: 3512, loss_cls: 1.1998, loss_bbox: 0.6108, loss: 1.8106
2021-12-23 15:03:34,852 - mmdet - INFO - Epoch [1][700/7330]    lr: 1.000e-02, eta: 19:09:40, time: 0.711, data_time: 0.010, memory: 3512, loss_cls: 1.1975, loss_bbox: 0.5894, loss: 1.7868
2021-12-23 15:04:10,733 - mmdet - INFO - Epoch [1][750/7330]    lr: 1.000e-02, eta: 19:01:58, time: 0.718, data_time: 0.010, memory: 3512, loss_cls: 1.1416, loss_bbox: 0.5770, loss: 1.7186

after:

# ps
%Cpu(s): 10.3 us,  3.9 sy,  0.0 ni, 85.8 id,  0.0 wa,  0.0 hi,  0.0 si,  0.0 st

# ps auxH | grep -E "(train.py|spawn)" | wc -l
220

# train log
2021-12-23 14:49:24,475 - mmdet - INFO - Epoch [1][550/7330]    lr: 1.000e-02, eta: 19:30:46, time: 0.723, data_time: 0.009, memory: 3479, loss_cls: 1.0979, loss_bbox: 0.5628, loss: 1.6607
2021-12-23 14:50:00,236 - mmdet - INFO - Epoch [1][600/7330]    lr: 1.000e-02, eta: 19:19:18, time: 0.715, data_time: 0.009, memory: 3479, loss_cls: 1.1219, loss_bbox: 0.5471, loss: 1.6689
2021-12-23 14:50:36,342 - mmdet - INFO - Epoch [1][650/7330]    lr: 1.000e-02, eta: 19:10:24, time: 0.723, data_time: 0.009, memory: 3479, loss_cls: 1.0238, loss_bbox: 0.5310, loss: 1.5548
2021-12-23 14:51:12,477 - mmdet - INFO - Epoch [1][700/7330]    lr: 1.000e-02, eta: 19:02:43, time: 0.723, data_time: 0.008, memory: 3479, loss_cls: 0.9660, loss_bbox: 0.5207, loss: 1.4867
2021-12-23 14:51:48,642 - mmdet - INFO - Epoch [1][750/7330]    lr: 1.000e-02, eta: 18:55:59, time: 0.723, data_time: 0.008, memory: 3479, loss_cls: 0.9161, loss_bbox: 0.4946, loss: 1.4108

YOLOX:

before:

# top
%Cpu(s): 57.0 us, 40.6 sy,  0.0 ni,  2.4 id,  0.0 wa,  0.0 hi,  0.0 si,  0.0 st

# ps auxH | grep -E "(train.py|spawn)" | wc -l
2316

# train log
2021-12-23 15:15:28,974 - mmdet - INFO - Epoch [1][500/1849]    lr: 2.925e-05, eta: 4 days, 12:34:37, time: 0.522, data_time: 0.027, memory: 7104, loss_cls: 2.0750, loss_bbox: 4.5399, loss_obj: 7.3216, loss: 13.9366
2021-12-23 15:15:52,379 - mmdet - INFO - Epoch [1][550/1849]    lr: 3.539e-05, eta: 4 days, 9:14:50, time: 0.468, data_time: 0.027, memory: 7104, loss_cls: 2.0856, loss_bbox: 4.4934, loss_obj: 6.9628, loss: 13.5418
2021-12-23 15:16:18,378 - mmdet - INFO - Epoch [1][600/1849]    lr: 4.212e-05, eta: 4 days, 7:08:09, time: 0.520, data_time: 0.027, memory: 7104, loss_cls: 2.1352, loss_bbox: 4.4188, loss_obj: 6.8577, loss: 13.4117
2021-12-23 15:16:45,188 - mmdet - INFO - Epoch [1][650/1849]    lr: 4.943e-05, eta: 4 days, 5:32:27, time: 0.536, data_time: 0.025, memory: 7189, loss_cls: 2.1809, loss_bbox: 4.3384, loss_obj: 6.6560, loss: 13.1754
2021-12-23 15:17:08,857 - mmdet - INFO - Epoch [1][700/1849]    lr: 5.733e-05, eta: 4 days, 3:29:02, time: 0.473, data_time: 0.029, memory: 7189, loss_cls: 2.2099, loss_bbox: 4.2915, loss_obj: 5.9493, loss: 12.4506
2021-12-23 15:17:34,397 - mmdet - INFO - Epoch [1][750/1849]    lr: 6.581e-05, eta: 4 days, 2:05:01, time: 0.511, data_time: 0.027, memory: 7189, loss_cls: 2.2673, loss_bbox: 4.2156, loss_obj: 5.7262, loss: 12.2090

after:

# top
%Cpu(s): 33.0 us,  4.8 sy,  0.0 ni, 62.2 id,  0.0 wa,  0.0 hi,  0.0 si,  0.0 st

# ps auxH | grep -E "(train.py|spawn)" | wc -l
300

# train log
2021-12-23 15:24:26,022 - mmdet - INFO - Epoch [1][500/1849]    lr: 2.925e-05, eta: 4 days, 0:38:47, time: 0.385, data_time: 0.018, memory: 7096, loss_cls: 1.9213, loss_bbox: 4.6419, loss_obj: 7.6198, loss: 14.1830
2021-12-23 15:24:45,570 - mmdet - INFO - Epoch [1][550/1849]    lr: 3.539e-05, eta: 3 days, 21:19:21, time: 0.391, data_time: 0.018, memory: 7096, loss_cls: 1.9475, loss_bbox: 4.6070, loss_obj: 7.4917, loss: 14.0463
2021-12-23 15:25:06,711 - mmdet - INFO - Epoch [1][600/1849]    lr: 4.212e-05, eta: 3 days, 18:57:41, time: 0.423, data_time: 0.017, memory: 7102, loss_cls: 1.9989, loss_bbox: 4.5558, loss_obj: 7.2311, loss: 13.7858
2021-12-23 15:25:29,712 - mmdet - INFO - Epoch [1][650/1849]    lr: 4.943e-05, eta: 3 days, 17:24:11, time: 0.460, data_time: 0.017, memory: 7102, loss_cls: 2.0798, loss_bbox: 4.4725, loss_obj: 7.0267, loss: 13.5790
2021-12-23 15:25:48,532 - mmdet - INFO - Epoch [1][700/1849]    lr: 5.733e-05, eta: 3 days, 15:08:43, time: 0.376, data_time: 0.017, memory: 7102, loss_cls: 2.1350, loss_bbox: 4.4055, loss_obj: 6.1632, loss: 12.7037
2021-12-23 15:26:06,272 - mmdet - INFO - Epoch [1][750/1849]    lr: 6.581e-05, eta: 3 days, 12:58:07, time: 0.355, data_time: 0.018, memory: 7102, loss_cls: 2.1839, loss_bbox: 4.3448, loss_obj: 5.6600, loss: 12.1887

YOLOV5 (Custom, data preprocessing very similar to YOLOX, but have efficent backbone and loss calculation)

I configure the YOLOv5 with workers_per_gpu=8, and never run YOLOv5 in spawn because it requires more than 256GB memory to spawn all the threads, so I just run it in fork mode.

before:

# top
%Cpu(s): 53.1 us, 46.5 sy,  0.0 ni,  0.4 id,  0.0 wa,  0.0 hi,  0.0 si,  0.0 st

# ps auxH | grep -E "(train.py|spawn)" | wc -l
4428

# train log
2021-12-23 15:59:06,145 - mmdet - INFO - Epoch [1][500/1849]    lr: 1.810e-03, eta: 3 days, 13:15:31, time: 0.513, data_time: 0.034, memory: 3593, loss_box: 0.8057, loss_obj: 0.5071, loss_cls: 0.8061, num_gts: 78.6700, loss: 2.1188
2021-12-23 15:59:32,292 - mmdet - INFO - Epoch [1][550/1849]    lr: 1.891e-03, eta: 3 days, 12:50:06, time: 0.524, data_time: 0.034, memory: 3593, loss_box: 0.8012, loss_obj: 0.4968, loss_cls: 0.7990, num_gts: 76.7650, loss: 2.0970
2021-12-23 15:59:58,584 - mmdet - INFO - Epoch [1][600/1849]    lr: 1.972e-03, eta: 3 days, 12:31:45, time: 0.528, data_time: 0.031, memory: 3593, loss_box: 0.7989, loss_obj: 0.5017, loss_cls: 0.7883, num_gts: 78.0125, loss: 2.0889
2021-12-23 16:00:25,410 - mmdet - INFO - Epoch [1][650/1849]    lr: 2.053e-03, eta: 3 days, 12:20:15, time: 0.534, data_time: 0.027, memory: 3593, loss_box: 0.7978, loss_obj: 0.4972, loss_cls: 0.7779, num_gts: 78.3200, loss: 2.0729
2021-12-23 16:00:51,846 - mmdet - INFO - Epoch [1][700/1849]    lr: 2.134e-03, eta: 3 days, 12:08:56, time: 0.532, data_time: 0.042, memory: 3593, loss_box: 0.7938, loss_obj: 0.4977, loss_cls: 0.7649, num_gts: 78.7475, loss: 2.0564
2021-12-23 16:01:17,904 - mmdet - INFO - Epoch [1][750/1849]    lr: 2.215e-03, eta: 3 days, 11:51:11, time: 0.519, data_time: 0.039, memory: 3593, loss_box: 0.7883, loss_obj: 0.4998, loss_cls: 0.7531, num_gts: 78.5925, loss: 2.0412

after:

# top
%Cpu(s): 59.1 us,  4.9 sy,  0.0 ni, 35.9 id,  0.1 wa,  0.0 hi,  0.0 si,  0.0 st

# ps auxH | grep -E "(train.py|spawn)" | wc -l
396

# train log
2021-12-23 15:46:41,565 - mmdet - INFO - Epoch [1][500/1849]    lr: 1.810e-03, eta: 1 day, 9:20:16, time: 0.160, data_time: 0.019, memory: 3593, loss_box: 0.8103, loss_obj: 0.5051, loss_cls: 0.8027, num_gts: 77.7750, lo
ss: 2.1181                                                                                                                                                                                                                 
2021-12-23 15:46:49,532 - mmdet - INFO - Epoch [1][550/1849]    lr: 1.891e-03, eta: 1 day, 8:31:53, time: 0.159, data_time: 0.019, memory: 3593, loss_box: 0.8082, loss_obj: 0.5058, loss_cls: 0.7942, num_gts: 78.7750, lo
ss: 2.1081                                                                                                                                                                                                                 
2021-12-23 15:46:57,486 - mmdet - INFO - Epoch [1][600/1849]    lr: 1.972e-03, eta: 1 day, 7:51:36, time: 0.159, data_time: 0.019, memory: 3593, loss_box: 0.8052, loss_obj: 0.5103, loss_cls: 0.7838, num_gts: 79.8150, lo
ss: 2.0993                                                                                                                                                                                                                 
2021-12-23 15:47:05,384 - mmdet - INFO - Epoch [1][650/1849]    lr: 2.053e-03, eta: 1 day, 7:16:29, time: 0.158, data_time: 0.018, memory: 3593, loss_box: 0.7994, loss_obj: 0.4936, loss_cls: 0.7766, num_gts: 76.8900, lo
ss: 2.0696                                                                                                                                                                                                                 
2021-12-23 15:47:13,279 - mmdet - INFO - Epoch [1][700/1849]    lr: 2.134e-03, eta: 1 day, 6:46:35, time: 0.158, data_time: 0.020, memory: 3593, loss_box: 0.7970, loss_obj: 0.5138, loss_cls: 0.7634, num_gts: 80.3275, lo
ss: 2.0741                                                                                                                                                                                                                 
2021-12-23 15:47:21,129 - mmdet - INFO - Epoch [1][750/1849]    lr: 2.215e-03, eta: 1 day, 6:19:44, time: 0.157, data_time: 0.020, memory: 3593, loss_box: 0.7940, loss_obj: 0.5021, loss_cls: 0.7531, num_gts: 79.5925, lo
ss: 2.0492

Other recommandation

  1. Change default multiprocessing start method from spawn to fork. I know there maybe some thread or namespace safety problem when use fork, but fork is so fast, and it's fast enough for me to take that risk and give it a test first. Another reson is it requires less resource (CPU, memory). So I think it good to give it a position in the config file.

  2. Disable all unnecessary data preprocessing stages and tweak the workers_per_gpu or cv2.setNumThread() to glance what is the optimized training speed for your model, after that, add the preprocessing stage back and see which one hinder your training speed. In my opinion, the preprocessing should not increase the training speed at all if the preprocessing not increase the training data (e.g. Mosaic will increase the trainging time because it add more instances to calculate the loss)

Conclusion

  1. A data-hungrey model with complicated data-preprocessing stages will get benefit form this PR.
  2. This PR reduces the number of threads massively, and the CPU finnally can take a breath.
  3. Change multi-processing start method from spawn to fork can reduce startup time and resources requirement.

@CLAassistant
Copy link

CLAassistant commented Dec 23, 2021

CLA assistant check
All committers have signed the CLA.

@imyhxy imyhxy changed the title Add Single Line to Save Your Life Add A Single Line to Save Your Life Dec 23, 2021
@orangeccc
Copy link

@imyhxy well done!

@fcakyon
Copy link
Contributor

fcakyon commented Dec 24, 2021

Amazing fix @imyhxy! Do you have any intentions to open a pr with your yolov5 implementation?

@imyhxy
Copy link
Contributor Author

imyhxy commented Dec 24, 2021

@fcakyon Yes, after I finished the training experiments and if everything OK I will open a PR for it.

@ZwwWayne
Copy link
Collaborator

Hi @imyhxy ,
Thanks for your kind PR. Would you like to also check the case of numthread=1? Originally there are some OpenMMLab codebases using numthread=1. If 0 is the best, we can use that in MMDet.

@imyhxy
Copy link
Contributor Author

imyhxy commented Dec 24, 2021

@ZwwWayne Merry Christmas 🎄

I have done some test on another server which is the same as before but have 4 GPU.

setNumThreads Threads %CPU us %CPU sy Times
0 188 33.1 4.3 0.213, 0.214, 0.213
1 188 33.4 4.3 0.215, 0.215, 0.216
2 220 36.8 5.4 0.215, 0.214, 0.213
4 284 39.3 7.0 0.214, 0.213, 0.215
8 412 44.5 11.4 0.223, 0.221, 0.223
16 668 50.8 23.3 0.232, 0.232, 0.232

When setting setNumThreads to 0 or 1, there is no significant difference on Thread Number, User CPU Time,
System CPU Time and Training Time, so they are pretty much is the same (every process should have at least one thread). Meanwhile, according to the document of opencv, the behaviour of setNumThreads(1) maybe differ from framework, so use setNumThreads(0) is more stable and safety.

If threads == 0, OpenCV will disable threading optimizations and run all it's functions sequentially. Passing threads < 0 will reset threads number to system default. This function must be called outside of parallel region.

OpenCV will try to run its functions with specified threads number, but some behaviour differs from framework:

    * TBB - User-defined parallel constructions will run with the same threads number, if another is not specified. If later on user creates his own scheduler, OpenCV will use it.
    * OpenMP - No special defined behaviour.
    * Concurrency - If threads == 1, OpenCV will disable threading optimizations and run its functions sequentially.
    * GCD - Supports only values <= 0.
    * C= - No special defined behaviour.

The Thread Number and CPU Time (us, sy) increase as the setNumThreads increases, but the training speed is slowing down only when the total CPU workload exceed some point (~55% on my system). On the above experiments, the training speed gets some significant slow down when setNumThreads() to some number between 4-8 (I have not done such fine-grain setup).

@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Dec 27, 2021

Thank you @imyhxy !
The results look pretty solid to me. We plan to merge this PR in our release in the middle of January. Thanks again for your detailed experiments.

@ZwwWayne ZwwWayne changed the base branch from master to dev January 5, 2022 16:19
@ZwwWayne ZwwWayne merged commit 05a3fbe into open-mmlab:dev Jan 5, 2022
@imyhxy imyhxy deleted the fix-opencv branch January 6, 2022 01:18
shinya7y added a commit to shinya7y/mmdetection that referenced this pull request Jan 12, 2022
chhluo pushed a commit to chhluo/mmdetection that referenced this pull request Feb 21, 2022
ZwwWayne pushed a commit that referenced this pull request Feb 22, 2022
* add DyHead

* move and update DYReLU

* update

* replace stack with sum to reduce memory

* clean and update

* update to align inference accuracy (incomplete)

* fix pad

* update to align training accuracy and pick #6867

* add README and metafile

* update docs

* resolve comments

* revert picking 6867

* update README.md

* update metafile.yml

* resolve comments and update urls
ZwwWayne added a commit that referenced this pull request Feb 26, 2022
* [Enhancement] Upgrade isort in pre-commit hook (#7130)

* upgrade isort to v5.10.1

* replace known_standard_library with extra_standard_library

* upgrade isort to v5.10.1

replace known_standard_library with extra_standard_library

* imports order changes

* [Fix] cannot to save the best checkpoint when the key_score is None (#7101)

* [Fix] Fix MixUp transform filter boxes failing case. Added test case (#7080)

* [Fix] Update the version limitation of mmcv-full and pytorch in CI. (#7133)

* Update

* Update build.yml

* Update build.yml

* [Feature] Support TIMMBackbone (#7020)

* add TIMMBackbone

based on
open-mmlab/mmpretrain#427
open-mmlab/mmsegmentation#998

* update and clean

* fix unit test

* Revert

* add example configs

* Create 2_new_data_model.md (#6476)

fix some typo

Co-authored-by: PJLAB\huanghaian <[email protected]>

* [FIX] add Ci of pytorch 1.10 and comments for bbox clamp (#7081) (#7083)

* add comments for bbox clamp

* add CI of pytorch1.10

* add ci of pytorch1.10.1

* mmcv1.9.0->mmcv1.9

* add ci of pytorch1.10

* Add daily issue owners (#7163)

* Add code owners

Signed-off-by: del-zhenwu <[email protected]>

* Update code owners

Signed-off-by: del-zhenwu <[email protected]>

* [Feature] Support visualization for Panoptic Segmentation (#7041)

* First commit of v2

* split the functions

* Support to show panoptic result

* temp

* Support to show gt

* support show gt

* fix lint

* Support to browse datasets

* Fix unit tests

* Fix findContours

* fix comments

* Fix pre-commit

* fix lint

* Add the type of an argument

* [Fix] confusion_matrix.py analysis tool handling NaNs (#7147)

* [Fix] Added missing property in SABLHead (#7091)

* Added missing property in SABLHead

* set pre-commit-hooks to v0.1.0

* set maskdownlint to v0.11.0

* pre-commit-hooks

Co-authored-by: Cedric Luo <[email protected]>

* Update config.md (#7215)

* [Fix] Fix wrong img name in onnx2tensorrt.py (#7157)

* [Docs] fix albumentations installed way (#7143)

* Update config.md

fix some typos

Co-authored-by: Jamie <[email protected]>
Co-authored-by: BigDong <[email protected]>

* [Feature] Support DyHead (#6823)

* add DyHead

* move and update DYReLU

* update

* replace stack with sum to reduce memory

* clean and update

* update to align inference accuracy (incomplete)

* fix pad

* update to align training accuracy and pick #6867

* add README and metafile

* update docs

* resolve comments

* revert picking 6867

* update README.md

* update metafile.yml

* resolve comments and update urls

* Fix broken colab link (#7218)

* [Fix] Fix wrong img name in onnx2tensorrt.py (#7157)

* [Docs] fix albumentations installed way (#7143)

* Fix broken colab link

Co-authored-by: Jamie <[email protected]>
Co-authored-by: BigDong <[email protected]>

* Remove the inplace addition in `FPN` (#7175)

* [Fix] Fix wrong img name in onnx2tensorrt.py (#7157)

* [Docs] fix albumentations installed way (#7143)

* Remove the inplace addition in `FPN`

* update

Co-authored-by: Jamie <[email protected]>
Co-authored-by: BigDong <[email protected]>
Co-authored-by: PJLAB\huanghaian <[email protected]>

* [Feature] Support OpenImages Dataset (#6331)

* [Feature] support openimage group of eval

* [Feature] support openimage group of eval

* support openimage dataset

* support openimage challenge dataset

* fully support OpenImages-V6 and OpenImages Challenge 2019

* Fix some logic error

* update config file

* fix get data_infos error

* fully support OpenImages evaluation

* update OpenImages config files

* [Feature] support OpenImages datasets

* fix bug

* support load image metas from pipeline

* fix bug

* fix get classes logic error

* update code

* support get image metas

* support openimags

* support collect image metas

* support Open Images

* fix openimages logic

* minor fix

* add a new function to compute openimages tpfp

* minor fix

* fix ci error

* minor fix

* fix indication

* minor fix

* fix returns

* fix returns

* fix returns

* fix returns

* fix returns

* minor fix

* update readme

* support loading image level labels and fix some logic

* minor fix

* minor fix

* add class names

* minor fix

* minor fix

* minor fix

* add openimages test unit

* minor fix

* minor fix

* fix test unit

* minor fix

* fix logic error

* minor fix

* fully support openimages

* minor fix

* fix docstring

* fix docstrings in readthedocs

* update get image metas script

* label_description_file -> label_file

* update openimages readme

* fix test unit

* fix test unit

* minor fix

* update readme file

* Update get_image_metas.py

* [Enhance] Speed up SimOTA matching. (#7098)

* [Feature] Add Maskformer to mmdet (#7212)

* first commit

* add README

* move model description from config to readme

add description for binary_input

add description for dice loss

add a independent panoptic gt processing function

add a independent panoptic gt processing function

remove compatibility of pretrain in maskformer

* update comments in maskformer_head

* update docs format

* Add deprecation message for deploy tools (#7242)

* Add CI for windows (#7228)

* [Fix] Fix wrong img name in onnx2tensorrt.py (#7157)

* [Docs] fix albumentations installed way (#7143)

* Add mmrotate (#7216)

* fix description for args in CSPDarknet (#7187)

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* Update build.yml

* fix test_find_latest_checkpoint

* fix data_infos__default_db_directories

* fix test_custom_classes_override_default

* Update test_custom_dataset.py

* Update test_common.py

* Update assign_result.py

* use os.path.join

* fix bug

* Update test_common.py

* Update assign_result.py

* Update sampling_result.py

* os.path -> osp

* os.path -> osp

Co-authored-by: Jamie <[email protected]>
Co-authored-by: BigDong <[email protected]>
Co-authored-by: Hyeokjoon Kwon <[email protected]>

* add Chinese version of init_cfg (#7188)

* [Fix] Fix wrong img name in onnx2tensorrt.py (#7157)

* [Docs] fix albumentations installed way (#7143)

* Create init_cfg.md

* Update docs/zh_cn/tutorials/init_cfg.md

Co-authored-by: Haian Huang(深度眸) <[email protected]>

* update init_cfg.md

* update init_cfg.md

* update init_cfg.md

* update init_cfg.md

Co-authored-by: Jamie <[email protected]>
Co-authored-by: BigDong <[email protected]>
Co-authored-by: Haian Huang(深度眸) <[email protected]>

* update MaskFormer readme and docs (#7241)

* update docs for maskformer

* update readme

* update readme format

* update link

* update json link

* update format of ConfigDict

* update format of function returns

* uncomment main in deployment/test.py

* [Feature] ResNet Strikes Back. (#7001)

* [Feature] ResNet Strikes Back.

* add more cfg

* add readme

* update

* update

* update

* update

* update

* update

* Maskformer Visualization (#7247)

* maskformer visualization

* compatible with models that do not contain arg of pretrained

* compatible with models that do not contain arg of pretrained

* Bump versions to v2.22.0 (#7240)

* Bump versions to v2.22.0

* Fix comments and add the latest PRs

* fix the id of contributor

* relax the version of mmcv

* Add ResNet Strikes Back

* Update README_zh-CN.md

* Update README.md

* fix typo

* Update README_zh-CN.md

Co-authored-by: Wenwei Zhang <[email protected]>

* Maskformer metafile and rsb readme format (#7250)

* [Fix] Fix Open Images testunit to avoid error in Windows CI (#7252)

* [Feature] support openimage group of eval

* [Feature] support openimage group of eval

* support openimage dataset

* support openimage challenge dataset

* fully support OpenImages-V6 and OpenImages Challenge 2019

* Fix some logic error

* update config file

* fix get data_infos error

* fully support OpenImages evaluation

* update OpenImages config files

* [Feature] support OpenImages datasets

* fix bug

* support load image metas from pipeline

* fix bug

* fix get classes logic error

* update code

* support get image metas

* support openimags

* support collect image metas

* support Open Images

* fix openimages logic

* minor fix

* add a new function to compute openimages tpfp

* minor fix

* fix ci error

* minor fix

* fix indication

* minor fix

* fix returns

* fix returns

* fix returns

* fix returns

* fix returns

* minor fix

* update readme

* support loading image level labels and fix some logic

* minor fix

* minor fix

* add class names

* minor fix

* minor fix

* minor fix

* add openimages test unit

* minor fix

* minor fix

* fix test unit

* minor fix

* fix logic error

* minor fix

* fully support openimages

* minor fix

* fix docstring

* fix docstrings in readthedocs

* update get image metas script

* label_description_file -> label_file

* update openimages readme

* fix test unit

* fix test unit

* minor fix

* update readme file

* Update get_image_metas.py

* fix oid testunit to avoid some error in windows

* minor fix to avoid some error in windows

* minor fix

* add comments in oid test unit

* minor fix

Co-authored-by: Cedric Luo <[email protected]>
Co-authored-by: LuooChen <[email protected]>
Co-authored-by: Daniel van Sabben Alsina <[email protected]>
Co-authored-by: jbwang1997 <[email protected]>
Co-authored-by: Yosuke Shinya <[email protected]>
Co-authored-by: siatwangmin <[email protected]>
Co-authored-by: PJLAB\huanghaian <[email protected]>
Co-authored-by: del-zhenwu <[email protected]>
Co-authored-by: Guangchen Lin <[email protected]>
Co-authored-by: VIKASH RANJAN <[email protected]>
Co-authored-by: Range King <[email protected]>
Co-authored-by: Jamie <[email protected]>
Co-authored-by: BigDong <[email protected]>
Co-authored-by: Haofan Wang <[email protected]>
Co-authored-by: Zhijian Liu <[email protected]>
Co-authored-by: BigDong <[email protected]>
Co-authored-by: RangiLyu <[email protected]>
Co-authored-by: Yue Zhou <[email protected]>
Co-authored-by: Hyeokjoon Kwon <[email protected]>
Co-authored-by: Kevin Ye <[email protected]>
ZwwWayne pushed a commit that referenced this pull request Jul 18, 2022
ZwwWayne pushed a commit that referenced this pull request Jul 18, 2022
* add DyHead

* move and update DYReLU

* update

* replace stack with sum to reduce memory

* clean and update

* update to align inference accuracy (incomplete)

* fix pad

* update to align training accuracy and pick #6867

* add README and metafile

* update docs

* resolve comments

* revert picking 6867

* update README.md

* update metafile.yml

* resolve comments and update urls
ZwwWayne pushed a commit to ZwwWayne/mmdetection that referenced this pull request Jul 19, 2022
ZwwWayne pushed a commit to ZwwWayne/mmdetection that referenced this pull request Jul 19, 2022
* add DyHead

* move and update DYReLU

* update

* replace stack with sum to reduce memory

* clean and update

* update to align inference accuracy (incomplete)

* fix pad

* update to align training accuracy and pick open-mmlab#6867

* add README and metafile

* update docs

* resolve comments

* revert picking 6867

* update README.md

* update metafile.yml

* resolve comments and update urls
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants