Skip to content

Commit

Permalink
[Datumaro] Merge with different categories (#2098)
Browse files Browse the repository at this point in the history
* Add category merging

* Update error message

* Add category merging test

* update changelog

* Fix field access

* remove import

* Update CHANGELOG.md

Co-authored-by: Nikita Manovich <[email protected]>
  • Loading branch information
zhiltsov-max and Nikita Manovich authored Sep 4, 2020
1 parent 4dbfa3b commit ffb71fb
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 57 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007)
- Annotation in process outline color wheel (<https://github.com/opencv/cvat/pull/2084>)
- [Datumaro] CLI command for dataset equality comparison (<https://github.com/opencv/cvat/pull/1989>)
- [Datumaro] Merging of datasets with different labels (<https://github.com/opencv/cvat/pull/2098>)

### Changed
- UI models (like DEXTR) were redesigned to be more interactive (<https://github.com/opencv/cvat/pull/2054>)
Expand Down
33 changes: 11 additions & 22 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ class Categories:

@attrs
class LabelCategories(Categories):
Category = namedtuple('Category', ['name', 'parent', 'attributes'])
@attrs(repr_ns='LabelCategories')
class Category:
name = attrib(converter=str, validator=not_empty)
parent = attrib(default='', validator=default_if_none(str))
attributes = attrib(factory=set, validator=default_if_none(set))

items = attrib(factory=list, validator=default_if_none(list))
_indices = attrib(factory=dict, init=False, eq=False)
Expand Down Expand Up @@ -93,15 +97,6 @@ def _reindex(self):

def add(self, name: str, parent: str = None, attributes: dict = None):
assert name not in self._indices, name
if attributes is None:
attributes = set()
else:
if not isinstance(attributes, set):
attributes = set(attributes)
for attr in attributes:
assert isinstance(attr, str)
if parent is None:
parent = ''

index = len(self.items)
self.items.append(self.Category(name, parent, attributes))
Expand Down Expand Up @@ -386,7 +381,10 @@ def wrap(item, **kwargs):

@attrs
class PointsCategories(Categories):
Category = namedtuple('Category', ['labels', 'joints'])
@attrs(repr_ns="PointsCategories")
class Category:
labels = attrib(factory=list, validator=default_if_none(list))
joints = attrib(factory=set, validator=default_if_none(set))

items = attrib(factory=dict, validator=default_if_none(dict))

Expand All @@ -396,28 +394,19 @@ def from_iterable(cls, iterable):
Args:
iterable ([type]): This iterable object can be:
1)simple int - will generate one Category with int as label
2)list of int - will interpreted as list of Category labels
3)list of positional argumetns - will generate Categories
with this arguments
1) list of positional argumetns - will generate Categories
with these arguments
Returns:
PointsCategories: PointsCategories object
"""
temp_categories = cls()

if isinstance(iterable, int):
iterable = [[iterable]]

for category in iterable:
if isinstance(category, int):
category = [category]
temp_categories.add(*category)
return temp_categories

def add(self, label_id, labels=None, joints=None):
if labels is None:
labels = []
if joints is None:
joints = []
joints = set(map(tuple, joints))
Expand Down
Loading

0 comments on commit ffb71fb

Please sign in to comment.