Skip to content

Commit

Permalink
Fix remap_labels for incomplete categories (cvat-ai#408)
Browse files Browse the repository at this point in the history
* Fix remap labels for missing secondary categories

* Update changelog
  • Loading branch information
Maxim Zhiltsov authored Aug 18, 2021
1 parent 6b91d77 commit b510e05
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- TBD

### Fixed
- Application of `remap_labels` to dataset categories of different length (<https://github.com/openvinotoolkit/datumaro/issues/314>)
- Patching of datasets in formats (<https://github.com/openvinotoolkit/datumaro/issues/348>)
- Unsafe unpickling in CIFAR import (<https://github.com/openvinotoolkit/datumaro/pull/362>)
- Improved Cityscapes export performance (<https://github.com/openvinotoolkit/datumaro/pull/367>)
Expand Down
21 changes: 20 additions & 1 deletion datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from enum import Enum, auto
from glob import iglob
from typing import Callable, Dict, Iterable, List, Optional
from typing import Callable, Dict, Iterable, List, Optional, Tuple
import os
import os.path as osp

Expand Down Expand Up @@ -159,6 +159,15 @@ def inverse_colormap(self):
self._inverse_colormap = invert_colormap(self.colormap)
return self._inverse_colormap

def __contains__(self, idx: int) -> bool:
return idx in self.colormap

def __getitem__(self, idx: int) -> Tuple[int, int, int]:
return self.colormap[idx]

def __len__(self):
return len(self.colormap)

def __eq__(self, other):
if not super().__eq__(other):
return False
Expand Down Expand Up @@ -530,6 +539,16 @@ def add(self, label_id, labels=None, joints=None):
joints = set(map(tuple, joints))
self.items[label_id] = self.Category(labels, joints)

def __contains__(self, idx: int) -> bool:
return idx in self.items

def __getitem__(self, idx: int) -> Tuple[int, int, int]:
return self.items[idx]

def __len__(self):
return len(self.items)


@attrs
class Points(_Shape):
class Visibility(Enum):
Expand Down
16 changes: 8 additions & 8 deletions datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,20 +501,20 @@ def __init__(self, extractor, mapping, default=None):
assert src_label_cat is not None
dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes)
dst_mask_cat.colormap = {
id: src_mask_cat.colormap[id]
id: src_mask_cat[id]
for id, _ in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
if id in src_mask_cat and (self._map_id(id) or id == 0)
}
self._categories[AnnotationType.mask] = dst_mask_cat

src_points_cat = self._extractor.categories().get(AnnotationType.points)
if src_points_cat is not None:
src_point_cat = self._extractor.categories().get(AnnotationType.points)
if src_point_cat is not None:
assert src_label_cat is not None
dst_points_cat = PointsCategories(attributes=src_points_cat.attributes)
dst_points_cat = PointsCategories(attributes=src_point_cat.attributes)
dst_points_cat.items = {
id: src_points_cat.items[id]
for id, item in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
id: src_point_cat[id]
for id, _ in enumerate(src_label_cat.items)
if id in src_point_cat and (self._map_id(id) or id == 0)
}
self._categories[AnnotationType.points] = dst_points_cat

Expand Down
1 change: 1 addition & 0 deletions tests/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Requirements:
DATUM_BUG_219 = "Return format is not uniform"
DATUM_BUG_257 = "Dataset.filter doesn't count removed items"
DATUM_BUG_259 = "Dataset.filter fails on merged datasets"
DATUM_BUG_314 = "Unsuccessful remap_labels"
DATUM_BUG_402 = "Troubles running 'remap_labels' on ProjectDataset"
DATUM_BUG_404 = "custom importer/extractor not loading"

Expand Down
29 changes: 28 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from datumaro.components.extractor import (
AnnotationType, Bbox, DatasetItem, Label, LabelCategories, Mask,
MaskCategories, Points, Polygon, PolyLine,
MaskCategories, Points, PointsCategories, Polygon, PolyLine,
)
from datumaro.components.project import Dataset
from datumaro.util.test_utils import compare_datasets
Expand Down Expand Up @@ -415,6 +415,33 @@ def test_remap_labels_delete_unspecified(self):

compare_datasets(self, target_dataset, actual)

@mark_requirement(Requirements.DATUM_BUG_314)
def test_remap_labels_ignore_missing_labels_in_secondary_categories(self):
source_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Label(0),
])
], categories={
AnnotationType.label: LabelCategories.from_iterable(['a', 'b', 'c']),
AnnotationType.points: PointsCategories.from_iterable([]), # all missing
AnnotationType.mask: MaskCategories.generate(2) # no c color
})

target_dataset = Dataset.from_iterable([
DatasetItem(id=1, annotations=[
Label(0),
]),
], categories={
AnnotationType.label: LabelCategories.from_iterable(['d', 'e', 'f']),
AnnotationType.points: PointsCategories.from_iterable([]),
AnnotationType.mask: MaskCategories.generate(2)
})

actual = transforms.RemapLabels(source_dataset,
mapping={ 'a': 'd', 'b': 'e', 'c': 'f' }, default='delete')

compare_datasets(self, target_dataset, actual)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_transform_labels(self):
src_dataset = Dataset.from_iterable([
Expand Down

0 comments on commit b510e05

Please sign in to comment.