diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b1b7ed29e1..84dbdf61a1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007) - Annotation in process outline color wheel () - [Datumaro] CLI command for dataset equality comparison () +- [Datumaro] Merging of datasets with different labels () ### Changed - UI models (like DEXTR) were redesigned to be more interactive () diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index b213b6231c5..dcb7b036c04 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -49,7 +49,11 @@ class Categories: @attrs class LabelCategories(Categories): - Category = namedtuple('Category', ['name', 'parent', 'attributes']) + @attrs(repr_ns='LabelCategories') + class Category: + name = attrib(converter=str, validator=not_empty) + parent = attrib(default='', validator=default_if_none(str)) + attributes = attrib(factory=set, validator=default_if_none(set)) items = attrib(factory=list, validator=default_if_none(list)) _indices = attrib(factory=dict, init=False, eq=False) @@ -93,15 +97,6 @@ def _reindex(self): def add(self, name: str, parent: str = None, attributes: dict = None): assert name not in self._indices, name - if attributes is None: - attributes = set() - else: - if not isinstance(attributes, set): - attributes = set(attributes) - for attr in attributes: - assert isinstance(attr, str) - if parent is None: - parent = '' index = len(self.items) self.items.append(self.Category(name, parent, attributes)) @@ -386,7 +381,10 @@ def wrap(item, **kwargs): @attrs class PointsCategories(Categories): - Category = namedtuple('Category', ['labels', 'joints']) + @attrs(repr_ns="PointsCategories") + class Category: + labels = attrib(factory=list, validator=default_if_none(list)) + joints = attrib(factory=set, validator=default_if_none(set)) items = attrib(factory=dict, validator=default_if_none(dict)) @@ -396,28 +394,19 @@ def from_iterable(cls, iterable): Args: iterable ([type]): This iterable object can be: - 1)simple int - will generate one Category with int as label - 2)list of int - will interpreted as list of Category labels - 3)list of positional argumetns - will generate Categories - with this arguments + 1) list of positional argumetns - will generate Categories + with these arguments Returns: PointsCategories: PointsCategories object """ temp_categories = cls() - if isinstance(iterable, int): - iterable = [[iterable]] - for category in iterable: - if isinstance(category, int): - category = [category] temp_categories.add(*category) return temp_categories def add(self, label_id, labels=None, joints=None): - if labels is None: - labels = [] if joints is None: joints = [] joints = set(map(tuple, joints)) diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 2e3a68136db..d887add9551 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -15,7 +15,8 @@ from unittest import TestCase from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.extractor import AnnotationType, Bbox, Label +from datumaro.components.extractor import (AnnotationType, Bbox, Label, + LabelCategories, PointsCategories, MaskCategories) from datumaro.components.project import Dataset from datumaro.util import find, filter_dict from datumaro.util.attrs_util import ensure_cls, default_if_none @@ -53,7 +54,8 @@ def merge_categories(sources): for cat_type, source_cat in source.items(): if not categories[cat_type] == source_cat: raise NotImplementedError( - "Merging different categories is not implemented yet") + "Merging of datasets with different categories is " + "only allowed in 'merge' command.") return categories class MergingStrategy(CliPlugin): @@ -180,7 +182,8 @@ def add_item_error(self, error, *args, **kwargs): _categories = attrib(init=False) # merged categories def __call__(self, datasets): - self._categories = merge_categories(d.categories() for d in datasets) + self._categories = self._merge_categories( + [d.categories() for d in datasets]) merged = Dataset(categories=self._categories) self._check_groups_definition() @@ -283,6 +286,126 @@ def match_items(datasets): return matches, item_map + def _merge_label_categories(self, sources): + same = True + common = None + for src_categories in sources: + src_cat = src_categories.get(AnnotationType.label) + if common is None: + common = src_cat + elif common != src_cat: + same = False + break + + if same: + return common + + dst_cat = LabelCategories() + for src_id, src_categories in enumerate(sources): + src_cat = src_categories.get(AnnotationType.label) + if src_cat is None: + continue + + for src_label in src_cat.items: + dst_label = dst_cat.find(src_label.name)[1] + if dst_label is not None: + if dst_label != src_label: + if src_label.parent and dst_label.parent and \ + src_label.parent != dst_label.parent: + raise ValueError("Can't merge label category " + "%s (from #%s): " + "parent label conflict: %s vs. %s" % \ + (src_label.name, src_id, + src_label.parent, dst_label.parent) + ) + dst_label.parent = dst_label.parent or src_label.parent + dst_label.attributes |= src_label.attributes + else: + pass + else: + dst_cat.add(src_label.name, + src_label.parent, src_label.attributes) + + return dst_cat + + def _merge_point_categories(self, sources, label_cat): + dst_point_cat = PointsCategories() + + for src_id, src_categories in enumerate(sources): + src_label_cat = src_categories.get(AnnotationType.label) + src_point_cat = src_categories.get(AnnotationType.points) + if src_label_cat is None or src_point_cat is None: + continue + + for src_label_id, src_cat in src_point_cat.items.items(): + src_label = src_label_cat.items[src_label_id].name + dst_label_id = label_cat.find(src_label)[0] + dst_cat = dst_point_cat.items.get(dst_label_id) + if dst_cat is not None: + if dst_cat != src_cat: + raise ValueError("Can't merge point category for label " + "%s (from #%s): %s vs. %s" % \ + (src_label, src_id, src_cat, dst_cat) + ) + else: + pass + else: + dst_point_cat.add(dst_label_id, + src_cat.labels, src_cat.joints) + + if len(dst_point_cat.items) == 0: + return None + + return dst_point_cat + + def _merge_mask_categories(self, sources, label_cat): + dst_mask_cat = MaskCategories() + + for src_id, src_categories in enumerate(sources): + src_label_cat = src_categories.get(AnnotationType.label) + src_mask_cat = src_categories.get(AnnotationType.mask) + if src_label_cat is None or src_mask_cat is None: + continue + + for src_label_id, src_cat in src_mask_cat.colormap.items(): + src_label = src_label_cat.items[src_label_id].name + dst_label_id = label_cat.find(src_label)[0] + dst_cat = dst_mask_cat.colormap.get(dst_label_id) + if dst_cat is not None: + if dst_cat != src_cat: + raise ValueError("Can't merge mask category for label " + "%s (from #%s): %s vs. %s" % \ + (src_label, src_id, src_cat, dst_cat) + ) + else: + pass + else: + dst_mask_cat.colormap[dst_label_id] = src_cat + + if len(dst_mask_cat.colormap) == 0: + return None + + return dst_mask_cat + + def _merge_categories(self, sources): + dst_categories = {} + + label_cat = self._merge_label_categories(sources) + if label_cat is None: + return dst_categories + + dst_categories[AnnotationType.label] = label_cat + + points_cat = self._merge_point_categories(sources, label_cat) + if points_cat is not None: + dst_categories[AnnotationType.points] = points_cat + + mask_cat = self._merge_mask_categories(sources, label_cat) + if mask_cat is not None: + dst_categories[AnnotationType.mask] = mask_cat + + return dst_categories + def _match_annotations(self, sources): all_by_type = {} for s in sources: @@ -473,8 +596,29 @@ def _check_group(group_labels, group): _check_group(group_labels, group) def _get_label_name(self, label_id): + if label_id is None: + return None return self._categories[AnnotationType.label].items[label_id].name + def _get_label_id(self, label): + return self._categories[AnnotationType.label].find(label)[0] + + def _get_src_label_name(self, ann, label_id): + if label_id is None: + return None + item_id = self._ann_map[id(ann)][1] + dataset_id = self._item_map[item_id][1] + return self._dataset_map[dataset_id][0] \ + .categories()[AnnotationType.label].items[label_id].name + + def _get_any_label_name(self, ann, label_id): + if label_id is None: + return None + try: + return self._get_src_label_name(ann, label_id) + except KeyError: + return self._get_label_name(label_id) + def _check_groups_definition(self): for group in self.conf.groups: for label, _ in group: @@ -486,16 +630,19 @@ def _check_groups_definition(self): self._categories[AnnotationType.label].items]) ) -@attrs +@attrs(kw_only=True) class AnnotationMatcher: + _context = attrib(type=IntersectMerge, default=None) + def match_annotations(self, sources): raise NotImplementedError() @attrs class LabelMatcher(AnnotationMatcher): - @staticmethod - def distance(a, b): - return a.label == b.label + def distance(self, a, b): + a_label = self._context._get_any_label_name(a, a.label) + b_label = self._context._get_any_label_name(b, b.label) + return a_label == b_label def match_annotations(self, sources): return [sum(sources, [])] @@ -507,6 +654,7 @@ class _ShapeMatcher(AnnotationMatcher): def match_annotations(self, sources): distance = self.distance + label_matcher = self.label_matcher pairwise_dist = self.pairwise_dist cluster_dist = self.cluster_dist @@ -537,9 +685,10 @@ def _has_same_source(cluster, extra_id): for a_idx, src_a in enumerate(sources): for src_b in sources[a_idx+1 :]: matches, _, _, _ = match_segments(src_a, src_b, - dist_thresh=pairwise_dist, distance=distance) - for m in matches: - adjacent[id(m[0])].append(id(m[1])) + dist_thresh=pairwise_dist, + distance=distance, label_matcher=label_matcher) + for a, b in matches: + adjacent[id(a)].append(id(b)) # join all segments into matching clusters clusters = [] @@ -573,6 +722,11 @@ def _has_same_source(cluster, extra_id): def distance(a, b): return segment_iou(a, b) + def label_matcher(self, a, b): + a_label = self._context._get_any_label_name(a, a.label) + b_label = self._context._get_any_label_name(b, b.label) + return a_label == b_label + @attrs class BboxMatcher(_ShapeMatcher): pass @@ -626,8 +780,6 @@ def match_annotations(self, sources): @attrs(kw_only=True) class AnnotationMerger: - _context = attrib(type=IntersectMerge, default=None) - def merge_clusters(self, clusters): raise NotImplementedError() @@ -641,20 +793,22 @@ def merge_clusters(self, clusters): return [] votes = {} # label -> score - for label_ann in clusters[0]: - votes[label_ann.label] = 1 + votes.get(label_ann.label, 0) + for ann in clusters[0]: + label = self._context._get_src_label_name(ann, ann.label) + votes[label] = 1 + votes.get(label, 0) merged = [] for label, count in votes.items(): if count < self.quorum: sources = set(self.get_ann_source(id(a)) for a in clusters[0] - if label not in [l.label for l in a]) + if label not in [self._context._get_src_label_name(l, l.label) + for l in a]) sources = [self._context._dataset_map[s][1] for s in sources] self._context.add_item_error(FailedLabelVotingError, sources, votes) continue - merged.append(Label(label, attributes={ + merged.append(Label(self._context._get_label_id(label), attributes={ 'score': count / len(self._context._dataset_map) })) @@ -682,14 +836,17 @@ def merge_clusters(self, clusters): def find_cluster_label(self, cluster): votes = {} for s in cluster: - state = votes.setdefault(s.label, [0, 0]) + label = self._context._get_src_label_name(s, s.label) + state = votes.setdefault(label, [0, 0]) state[0] += s.attributes.get('score', 1.0) state[1] += 1 label, (score, count) = max(votes.items(), key=lambda e: e[1][0]) if count < self.quorum: self._context.add_item_error(FailedLabelVotingError, votes) - score = score / count if count else None + label = None + score = score / len(self._context._dataset_map) + label = self._context._get_label_id(label) return label, score @staticmethod @@ -729,11 +886,10 @@ class LineMerger(_ShapeMerger, LineMatcher): class CaptionsMerger(AnnotationMerger, CaptionsMatcher): pass -def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0): - if distance == 'iou': - distance = segment_iou - else: - assert callable(distance) +def match_segments(a_segms, b_segms, distance=segment_iou, dist_thresh=1.0, + label_matcher=lambda a, b: a.label == b.label): + assert callable(distance), distance + assert callable(label_matcher), label_matcher a_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) b_segms.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) @@ -753,13 +909,16 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0): for a_idx, a_segm in enumerate(a_segms): if len(b_segms) == 0: break - matched_b = a_matches[a_idx] - max_dist = max(distances[a_idx, matched_b], dist_thresh) - for b_idx, b_segm in enumerate(b_segms): + matched_b = -1 + max_dist = -1 + b_indices = np.argsort([not label_matcher(a_segm, b_segm) + for b_segm in b_segms], + kind='stable') # prioritize those with same label, keep score order + for b_idx in b_indices: if 0 <= b_matches[b_idx]: # assign a_segm with max conf continue d = distances[a_idx, b_idx] - if d < max_dist: + if d < dist_thresh or d <= max_dist: continue max_dist = d matched_b = b_idx @@ -771,7 +930,7 @@ def match_segments(a_segms, b_segms, distance='iou', dist_thresh=1.0): b_segm = b_segms[matched_b] - if a_segm.label == b_segm.label: + if label_matcher(a_segm, b_segm): matches.append( (a_segm, b_segm) ) else: mispred.append( (a_segm, b_segm) ) diff --git a/datumaro/datumaro/util/annotation_util.py b/datumaro/datumaro/util/annotation_util.py index 63950a14343..3daa313f3fb 100644 --- a/datumaro/datumaro/util/annotation_util.py +++ b/datumaro/datumaro/util/annotation_util.py @@ -118,7 +118,7 @@ def _to_rle(ann): if ann.type == AnnotationType.polygon: return mask_utils.frPyObjects([ann.points], h, w) elif isinstance(ann, RleMask): - return [ann._rle] + return [ann.rle] elif ann.type == AnnotationType.mask: return mask_utils.frPyObjects([mask_to_rle(ann.image)], h, w) else: diff --git a/datumaro/requirements.txt b/datumaro/requirements.txt index b5142853b21..6bc3c7ee799 100644 --- a/datumaro/requirements.txt +++ b/datumaro/requirements.txt @@ -7,6 +7,6 @@ matplotlib>=3.3.1 opencv-python-headless>=4.1.0.25 Pillow>=6.1.0 pycocotools>=2.0.0 -PyYAML>=5.1.1 +PyYAML>=5.3.1 scikit-image>=0.15.0 tensorboardX>=1.8 diff --git a/datumaro/tests/test_ops.py b/datumaro/tests/test_ops.py index dd4520b52f6..5b7355bf79c 100644 --- a/datumaro/tests/test_ops.py +++ b/datumaro/tests/test_ops.py @@ -3,7 +3,8 @@ import numpy as np from datumaro.components.extractor import (Bbox, Caption, DatasetItem, - Extractor, Label, Mask, Points, Polygon, PolyLine) + Extractor, Label, Mask, Points, Polygon, PolyLine, + LabelCategories, PointsCategories, MaskCategories, AnnotationType) from datumaro.components.operations import (FailedAttrVotingError, IntersectMerge, NoMatchingAnnError, NoMatchingItemError, WrongGroupError, compute_ann_statistics, mean_std) @@ -198,7 +199,7 @@ def test_can_match_shapes(self): Bbox(1, 2, 3, 4, label=1), # common - Mask(label=3, z_order=2, image=np.array([ + Mask(label=2, z_order=2, image=np.array([ [0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 0], @@ -218,7 +219,7 @@ def test_can_match_shapes(self): source1 = Dataset.from_iterable([ DatasetItem(1, annotations=[ # common - Mask(label=3, image=np.array([ + Mask(label=2, image=np.array([ [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1], @@ -238,7 +239,7 @@ def test_can_match_shapes(self): source2 = Dataset.from_iterable([ DatasetItem(1, annotations=[ # common - Mask(label=3, z_order=3, image=np.array([ + Mask(label=2, z_order=3, image=np.array([ [0, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1], @@ -261,7 +262,7 @@ def test_can_match_shapes(self): # common # nearest to mean bbox - Mask(label=3, z_order=3, image=np.array([ + Mask(label=2, z_order=3, image=np.array([ [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1], @@ -365,3 +366,86 @@ def test_group_checks(self): self.assertEqual(3, len([e for e in merger.errors if isinstance(e, WrongGroupError)]), merger.errors ) + + def test_can_merge_classes(self): + source0 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(0), + Label(1), + Bbox(0, 0, 1, 1, label=1), + ]), + ], categories=['a', 'b']) + + source1 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(0), + Label(1), + Bbox(0, 0, 1, 1, label=0), + Bbox(0, 0, 1, 1, label=1), + ]), + ], categories=['b', 'c']) + + expected = Dataset.from_iterable([ + DatasetItem(1, annotations=[ + Label(0), + Label(1), + Label(2), + Bbox(0, 0, 1, 1, label=1), + Bbox(0, 0, 1, 1, label=2), + ]), + ], categories=['a', 'b', 'c']) + + merger = IntersectMerge() + merged = merger([source0, source1]) + + compare_datasets(self, expected, merged, ignored_attrs={'score'}) + + def test_can_merge_categories(self): + source0 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ Label(0), ]), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable(['a', 'b']), + AnnotationType.points: PointsCategories.from_iterable([ + (0, ['l0', 'l1']), + (1, ['l2', 'l3']), + ]), + AnnotationType.mask: MaskCategories({ + 0: (0, 1, 2), + 1: (1, 2, 3), + }), + }) + + source1 = Dataset.from_iterable([ + DatasetItem(1, annotations=[ Label(0), ]), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable(['c', 'b']), + AnnotationType.points: PointsCategories.from_iterable([ + (0, []), + (1, ['l2', 'l3']), + ]), + AnnotationType.mask: MaskCategories({ + 0: (0, 2, 4), + 1: (1, 2, 3), + }), + }) + + expected = Dataset.from_iterable([ + DatasetItem(1, annotations=[ Label(0), Label(2), ]), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable(['a', 'b', 'c']), + AnnotationType.points: PointsCategories.from_iterable([ + (0, ['l0', 'l1']), + (1, ['l2', 'l3']), + (2, []), + ]), + AnnotationType.mask: MaskCategories({ + 0: (0, 1, 2), + 1: (1, 2, 3), + 2: (0, 2, 4), + }), + }) + + merger = IntersectMerge() + merged = merger([source0, source1]) + + compare_datasets(self, expected, merged, ignored_attrs={'score'}) \ No newline at end of file