diff --git a/CHANGELOG.md b/CHANGELOG.md index f356129b09ca..14ea733fc6da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - TBD ### Fixed +- Application of `remap_labels` to dataset categories of different length () - Patching of datasets in formats () - Unsafe unpickling in CIFAR import () - Improved Cityscapes export performance () diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index 793bcdc5afdb..d4e7823de242 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -4,7 +4,7 @@ from enum import Enum, auto from glob import iglob -from typing import Callable, Dict, Iterable, List, Optional +from typing import Callable, Dict, Iterable, List, Optional, Tuple import os import os.path as osp @@ -159,6 +159,15 @@ def inverse_colormap(self): self._inverse_colormap = invert_colormap(self.colormap) return self._inverse_colormap + def __contains__(self, idx: int) -> bool: + return idx in self.colormap + + def __getitem__(self, idx: int) -> Tuple[int, int, int]: + return self.colormap[idx] + + def __len__(self): + return len(self.colormap) + def __eq__(self, other): if not super().__eq__(other): return False @@ -530,6 +539,16 @@ def add(self, label_id, labels=None, joints=None): joints = set(map(tuple, joints)) self.items[label_id] = self.Category(labels, joints) + def __contains__(self, idx: int) -> bool: + return idx in self.items + + def __getitem__(self, idx: int) -> Tuple[int, int, int]: + return self.items[idx] + + def __len__(self): + return len(self.items) + + @attrs class Points(_Shape): class Visibility(Enum): diff --git a/datumaro/plugins/transforms.py b/datumaro/plugins/transforms.py index 602d519a31d3..8fb2cf997637 100644 --- a/datumaro/plugins/transforms.py +++ b/datumaro/plugins/transforms.py @@ -501,20 +501,20 @@ def __init__(self, extractor, mapping, default=None): assert src_label_cat is not None dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes) dst_mask_cat.colormap = { - id: src_mask_cat.colormap[id] + id: src_mask_cat[id] for id, _ in enumerate(src_label_cat.items) - if self._map_id(id) or id == 0 + if id in src_mask_cat and (self._map_id(id) or id == 0) } self._categories[AnnotationType.mask] = dst_mask_cat - src_points_cat = self._extractor.categories().get(AnnotationType.points) - if src_points_cat is not None: + src_point_cat = self._extractor.categories().get(AnnotationType.points) + if src_point_cat is not None: assert src_label_cat is not None - dst_points_cat = PointsCategories(attributes=src_points_cat.attributes) + dst_points_cat = PointsCategories(attributes=src_point_cat.attributes) dst_points_cat.items = { - id: src_points_cat.items[id] - for id, item in enumerate(src_label_cat.items) - if self._map_id(id) or id == 0 + id: src_point_cat[id] + for id, _ in enumerate(src_label_cat.items) + if id in src_point_cat and (self._map_id(id) or id == 0) } self._categories[AnnotationType.points] = dst_points_cat diff --git a/tests/requirements.py b/tests/requirements.py index d1f535f336e7..f7dc2ba798b4 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -30,6 +30,7 @@ class Requirements: DATUM_BUG_219 = "Return format is not uniform" DATUM_BUG_257 = "Dataset.filter doesn't count removed items" DATUM_BUG_259 = "Dataset.filter fails on merged datasets" + DATUM_BUG_314 = "Unsuccessful remap_labels" DATUM_BUG_402 = "Troubles running 'remap_labels' on ProjectDataset" DATUM_BUG_404 = "custom importer/extractor not loading" diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 2bc4fbcfb3b5..31656f083f45 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,7 +5,7 @@ from datumaro.components.extractor import ( AnnotationType, Bbox, DatasetItem, Label, LabelCategories, Mask, - MaskCategories, Points, Polygon, PolyLine, + MaskCategories, Points, PointsCategories, Polygon, PolyLine, ) from datumaro.components.project import Dataset from datumaro.util.test_utils import compare_datasets @@ -415,6 +415,33 @@ def test_remap_labels_delete_unspecified(self): compare_datasets(self, target_dataset, actual) + @mark_requirement(Requirements.DATUM_BUG_314) + def test_remap_labels_ignore_missing_labels_in_secondary_categories(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Label(0), + ]) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable(['a', 'b', 'c']), + AnnotationType.points: PointsCategories.from_iterable([]), # all missing + AnnotationType.mask: MaskCategories.generate(2) # no c color + }) + + target_dataset = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Label(0), + ]), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable(['d', 'e', 'f']), + AnnotationType.points: PointsCategories.from_iterable([]), + AnnotationType.mask: MaskCategories.generate(2) + }) + + actual = transforms.RemapLabels(source_dataset, + mapping={ 'a': 'd', 'b': 'e', 'c': 'f' }, default='delete') + + compare_datasets(self, target_dataset, actual) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_transform_labels(self): src_dataset = Dataset.from_iterable([