Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Add two augmentation functions: rotate and light change #1170

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@

# Flips
_C.INPUT.HORIZONTAL_FLIP_PROB_TRAIN = 0.5
_C.INPUT.VERTICAL_FLIP_PROB_TRAIN = 0.0
_C.INPUT.VERTICAL_FLIP_PROB_TRAIN = 0.5

# Rotate
_C.INPUT.ANGLE_TRAIN = 10

# Linear Light Shade
_C.INPUT.VERTICAL_LIGHT_PROB_TRAIN = 0.5
_C.INPUT.VERTICAL_LIGHT_SCALE_TRAIN = 20
_C.INPUT.HORIZONTAL_LIGHT_PROB_TRAIN = 0.5
_C.INPUT.HORIZONTAL_LIGHT_SCALE_TRAIN = 20

# -----------------------------------------------------------------------------
# Dataset
Expand Down
13 changes: 13 additions & 0 deletions maskrcnn_benchmark/data/transforms/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def build_transforms(cfg, is_train=True):
contrast = cfg.INPUT.CONTRAST
saturation = cfg.INPUT.SATURATION
hue = cfg.INPUT.HUE
angle = cfg.INPUT.ANGLE_TRAIN
light_vertical_prob = cfg.INPUT.VERTICAL_LIGHT_PROB_TRAIN
light_vertical_scale = cfg.INPUT.VERTICAL_LIGHT_SCALE_TRAIN
light_horizontal_prob = cfg.INPUT.HORIZONTAL_LIGHT_PROB_TRAIN
light_horizontal_scale = cfg.INPUT.HORIZONTAL_LIGHT_SCALE_TRAIN
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
Expand All @@ -21,6 +26,11 @@ def build_transforms(cfg, is_train=True):
contrast = 0.0
saturation = 0.0
hue = 0.0
angle = 0.0
light_vertical_prob = 0.0
light_vertical_scale = 0.0
light_horizontal_prob = 0.0
light_horizontal_scale = 0.0

to_bgr255 = cfg.INPUT.TO_BGR255
normalize_transform = T.Normalize(
Expand All @@ -35,7 +45,10 @@ def build_transforms(cfg, is_train=True):

transform = T.Compose(
[
T.SmallAngleRotate(angle),
color_jitter,
T.HorizontalLinearLight(light_horizontal_prob, light_horizontal_scale),
T.VerticalLinearLight(light_vertical_prob, light_vertical_scale),
T.Resize(min_size, max_size),
T.RandomHorizontalFlip(flip_horizontal_prob),
T.RandomVerticalFlip(flip_vertical_prob),
Expand Down
57 changes: 57 additions & 0 deletions maskrcnn_benchmark/data/transforms/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import random

from PIL import Image
import numpy as np
import cv2
import torch
import torchvision
from torchvision.transforms import functional as F
Expand Down Expand Up @@ -83,6 +86,59 @@ def __call__(self, image, target):
target = target.transpose(1)
return image, target

class HorizontalLinearLight(object):
def __init__(self, prob=0.5, lightsacle=50):
self.prob = prob
self.lightsacle = lightsacle

def __call__(self, image, target):
if random.random() < self.prob:
image = np.asarray(image.copy())
h, w, c = image.shape
x = np.linspace(-1*self.lightsacle, self.lightsacle, w)
weight = np.expand_dims(x, axis=1)
weight = weight.repeat(c, axis=1)
image = image + weight
image[image < 0] = 0
image[image > 255] = 255
image = Image.fromarray(image.astype(np.uint8))
return image, target

class VerticalLinearLight(object):
def __init__(self, prob=0.5, lightsacle=50):
self.prob = prob
self.lightsacle = lightsacle

def __call__(self, image, target):
if random.random() < self.prob:
image = np.asarray(image.copy())
h, w, c = image.shape
x = np.linspace(-1*self.lightsacle, self.lightsacle, h)
weight = np.expand_dims(x, axis=1)
weight = weight.repeat(c, axis=1)
weight = np.expand_dims(weight, axis=1)
image = image + weight
image[image < 0] = 0
image[image > 255] = 255
image = Image.fromarray(image.astype(np.uint8))
return image, target

class SmallAngleRotate(object):
def __init__(self,angle=10):
self.angle_range = angle

def __call__(self, image, target):
self.angle = random.randint(-1*self.angle_range, self.angle_range)
image = np.asarray(image.copy())
h, w, _ = image.shape
cx = w / 2
cy = h / 2
M = cv2.getRotationMatrix2D((cx, cy), self.angle, 1.0)
image = cv2.warpAffine(image, M, (w, h))
target = target.rotate(M)
image = Image.fromarray(image)
return image, target

class ColorJitter(object):
def __init__(self,
brightness=None,
Expand Down Expand Up @@ -119,3 +175,4 @@ def __call__(self, image, target=None):
if target is None:
return image
return image, target

36 changes: 28 additions & 8 deletions maskrcnn_benchmark/structures/bounding_box.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import numpy as np

# transpose
FLIP_LEFT_RIGHT = 0
Expand All @@ -25,7 +26,7 @@ def __init__(self, bbox, image_size, mode="xyxy"):
)
if bbox.size(-1) != 4:
raise ValueError(
"last dimension of bbox should have a "
"last dimenion of bbox should have a "
"size of 4, got {}".format(bbox.size(-1))
)
if mode not in ("xyxy", "xywh"):
Expand Down Expand Up @@ -164,9 +165,31 @@ def transpose(self, method):
bbox.add_field(k, v)
return bbox.convert(self.mode)

def rotate(self, matrix):
"""
Returns a rotated copy of this bounding box

:param matrix: Rotation matrix caculated
according to certain degree by Opencv CV2::getRotationMatrix2D
"""
bbox = BoxList(self.bbox, self.size, mode="xyxy")
# bbox._copy_extra_fields(self)
for k, v in self.extra_fields.items():
if not isinstance(v, torch.Tensor):
v = v.rotate(matrix)
newbox = []
for polygon in v.polygons:
coods = polygon.polygons[0].numpy()
allx = coods[[0, 2, 4, 6]]
ally = coods[[1, 3, 5, 7]]
newbox.append([allx.min(), ally.min(), allx.max(), ally.max()])
bbox.bbox = torch.from_numpy(np.asarray(newbox))
bbox.add_field(k, v)
return bbox.convert(self.mode)

def crop(self, box):
"""
Crops a rectangular region from this bounding box. The box is a
Cropss a rectangular region from this bounding box. The box is a
4-tuple defining the left, upper, right, and lower pixel
coordinate.
"""
Expand Down Expand Up @@ -232,18 +255,15 @@ def area(self):
area = box[:, 2] * box[:, 3]
else:
raise RuntimeError("Should not be here")

return area

def copy_with_fields(self, fields, skip_missing=False):
def copy_with_fields(self, fields):
bbox = BoxList(self.bbox, self.size, self.mode)
if not isinstance(fields, (list, tuple)):
fields = [fields]
for field in fields:
if self.has_field(field):
bbox.add_field(field, self.get_field(field))
elif not skip_missing:
raise KeyError("Field '{}' not found in {}".format(field, self))
bbox.add_field(field, self.get_field(field))
return bbox

def __repr__(self):
Expand Down
Loading