Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transforms: add Random Erasing for image augmentation #909

Merged
merged 16 commits into from
Jun 24, 2019

Conversation

zhunzhong07
Copy link
Contributor

@zhunzhong07 zhunzhong07 commented May 16, 2019

Random Erasing randomly selects a rectangle region in an image and erases its pixels with random values. It can reduce the risk of overfitting, and improves CNN baselines in image classification, object detection and person reidentification.

I found that this augmentation method has been widely used in image classification (CIFA-10, CIFAR-100) and person re-identification.

Also, it could achieve improvements on ImageNet: +0.7% in Prec@1 for ResNet-50, +0.33% in Prec@1 for ResNet-34.

Therefore, I think it would be valuable to users.

'Random Erasing Data Augmentation' by Zhong et.al. https://arxiv.org/pdf/1708.04896.pdf

A parallel work is "Improved Regularization of Convolutional Neural Networks with Cutout" proposed by DeVries. https://arxiv.org/pdf/1708.04552.pdf

Previous pull request and issues #335 #226 #420

@zhunzhong07 zhunzhong07 changed the title add erase function transforms: add Random Erasing for image augmentation May 16, 2019
@zhunzhong07 zhunzhong07 marked this pull request as ready for review May 16, 2019 06:23
@codecov-io
Copy link

codecov-io commented May 16, 2019

Codecov Report

Merging #909 into master will increase coverage by 2.82%.
The diff coverage is 82.05%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #909      +/-   ##
==========================================
+ Coverage   60.03%   62.85%   +2.82%     
==========================================
  Files          64       65       +1     
  Lines        5054     5140      +86     
  Branches      754      773      +19     
==========================================
+ Hits         3034     3231     +197     
+ Misses       1817     1683     -134     
- Partials      203      226      +23
Impacted Files Coverage Δ
torchvision/transforms/functional.py 70.95% <60%> (-0.17%) ⬇️
torchvision/transforms/transforms.py 82.12% <85.29%> (-0.42%) ⬇️
torchvision/datasets/stl10.py 26.59% <0%> (-3.26%) ⬇️
torchvision/models/mobilenet.py 89.7% <0%> (-2.61%) ⬇️
torchvision/ops/roi_pool.py 67.44% <0%> (-0.86%) ⬇️
torchvision/ops/roi_align.py 65.95% <0%> (-0.71%) ⬇️
torchvision/models/detection/faster_rcnn.py 74.39% <0%> (ø) ⬆️
torchvision/extension.py 38.09% <0%> (ø)
torchvision/ops/boxes.py 94.73% <0%> (+0.14%) ⬆️
... and 9 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2b3a1b6...429a55e. Read the comment docs.

@zhunzhong07
Copy link
Contributor Author

@alykhantejani Could you pay attention to this method and advice the member to add it into the transform. I think random erasing is a useful augmentation method that will often be used in vision tasks.

@ekagra-ranjan
Copy link
Contributor

This transform can be handy during self-supervision training.

@rwightman
Copy link
Contributor

I've used random erasing successfully in metric loss (triplet) training and larger image (224x224+ imagenet like) training with good results. I think it would be a worthwhile addition.

@zhunzhong07 .. I found that the per-pixel version was quite useful for the problems but you didn't include it in your github impl or this PR? In my experiments, normally distributed pixels, after image normalization worked well. Uniform dist caused convergence issues later in training. I perform the RE operation once the tensors are on the GPU as part of a GPU prefetching loader/collate/normalize.

Feel free to copy any ideas: https://github.com/rwightman/pytorch-image-models/blob/master/data/random_erasing.py

@zhunzhong07
Copy link
Contributor Author

zhunzhong07 commented May 19, 2019

@rwightman Thanks for your advice. In PR, I have included per-pixel mode by v = torch.rand(img.size()[0], h, w). I have checked your code and notice that the per-pixel value should be normally distributed. I will fix this bug. Thank you!

