Skip to content

Commit

Permalink
add vertical flip (facebookresearch#818)
Browse files Browse the repository at this point in the history
* keep the resize function the same in test time the same with training time

* add vertical flip
  • Loading branch information
hcx1231 authored and fmassa committed May 28, 2019
1 parent d7e3c65 commit 4c6cd1a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
1 change: 1 addition & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
_C.INPUT.SATURATION = 0.0
_C.INPUT.HUE = 0.0

_C.INPUT.VERTICAL_FLIP_PROB_TRAIN = 0.0

# -----------------------------------------------------------------------------
# Dataset
Expand Down
9 changes: 6 additions & 3 deletions maskrcnn_benchmark/data/transforms/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ def build_transforms(cfg, is_train=True):
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
flip_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN
flip_horizontal_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN
flip_vertical_prob = cfg.INPUT.VERTICAL_FLIP_PROB_TRAIN
brightness = cfg.INPUT.BRIGHTNESS
contrast = cfg.INPUT.CONTRAST
saturation = cfg.INPUT.SATURATION
hue = cfg.INPUT.HUE
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
flip_prob = 0
flip_horizontal_prob = 0.0
flip_vertical_prob = 0.0
brightness = 0.0
contrast = 0.0
saturation = 0.0
Expand All @@ -35,7 +37,8 @@ def build_transforms(cfg, is_train=True):
[
color_jitter,
T.Resize(min_size, max_size),
T.RandomHorizontalFlip(flip_prob),
T.RandomHorizontalFlip(flip_horizontal_prob),
T.RandomVerticalFlip(flip_vertical_prob),
T.ToTensor(),
normalize_transform,
]
Expand Down
9 changes: 9 additions & 0 deletions maskrcnn_benchmark/data/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ def __call__(self, image, target):
target = target.transpose(0)
return image, target

class RandomVerticalFlip(object):
def __init__(self, prob=0.5):
self.prob = prob

def __call__(self, image, target):
if random.random() < self.prob:
image = F.vflip(image)
target = target.transpose(1)
return image, target

class ColorJitter(object):
def __init__(self,
Expand Down

0 comments on commit 4c6cd1a

Please sign in to comment.