Skip to content

Commit

Permalink
[Datumaro] Change alignment in mask parsing (#1547)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max authored May 17, 2020
1 parent 4299090 commit 98a9718
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Task/Job buttons has no "Open in new tab" option (<https://github.com/opencv/cvat/pull/1419>)
- Delete point context menu option has no shortcut hint (<https://github.com/opencv/cvat/pull/1416>)
- Fixed issue with unnecessary tag activation in cvat-canvas (<https://github.com/opencv/cvat/issues/1540>)
- Fixed an issue with large number of instances in instance mask (https://github.com/opencv/cvat/issues/1539)
- Fixed full COCO dataset import error with conflicting labels in keypoints and detection (https://github.com/opencv/cvat/pull/1548)
- Fixed COCO keypoints skeleton parsing and saving (https://github.com/opencv/cvat/issues/1539)

Expand Down
18 changes: 8 additions & 10 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def from_instance_masks(instance_masks,
if instance_ids is not None:
assert len(instance_ids) == len(instance_masks)
else:
instance_ids = [1 + i for i in range(len(instance_masks))]
instance_ids = range(1, len(instance_masks) + 1)

if instance_labels is not None:
assert len(instance_labels) == len(instance_masks)
Expand Down Expand Up @@ -310,15 +310,13 @@ def instance_mask(self):
def instance_count(self):
return int(self.instance_mask.max())

def get_instance_labels(self, class_count=None):
if class_count is None:
class_count = np.max(self.class_mask) + 1

m = self.class_mask * class_count + self.instance_mask
m = m.astype(int)
def get_instance_labels(self):
class_shift = 16
m = (self.class_mask.astype(np.uint32) << class_shift) \
+ self.instance_mask.astype(np.uint32)
keys = np.unique(m)
instance_labels = {k % class_count: k // class_count
for k in keys if k % class_count != 0
instance_labels = {k & ((1 << class_shift) - 1): k >> class_shift
for k in keys if k & ((1 << class_shift) - 1) != 0
}
return instance_labels

Expand Down Expand Up @@ -783,4 +781,4 @@ def categories(self):
return self._extractor.categories()

def transform_item(self, item):
raise NotImplementedError()
raise NotImplementedError()
3 changes: 1 addition & 2 deletions datumaro/datumaro/plugins/voc_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ def _load_annotations(self, item_id):

if class_mask is not None:
label_cat = self._categories[AnnotationType.label]
instance_labels = compiled_mask.get_instance_labels(
class_count=len(label_cat.items))
instance_labels = compiled_mask.get_instance_labels()
else:
instance_labels = {i: None
for i in range(compiled_mask.instance_count)}
Expand Down
13 changes: 12 additions & 1 deletion datumaro/tests/test_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import TestCase

import datumaro.util.mask_tools as mask_tools
from datumaro.components.extractor import CompiledMask


class PolygonConversionsTest(TestCase):
Expand Down Expand Up @@ -183,4 +184,14 @@ def test_can_merge_masks(self):
actual = mask_tools.merge_masks(masks)

self.assertTrue(np.array_equal(expected, actual),
'%s\nvs.\n%s' % (expected, actual))
'%s\nvs.\n%s' % (expected, actual))

def test_can_decode_compiled_mask(self):
class_idx = 1000
instance_idx = 10000
mask = np.array([1])
compiled_mask = CompiledMask(mask * class_idx, mask * instance_idx)

labels = compiled_mask.get_instance_labels()

self.assertEqual({instance_idx: class_idx}, labels)
35 changes: 35 additions & 0 deletions datumaro/tests/test_voc_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,41 @@ def __iter__(self):
VocSegmentationConverter(label_map='voc'), test_dir,
target_dataset=DstExtractor())

def test_can_save_voc_segm_with_many_instances(self):
def bit(x, y, shape):
mask = np.zeros(shape)
mask[y, x] = 1
return mask

class TestExtractor(TestExtractorBase):
def __iter__(self):
return iter([
DatasetItem(id=1, subset='a', annotations=[
Mask(image=bit(x, y, shape=[10, 10]),
label=self._label(VOC.VocLabel(3).name),
z_order=10 * y + x + 1
)
for y in range(10) for x in range(10)
]),
])

class DstExtractor(TestExtractorBase):
def __iter__(self):
return iter([
DatasetItem(id=1, subset='a', annotations=[
Mask(image=bit(x, y, shape=[10, 10]),
label=self._label(VOC.VocLabel(3).name),
group=10 * y + x + 1
)
for y in range(10) for x in range(10)
]),
])

with TestDir() as test_dir:
self._test_save_and_load(TestExtractor(),
VocSegmentationConverter(label_map='voc'), test_dir,
target_dataset=DstExtractor())

def test_can_save_voc_layout(self):
class TestExtractor(TestExtractorBase):
def __iter__(self):
Expand Down

0 comments on commit 98a9718

Please sign in to comment.