One request. I don't have enough GPUs to train a model on imagenet right now. So, if you already have the results, could you also provide some results of training w/ or w/o random erasing on imagenet? Thank you!

@rwightman
Copy link
Contributor

@zhunzhong07 I'll run some imagenet trainings to support this. I don't think I have two historical runs with all the hyper-params and results recorded that didn't have some sort of change in library versions, other hyper params, machines etc... I'll let you know how it goes

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following proposed changes will be appreciable.

torchvision/transforms/functional.py Outdated Show resolved Hide resolved
torchvision/transforms/functional.py Outdated Show resolved Hide resolved
@@ -1317,6 +1317,23 @@ def test_random_grayscale(self):
# Checking if RandomGrayscale can be printed as string
trans3.__repr__()

def test_random_erasing(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the test a bit more stronger by checking if the region around the erased patch is equal to the original image?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed by #1060.

def __call__(self, img):
"""
Args:
img (Tensor): Image to be erased.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO , having input and output as PIL Image should be good and consistent with the other transforms
e.g In order to apply RandomOrder with the list of other transforms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't do this operation after Normalization (which is a Tensor based transform), you have to duplicate argument passing and pass your dataset stats to both Normalize and the RandomErasing transform. Post norm, you can assume a mean of 0 and consistent std dev. Also, if you do it before norm, mixed up with other transforms, it's much easier to skew the statistics of your data and cause divergence between train and validation.

In my experience, using it on a few projects now, it's generally cleaner, less fussy, and more efficient (integrated with moving data to the GPU) if done after normalization as tensor ops.

w = int(round(math.sqrt(target_area / aspect_ratio)))

if w < img.size()[2] and h < img.size()[1]:
x = random.randint(0, img.size()[1] - h)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhunzhong07 , just a minor optimisation, can't we store img.size()[1] kinda values instead of computing again and again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed by #1087.

@rwightman
Copy link
Contributor

FWIW, I finished several training sessions with different random-erasing settings. I ran no RE, constant (0) RE, and normally distributed (0 mean, 1 std) per-pixel RE. I did not do a random color (solid) run. I'm still running some other tests in this series to validate some other impl and hyper-params for personal interest....

  • All training with cosine LR decay, 5 epoch warmup, 120 epochs until a 10 epoch constant cooldown, label smoothing @ 0.1
  • Standard imagenet preprocessing & augmentation otherwise
  • Thanks to cosine + smoothing, all 3 models I trained are pretty darn good results for a resnet-34 (see torchvision pretrained), including the one without random-erasing
  • There is an improvement with each level of RE applied
  • The improvement is still evident when validating against the 'ImagenetV2 matched-frequency' validation set, where both RE models show an advantage

ImageNet 1K validation

training results
torchvision Prec@1 73.312 (26.688) Prec@5 91.426 (8.574)
me w/ no RE Prec@1 74.772 (25.228) Prec@5 91.988 (8.012)
me w/ const 0 RE Prec@1 74.838 (25.162) Prec@5 92.212 (7.788)
me w/ per-pixel, normal RE Prec@1 75.110 (24.890) Prec@5 92.284 (7.716)

ImagenetV2-matched-frequency validation (https://github.com/modestyachts/ImageNetV2)

training results
torchvision Prec@1 61.190 (38.810) Prec@5 82.710 (17.290)
me w/ no RE Prec@1 62.300 (37.700) Prec@5 83.870 (16.130)
me w/ const 0 RE Prec@1 62.790 (37.210) Prec@5 84.060 (15.940)
me w/ per-pixel, normal RE Prec@1 62.870 (37.130) Prec@5 84.140 (15.860)

@zhunzhong07
Copy link
Contributor Author

@rwightman Thanks for your experimental results. It is great to see that random erasing could improve the performance of ImageNet.

Did you run these results on your impl https://github.com/rwightman/pytorch-image-models? If so, could you also provide the running shell (i.e., parameters of distributed_train.sh), so that we can accurately reproduce results?

Thank you for your time and efforts on implementing these results.

@rwightman
Copy link
Contributor

Yeah, using the train script in image-models. I was only running single GPU for these runs and did them in parallel. I had a local mod experimenting with the warmup and changing it's overlap behaviour with the main schedule but differences is minor, I extended the epochs here by 5 to compensate. These should reproduce results closely enough:

No RE:
python train.py /imagenet/ --model resnet34 -b 256 --epochs 125 --warmup-epochs 5 --sched cosine --lr 0.1 --weight-decay 1e-4 --reprob 0.

RE constant:
python train.py /imagenet/ --model resnet34 -b 256 --epochs 125 --warmup-epochs 5 --sched cosine --lr 0.1 --weight-decay 1e-4 --reprob 0.4

RE per-pixel normal:
python train.py /imagenet/ --model resnet34 -b 256 --epochs 125 --warmup-epochs 5 --sched cosine --lr 0.1 --weight-decay 1e-4 --reprob 0.4 --remode pixel

@fmassa
Copy link
Member

fmassa commented Jun 4, 2019

@rwightman thanks a lot for the feedback wrt the usefulness of RandomErasing.

I'll have a closer look at the implementation today

@zhunzhong07
Copy link
Contributor Author

zhunzhong07 commented Jun 6, 2019

@rwightman Thanks! With your provided scripts, I have obtained similar results for ResNet34.

@fmassa Thank you for your attention. I also run RandomErasing for ResNet50 and ResNet101, and achieve an improvement (+0.7 in Prec@1 for ResNet50 and +0.55 in Prec@1 for ResNet101).

Results on ImageNet 1K validation

                                                               
trainingresults
torchvision ResNet50Prec@1 76.15 (23.85) Prec@5 92.87 (7.13)
me ResNet50 w/ no REPrec@1 76.33 (23.67) Prec@5 92.96 (7.04)
me ResNet50 w/ per-pixel, normal RE Prec@1 77.08 (22.92) Prec@5 93.27 (6.73)
torchvision ResNet101Prec@1 77.37 (22.63) Prec@5 93.56 (6.44)
me ResNet101 w/ no REPrec@1 79.02 (20.98) Prec@5 94.27 (5.73)
me ResNet101 w/ per-pixel, normal RE Prec@1 79.57 (20.43) Prec@5 94.7 (5.3)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks fir the PR!

I have a few comments, let me know what you think.

Also, this is the transform that you used to obtain better results, with value='random', is that right?

Can you also add an entry to the documentation in https://github.com/pytorch/vision/blob/master/docs/source/transforms.rst

torchvision/transforms/functional.py Outdated Show resolved Hide resolved
torchvision/transforms/transforms.py Outdated Show resolved Hide resolved
torchvision/transforms/transforms.py Outdated Show resolved Hide resolved
torchvision/transforms/transforms.py Outdated Show resolved Hide resolved
torchvision/transforms/transforms.py Outdated Show resolved Hide resolved
@zhunzhong07
Copy link
Contributor Author

zhunzhong07 commented Jun 8, 2019

@fmassa Thank you for your improved comments. I have modified the PR and add an entry to the documentation, according to your suggestions.

Yes, using the mode of value='random' would achieve better results. I would like to know your suggestion: only using 'random' mode, or, including multi modes (random and constant)?

@zhunzhong07
Copy link
Contributor Author

Hi @fmassa. I've modified the PR according to your comments. I also add the results of ResNet101 above, consistent improvement is obtained by RandomErasing.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@fmassa fmassa merged commit 3254560 into pytorch:master Jun 24, 2019
@Zhaoyi-Yan
Copy link

@zhunzhong07 Hi, have you tried it for detection?

@zhunzhong07
Copy link
Contributor Author

@Zhaoyi-Yan Yes. Random Erasing can improve the results of Fast RCNN on VOC17. Please refer to our paper.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants