From 862347d5383e7982fc3bdf0b5d82c359db3e7dcc Mon Sep 17 00:00:00 2001 From: CoinCheung <867153576@qq.com> Date: Fri, 19 Apr 2019 23:29:30 +0800 Subject: [PATCH] add color jitter augmentation (#680) * add color jitter augmentation * fix spelling --- maskrcnn_benchmark/config/defaults.py | 6 ++++++ maskrcnn_benchmark/data/transforms/build.py | 7 +++++++ .../data/transforms/transforms.py | 18 ++++++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index 23a599ef7..2f5a88009 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -54,6 +54,12 @@ # Convert image to BGR format (for Caffe2 models), in range 0-255 _C.INPUT.TO_BGR255 = True +# Image ColorJitter +_C.INPUT.BRIGHTNESS = 0.0 +_C.INPUT.CONTRAST = 0.0 +_C.INPUT.SATURATION = 0.0 +_C.INPUT.HUE = 0.0 + # ----------------------------------------------------------------------------- # Dataset diff --git a/maskrcnn_benchmark/data/transforms/build.py b/maskrcnn_benchmark/data/transforms/build.py index 8645d4df4..825c3ee42 100644 --- a/maskrcnn_benchmark/data/transforms/build.py +++ b/maskrcnn_benchmark/data/transforms/build.py @@ -16,9 +16,16 @@ def build_transforms(cfg, is_train=True): normalize_transform = T.Normalize( mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=to_bgr255 ) + color_jitter = T.ColorJitter( + brightness=cfg.INPUT.BRIGHTNESS, + contrast=cfg.INPUT.CONTRAST, + saturation=cfg.INPUT.SATURATION, + hue=cfg.INPUT.HUE, + ) transform = T.Compose( [ + color_jitter, T.Resize(min_size, max_size), T.RandomHorizontalFlip(flip_prob), T.ToTensor(), diff --git a/maskrcnn_benchmark/data/transforms/transforms.py b/maskrcnn_benchmark/data/transforms/transforms.py index 7e3ebbd6c..1c322f8ba 100644 --- a/maskrcnn_benchmark/data/transforms/transforms.py +++ b/maskrcnn_benchmark/data/transforms/transforms.py @@ -72,6 +72,24 @@ def __call__(self, image, target): return image, target +class ColorJitter(object): + def __init__(self, + brightness=None, + contrast=None, + saturation=None, + hue=None, + ): + self.color_jitter = torchvision.transforms.ColorJitter( + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue,) + + def __call__(self, image, target): + image = self.color_jitter(image) + return image, target + + class ToTensor(object): def __call__(self, image, target): return F.to_tensor(image), target