Skip to content

Commit

Permalink
add new config
Browse files Browse the repository at this point in the history
  • Loading branch information
triple-Mu committed Apr 7, 2024
1 parent f5293b0 commit 2043b45
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
9 changes: 4 additions & 5 deletions configs/rotated_rtmdet/rotated_rtmdet_l-comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
data_root = 'data/DOTA/'
dataset_type = 'DOTADataset'
metainfo = {
'classes':
('A', 'B', 'C', 'D', 'E', 'F'),
'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0),
(255, 0, 0), (138, 43, 226), (255, 128, 0)]
'classes': ('A', 'B', 'C', 'D', 'E', 'F'),
'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
(138, 43, 226), (255, 128, 0)]
}

file_client_args = dict(backend='disk')
angle_version = 'le90'

batch_size = 16
batch_size = 8
num_workers = 0

model = dict(
Expand Down
39 changes: 26 additions & 13 deletions mmrotate/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import mmcv
import numpy as np
import torch

from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmdet.structures.bbox import BaseBoxes, get_box_type
Expand Down Expand Up @@ -447,15 +446,19 @@ def __repr__(self):

@TRANSFORMS.register_module()
class CacheCopyPaste(BaseTransform):
def __init__(self, num_copy_thres: int = 20, max_capacity: int = 1024) -> None:

def __init__(self,
num_copy_thres: int = 20,
max_capacity: int = 1024) -> None:
self.num_copy_thres = num_copy_thres
self.max_capacity = max_capacity
self.cache = []

def transform(self, results: dict) -> dict:
"""The transform function."""
img = results['img'].copy()
for bbox in results['gt_bboxes'].tensor:
for bbox, label in zip(results['gt_bboxes'].tensor,
results['gt_bboxes_labels']):
if len(self.cache) < self.max_capacity:
xc, yc, w, h = bbox[:4].round().int().tolist()
angle = bbox[4].item() / np.pi * 180
Expand All @@ -469,8 +472,14 @@ def transform(self, results: dict) -> dict:
if crop.size == 0:
continue
xc_new, yc_new = xc - x_min, yc - y_min
bbox_new = np.array([xc_new, yc_new, w, h, angle], dtype=np.float32)
self.cache.append((crop, bbox_new))
bbox_new = np.array([xc_new, yc_new, w, h, angle],
dtype=np.float32)
if label in (1, 2, 3, 4, 5):
self.cache.append((crop, bbox_new, label))
if label in (4, 5):
self.cache.append((crop, bbox_new, label))
self.cache.append((crop, bbox_new, label))
self.cache.append((crop, bbox_new, label))
else:
random.shuffle(self.cache)
self.cache = self.cache[:self.max_capacity]
Expand All @@ -489,9 +498,10 @@ def transform(self, results: dict) -> dict:
selects = self.cache[:n_copy]
self.cache = self.cache[n_copy:]
for select in selects:
im, rbox = select
im, rbox, label = select
xc, yc, w, h, angle = rbox
box = cv2.boxPoints(((xc, yc), (w, h), angle)).round().astype(np.int32)
box = cv2.boxPoints(
((xc, yc), (w, h), angle)).round().astype(np.int32)
h, w = im.shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
cv2.drawContours(mask, [box], 0, (255), thickness=cv2.FILLED)
Expand All @@ -500,17 +510,20 @@ def transform(self, results: dict) -> dict:
p_at_w = random.randint(0, img_w - w - 1)
# img[p_at_h:p_at_h + h, p_at_w:p_at_w + w] = im
im = cv2.seamlessClone(
im,
img[p_at_h:p_at_h + h, p_at_w:p_at_w + w],
mask,
(w // 2, h // 2),
cv2.NORMAL_CLONE)
im, img[p_at_h:p_at_h + h, p_at_w:p_at_w + w], mask,
(w // 2, h // 2), cv2.NORMAL_CLONE)
img[p_at_h:p_at_h + h, p_at_w:p_at_w + w] = im
rbox[0] += p_at_w
rbox[1] += p_at_h
rbox[4] = rbox[4] / 180 * np.pi
rbox = RotatedBoxes(rbox[None], dtype=torch.float32)
results['gt_bboxes'] = RotatedBoxes.cat([results['gt_bboxes'], rbox])
results['gt_bboxes'] = RotatedBoxes.cat(
[results['gt_bboxes'], rbox])
results['gt_bboxes_labels'] = np.append(
results['gt_bboxes_labels'], label)
results['gt_ignore_flags'] = np.append(
results['gt_ignore_flags'], False)

results['img'] = img
return results

Expand Down

0 comments on commit 2043b45

Please sign in to comment.