diff --git a/CHANGELOG.md b/CHANGELOG.md index 560e81d12e4b..c6a57f84dc5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,25 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added +- + +### Changed +- + +### Deprecated +- + +### Removed +- + +### Fixed +- Allowed explicit label removal in `remap_labels` transform () + +### Security +- + ## 31/03/2021 - Release v0.1.8 ### Added - diff --git a/datumaro/plugins/transforms.py b/datumaro/plugins/transforms.py index e634794ff75c..dfecb25a990c 100644 --- a/datumaro/plugins/transforms.py +++ b/datumaro/plugins/transforms.py @@ -18,6 +18,7 @@ ) from datumaro.components.cli_plugin import CliPlugin import datumaro.util.mask_tools as mask_tools +from datumaro.util import parse_str_enum_value, NOTSET from datumaro.util.annotation_util import find_group_leader, find_instances @@ -433,7 +434,22 @@ def transform_item(self, item): class RemapLabels(Transform, CliPlugin): """ Changes labels in the dataset.|n + |n + A label can be:|n + - renamed (and joined with existing) -|n + |s|swhen specified '--label :'|n + - deleted - when specified '--label :' or default action is 'delete'|n + |s|sand the label is not mentioned in the list. When a label|n + |s|sis deleted, all the associated annotations are removed|n + - kept unchanged - when specified '--label :'|n + |s|sor default action is 'keep' and the label is not mentioned in the list|n + Annotations with no label are managed by the default action policy.|n + |n Examples:|n + - Remove the 'person' label (and corresponding annotations):|n + |s|sremap_labels -l person: --default keep|n + - Rename 'person' to 'pedestrian' and 'human' to 'pedestrian', join:|n + |s|sremap_labels -l person:pedestrian -l human:pedestrian --default keep|n - Rename 'person' to 'car' and 'cat' to 'dog', keep 'bus', remove others:|n |s|sremap_labels -l person:car -l bus:bus -l cat:dog --default delete """ @@ -463,9 +479,9 @@ def build_cmdline_parser(cls, **kwargs): def __init__(self, extractor, mapping, default=None): super().__init__(extractor) - assert isinstance(default, (str, self.DefaultAction)) - if isinstance(default, str): - default = self.DefaultAction[default] + default = parse_str_enum_value(default, self.DefaultAction, + self.DefaultAction.keep) + self._default_action = default assert isinstance(mapping, (dict, list)) if isinstance(mapping, list): @@ -503,10 +519,10 @@ def _make_label_id_map(self, src_label_cat, label_mapping, default_action): dst_label_cat = LabelCategories(attributes=src_label_cat.attributes) id_mapping = {} for src_index, src_label in enumerate(src_label_cat.items): - dst_label = label_mapping.get(src_label.name) - if not dst_label and default_action == self.DefaultAction.keep: + dst_label = label_mapping.get(src_label.name, NOTSET) + if dst_label is NOTSET and default_action == self.DefaultAction.keep: dst_label = src_label.name # keep unspecified as is - if not dst_label: + elif not dst_label or dst_label is NOTSET: continue dst_index = dst_label_cat.find(dst_label)[0] @@ -518,7 +534,7 @@ def _make_label_id_map(self, src_label_cat, label_mapping, default_action): if log.getLogger().isEnabledFor(log.DEBUG): log.debug("Label mapping:") for src_id, src_label in enumerate(src_label_cat.items): - if id_mapping.get(src_id): + if id_mapping.get(src_id) is not None: log.debug("#%s '%s' -> #%s '%s'", src_id, src_label.name, id_mapping[src_id], dst_label_cat.items[id_mapping[src_id]].name @@ -535,14 +551,11 @@ def categories(self): def transform_item(self, item): annotations = [] for ann in item.annotations: - if ann.type in { AnnotationType.label, AnnotationType.mask, - AnnotationType.points, AnnotationType.polygon, - AnnotationType.polyline, AnnotationType.bbox - } and ann.label is not None: + if getattr(ann, 'label') is not None: conv_label = self._map_id(ann.label) if conv_label is not None: annotations.append(ann.wrap(label=conv_label)) - else: + elif self._default_action is self.DefaultAction.keep: annotations.append(ann.wrap()) return item.wrap(annotations=annotations) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5098d03634de..1e310c1832a4 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -336,15 +336,18 @@ def test_remap_labels(self): Bbox(1, 2, 3, 4, label=2), Mask(image=np.array([1]), label=3), - # Should be kept + # Should be deleted Polygon([1, 1, 2, 2, 3, 4], label=4), - PolyLine([1, 3, 4, 2, 5, 6]) + + # Should be kept + PolyLine([1, 3, 4, 2, 5, 6]), + Bbox(4, 3, 2, 1, label=5), ]) ], categories={ AnnotationType.label: LabelCategories.from_iterable( - 'label%s' % i for i in range(5)), + 'label%s' % i for i in range(6)), AnnotationType.mask: MaskCategories( - colormap=mask_tools.generate_colormap(5)), + colormap=mask_tools.generate_colormap(6)), }) dst_dataset = Dataset.from_iterable([ @@ -353,37 +356,45 @@ def test_remap_labels(self): Bbox(1, 2, 3, 4, label=0), Mask(image=np.array([1]), label=1), - Polygon([1, 1, 2, 2, 3, 4], label=2), - PolyLine([1, 3, 4, 2, 5, 6], label=None) + PolyLine([1, 3, 4, 2, 5, 6], label=None), + Bbox(4, 3, 2, 1, label=2), ]), ], categories={ AnnotationType.label: LabelCategories.from_iterable( - ['label0', 'label9', 'label4']), + ['label0', 'label9', 'label5']), AnnotationType.mask: MaskCategories(colormap={ - k: v for k, v in mask_tools.generate_colormap(5).items() - if k in { 0, 1, 3, 4 } + k: v for k, v in mask_tools.generate_colormap(6).items() + if k in { 0, 1, 3, 5 } }) }) actual = transforms.RemapLabels(src_dataset, mapping={ - 'label1': 'label9', - 'label2': 'label0', - 'label3': 'label9', + 'label1': 'label9', # rename & join with new label9 (from label3) + 'label2': 'label0', # rename & join with existing label0 + 'label3': 'label9', # rename & join with new label9 (form label1) + 'label4': '', # delete the label and associated annotations + # 'label5' - unchanged }, default='keep') compare_datasets(self, dst_dataset, actual) def test_remap_labels_delete_unspecified(self): source_dataset = Dataset.from_iterable([ - DatasetItem(id=1, annotations=[ Label(0) ]) - ], categories=['label0']) + DatasetItem(id=1, annotations=[ + Label(0, id=0), # will be removed + Label(1, id=1), + Bbox(1, 2, 3, 4, label=None), + ]) + ], categories=['label0', 'label1']) target_dataset = Dataset.from_iterable([ - DatasetItem(id=1), - ], categories=[]) + DatasetItem(id=1, annotations=[ + Label(0, id=1), + ]), + ], categories=['label1']) actual = transforms.RemapLabels(source_dataset, - mapping={}, default='delete') + mapping={ 'label1': 'label1' }, default='delete') compare_datasets(self, target_dataset, actual)