Skip to content

Commit

Permalink
[Fix] Fix MixUp transform filter boxes failing case. Added test case
Browse files Browse the repository at this point in the history
  • Loading branch information
dvansa committed Jan 28, 2022
1 parent 4bdb312 commit d567d14
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
5 changes: 2 additions & 3 deletions mmdet/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,9 +2426,8 @@ def _mixup_transform(self, results):
keep_list = self._filter_box_candidates(retrieve_gt_bboxes.T,
cp_retrieve_gt_bboxes.T)

if keep_list.sum() >= 1.0:
retrieve_gt_labels = retrieve_gt_labels[keep_list]
cp_retrieve_gt_bboxes = cp_retrieve_gt_bboxes[keep_list]
retrieve_gt_labels = retrieve_gt_labels[keep_list]
cp_retrieve_gt_bboxes = cp_retrieve_gt_bboxes[keep_list]

mixup_gt_bboxes = np.concatenate(
(results['gt_bboxes'], cp_retrieve_gt_bboxes), axis=0)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_data/test_pipelines/test_transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,33 @@ def test_mixup():
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32

# test filter bbox :
# 2 boxes with sides 1 and 3 are filtered as min_bbox_size=5
gt_bboxes = np.array([[0, 0, 1, 1], [0, 0, 3, 3]], dtype=np.float32)
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = np.array([], dtype=np.float32)
mixresults = results['mix_results'][0]
mixresults['gt_labels'] = copy.deepcopy(results['gt_labels'])
mixresults['gt_bboxes'] = copy.deepcopy(results['gt_bboxes'])
mixresults['gt_bboxes_ignore'] = copy.deepcopy(results['gt_bboxes_ignore'])
transform = dict(
type='MixUp',
img_scale=(10, 12),
ratio_range=(1.5, 1.5),
min_bbox_size=5,
skip_filter=False)
mixup_module = build_from_cfg(transform, PIPELINES)

results = mixup_module(results)

assert results['gt_bboxes'].shape[0] == 2
assert results['gt_labels'].shape[0] == 2
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32


def test_photo_metric_distortion():
img = mmcv.imread(
Expand Down

0 comments on commit d567d14

Please sign in to comment.