Skip to content

Commit

Permalink
[Datumaro] Add DatasetItem attributes (#1639)
Browse files Browse the repository at this point in the history
* Add DatasetItem attributes

* Update tests

* Update datumaro format
  • Loading branch information
zhiltsov-max authored Jun 8, 2020
1 parent b2503c6 commit ba309c8
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 37 deletions.
18 changes: 15 additions & 3 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def __eq__(self, other):
class DatasetItem:
# pylint: disable=redefined-builtin
def __init__(self, id=None, annotations=None,
subset=None, path=None, image=None):
subset=None, path=None, image=None, attributes=None):
assert id is not None
self._id = str(id)

Expand Down Expand Up @@ -604,6 +604,12 @@ def __init__(self, id=None, annotations=None,
image = Image(path=image)
assert image is None or isinstance(image, Image)
self._image = image

if attributes is None:
attributes = {}
else:
attributes = dict(attributes)
self._attributes = attributes
# pylint: enable=redefined-builtin

@property
Expand All @@ -630,6 +636,10 @@ def image(self):
def has_image(self):
return self._image is not None

@property
def attributes(self):
return self._attributes

def __eq__(self, other):
if not isinstance(other, __class__):
return False
Expand All @@ -638,10 +648,12 @@ def __eq__(self, other):
(self.subset == other.subset) and \
(self.path == other.path) and \
(self.annotations == other.annotations) and \
(self.image == other.image)
(self.image == other.image) and \
(self.attributes == other.attributes)

def wrap(item, **kwargs):
expected_args = {'id', 'annotations', 'subset', 'path', 'image'}
expected_args = {'id', 'annotations', 'subset',
'path', 'image', 'attributes'}
for k in expected_args:
if k not in kwargs:
kwargs[k] = getattr(item, k)
Expand Down
2 changes: 2 additions & 0 deletions datumaro/datumaro/plugins/datumaro_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def write_item(self, item):
'id': item.id,
'annotations': annotations,
}
if item.attributes:
item_desc['attr'] = item.attributes
if item.path:
item_desc['path'] = item.path
if item.has_image:
Expand Down
6 changes: 4 additions & 2 deletions datumaro/datumaro/plugins/datumaro_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,15 @@ def _load_items(self, parsed):
annotations = self._load_annotations(item_desc)

item = DatasetItem(id=item_id, subset=self._subset,
annotations=annotations, image=image)
annotations=annotations, image=image,
attributes=item_desc.get('attr'))

items.append(item)

return items

def _load_annotations(self, item):
@staticmethod
def _load_annotations(item):
parsed = item['annotations']
loaded = []

Expand Down
19 changes: 18 additions & 1 deletion datumaro/datumaro/util/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,21 @@ def compare_datasets(test, expected, actual):

ann_b = find(ann_b_matches, lambda x: x == ann_a)
test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a))
item_b.annotations.remove(ann_b) # avoid repeats
item_b.annotations.remove(ann_b) # avoid repeats

def compare_datasets_strict(test, expected, actual):
# Compares datasets for strong equality

test.assertEqual(expected.categories(), actual.categories())

test.assertListEqual(sorted(expected.subsets()), sorted(actual.subsets()))
test.assertEqual(len(expected), len(actual))

for subset_name in expected.subsets():
e_subset = expected.get_subset(subset_name)
a_subset = actual.get_subset(subset_name)
test.assertEqual(len(e_subset), len(a_subset))
for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)):
test.assertEqual(item_a, item_b,
'%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b)))
52 changes: 21 additions & 31 deletions datumaro/tests/test_datumaro_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,24 @@
from datumaro.plugins.datumaro_format.converter import DatumaroConverter
from datumaro.util.mask_tools import generate_colormap
from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, item_to_str

from datumaro.util.test_utils import TestDir, compare_datasets_strict

class DatumaroConverterTest(TestCase):
def _test_save_and_load(self, source_dataset, converter, test_dir,
target_dataset=None, importer_args=None):
converter(source_dataset, test_dir)

if importer_args is None:
importer_args = {}
parsed_dataset = Project.import_from(
test_dir, 'datumaro', **importer_args).make_dataset()

if target_dataset is None:
target_dataset = source_dataset

compare_datasets_strict(self,
expected=target_dataset, actual=parsed_dataset)

class TestExtractor(Extractor):
def __iter__(self):
return iter([
Expand Down Expand Up @@ -47,7 +61,8 @@ def __iter__(self):
Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4),
]),

DatasetItem(id=42, subset='test'),
DatasetItem(id=42, subset='test',
attributes={'a1': 5, 'a2': '42'}),

DatasetItem(id=42),
DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))),
Expand All @@ -73,36 +88,11 @@ def categories(self):

def test_can_save_and_load(self):
with TestDir() as test_dir:
source_dataset = self.TestExtractor()

converter = DatumaroConverter(save_images=True)
converter(source_dataset, test_dir)

project = Project.import_from(test_dir, 'datumaro')
parsed_dataset = project.make_dataset()

self.assertListEqual(
sorted(source_dataset.subsets()),
sorted(parsed_dataset.subsets()),
)

self.assertEqual(len(source_dataset), len(parsed_dataset))

for subset_name in source_dataset.subsets():
source_subset = source_dataset.get_subset(subset_name)
parsed_subset = parsed_dataset.get_subset(subset_name)
self.assertEqual(len(source_subset), len(parsed_subset))
for idx, (item_a, item_b) in enumerate(
zip(source_subset, parsed_subset)):
self.assertEqual(item_a, item_b, '%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b)))

self.assertEqual(
source_dataset.categories(),
parsed_dataset.categories())
self._test_save_and_load(self.TestExtractor(),
DatumaroConverter(save_images=True), test_dir)

def test_can_detect(self):
with TestDir() as test_dir:
DatumaroConverter()(self.TestExtractor(), save_dir=test_dir)

self.assertTrue(DatumaroImporter.detect(test_dir))
self.assertTrue(DatumaroImporter.detect(test_dir))

0 comments on commit ba309c8

Please sign in to comment.