-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the accuracy of Classification models by using SOTA recipes and primitives #3995
Comments
@datumbox Can you release the training code, maybe the configs of training because the reference training code has already implemented the training tricks. |
@xiaohu2015 Of course! I'm in the middle of writing a blogpost that will include the configs, the training methodology, detailed ablations etc. It should be out next week. :) Edit: Here is the blogpost that documents the training recipe. |
@datumbox For the commands that start with |
Hi @datumbox. I have tried your New Recipe (without FixRes mitigations) on ResNet101 and obtained only a peak top-1 accuracy of 81.328 (at epoch 418), which is 0.558 behind your result (81.886). I launched the following command on 64 GPUs:
I only used a batch size of 64 because 128 led to out of memory (on 16GB GPUs). Thus the effective batch size in my training is 64*64 = 4096. Could you please tell me how many GPUs you used in your training? Or even better, could you please share the training configurations shown in your training logs? For example, mine is:
The number of GPUs is important information because it affects the effective batch size. I would need to scale my learning rate accordingly to match your results (and for that I would need to know the number of GPUs, and the learning rate, that you used). FYI the following file contains the metrics values at each epoch of my training. Unfortunately the training log file is too big (700MB) to be shared. It is filled with the following annoying warning message:
(By the way, do you know how to get rid of this kind of messages please? Should I create a GitHub issue somewhere?) Thank you very much in advance for your reply! |
Hi @netw0rkf10w .
This is exactly why it's hard for me as well to share the training log file. We are working on improving the model documentation and figure out how to share these easier. Here is the fully command used to train the model, it should contain all the information you need to reproduce this:
Note that we used submitit and a custom script to launch our jobs.
This is probably why you don't match my results. I used an effective batch size of 1024. 8x A100 GPUs with 128 batch-size per GPU. I would recommend to maintain the total batch size equal to 1024 to avoid requiring adapting the rest of the parameters.
Concerning the warning message, I would recommend opening a GitHub issue on main PyTorch with the minimum snippet that reproduces it to investigate further. Let me know if you face further problems reproducing the results. |
@datumbox Great, thanks a lot for your reply! I'll try again and keep you informed about the results. |
@datumbox As per the discussion in #5084, below is a recipe that achieved the following result on ResNet-50 and ImageNet:
Overview of changes to the current recipe (New Recipe + FixRes mitigations):
|
@tbennun Great contributions. I guess increasing the number of repetitions also leads to slower training. Could you tell me how much slower it was for your training? I am about to launch a few trainings and if Thanks in advance for your reply. |
@netw0rkf10w Actually, this didn't slow down training at all. The current version of RA in the classifier example uses the DeiT scheme, in which the epoch length is also |
@tbennun I see, thanks. Let me try |
I was able to reach
The effective batch size is @datumbox You said in #5084 that you were about to launch a new set of trainings with |
@netw0rkf10w Thanks for confirming, good to know you matched the accuracy. No plans to retrain all the models for now. It's very expensive and time consuming to train everything from scratch and not sure it makes sense to do this as the improvement is expected to be in the scale of 0.1-0.2 points. |
I have modified the scope of the ticket to focus on Classification so that we can conclude the phase 1 of our Batteries Included project. We will focus on Detection and Segmentation on our phase 2. Big thanks to everyone involved to this project for helping us keep TorchVision fresh! |
Hi @tbennun , follow your recipe, i tried to reappear the result. Download the latest pytorch/vision code and do it like below:
The only modified is the way to loading resnet50
The result is not very ideal, even worse than the original training.
I'm not sure if I need to load the trained resnet50. This result is a reference for everyone.
|
🚀 Feature
Update the weights of all pre-trained models to improve their accuracy.
Motivation
New Recipe + FixRes mitigations
Using a recipe which includes Warmup, Cosine Annealing, Label Smoothing, Mixup, Cutmix, Random Erasing, TrivialAugment, No BN weight decay, EMA and long training cycles and optional FixRes mitigations we are able to improve the
resnet50
accuracy by over 4.5 points. For more information on the training recipe, check here:Running other models through the same recipe, achieves the following improved accuracies:
New Recipe (without FixRes mitigations)
Removing the optional FixRes mitigations seems to yield better results for some deeper architectures and variants with larger receptive fields:
New Recipe + Regularization tuning
Adjusting slightly the regularization can help us improve the following:
In addition to regularization adjustment we can also apply the Repeated Augmentation trick
--ra-sampler --ra-reps 4
:Post-Training Quantized models
New Recipe (LR+weight_decay+train_crop_size tuning)
Pitch
To be able to improve the pre-trained model accuracy, we need to complete the "Batteries Included" work as #3911. Moreover we will need to extend our existing model builders to support multiple weights as described at #4611. Then we will be able to:
cc @datumbox @vfdev-5
The text was updated successfully, but these errors were encountered: