diff --git a/CHANGELOG.md b/CHANGELOG.md index 6946398f34..b48c33b7d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added -- +- `ByteImage` class to represent encoded images in memory and avoid recoding on save () ### Changed - Implementation of format plugins simplified () @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Removed -- +- `image/depth` value from VOC export () ### Fixed - Zero division errors in dataset statistics () diff --git a/datumaro/components/converter.py b/datumaro/components/converter.py index 12224489a3..086e5d0d02 100644 --- a/datumaro/components/converter.py +++ b/datumaro/components/converter.py @@ -9,7 +9,7 @@ import shutil from datumaro.components.cli_plugin import CliPlugin -from datumaro.util.image import save_image +from datumaro.util.image import save_image, ByteImage class Converter(CliPlugin): @@ -49,7 +49,7 @@ def __init__(self, extractor, save_dir, save_images=False, def _find_image_ext(self, item): src_ext = None if item.has_image: - src_ext = osp.splitext(osp.basename(item.image.path))[1] + src_ext = item.image.ext return self._image_ext or src_ext or self._default_image_ext @@ -57,18 +57,20 @@ def _make_image_filename(self, item): return item.id + self._find_image_ext(item) def _save_image(self, item, path=None): - image = item.image.data - if image is None: + if not item.image.has_data: log.warning("Item '%s' has no image", item.id) return path = path or self._make_image_filename(item) - src_ext = osp.splitext(osp.basename(item.image.path))[1] - dst_ext = osp.splitext(osp.basename(path))[1] + src_ext = item.image.ext.lower() + dst_ext = osp.splitext(osp.basename(path))[1].lower() os.makedirs(osp.dirname(path), exist_ok=True) if src_ext == dst_ext and osp.isfile(item.image.path): shutil.copyfile(item.image.path, path) + elif src_ext == dst_ext and isinstance(item.image, ByteImage): + with open(path, 'wb') as f: + f.write(item.image.get_bytes()) else: - save_image(path, image) + save_image(path, item.image.data) diff --git a/datumaro/plugins/tf_detection_api_format/converter.py b/datumaro/plugins/tf_detection_api_format/converter.py index 1c91dff20c..cdc1e5fe60 100644 --- a/datumaro/plugins/tf_detection_api_format/converter.py +++ b/datumaro/plugins/tf_detection_api_format/converter.py @@ -15,7 +15,7 @@ LabelCategories ) from datumaro.components.converter import Converter -from datumaro.util.image import encode_image +from datumaro.util.image import encode_image, ByteImage from datumaro.util.annotation_util import (max_bbox, find_group_leader, find_instances) from datumaro.util.mask_tools import merge_masks @@ -197,11 +197,16 @@ def _make_tf_example(self, item): return tf_example def _save_image(self, item, path=None): - dst_ext = osp.splitext(osp.basename(path))[1] + src_ext = item.image.ext.lower() + dst_ext = osp.splitext(osp.basename(path))[1].lower() fmt = DetectionApiPath.IMAGE_EXT_FORMAT.get(dst_ext) if not fmt: log.warning("Item '%s': can't find format string for the '%s' " "image extension, the corresponding field will be empty." % \ (item.id, dst_ext)) - buffer = encode_image(item.image.data, dst_ext) + + if src_ext == dst_ext and isinstance(item.image, ByteImage): + buffer = item.image.get_bytes() + else: + buffer = encode_image(item.image.data, dst_ext) return buffer, fmt \ No newline at end of file diff --git a/datumaro/plugins/tf_detection_api_format/extractor.py b/datumaro/plugins/tf_detection_api_format/extractor.py index b37037bb52..75d560453b 100644 --- a/datumaro/plugins/tf_detection_api_format/extractor.py +++ b/datumaro/plugins/tf_detection_api_format/extractor.py @@ -11,7 +11,7 @@ from datumaro.components.extractor import (SourceExtractor, DatasetItem, AnnotationType, Bbox, Mask, LabelCategories, Importer ) -from datumaro.util.image import Image, decode_image, lazy_image +from datumaro.util.image import ByteImage, decode_image, lazy_image from datumaro.util.tf_util import import_tf as _import_tf from .format import DetectionApiPath @@ -167,13 +167,13 @@ def _parse_tfrecord_file(cls, filepath, subset, images_dir): image_params = {} if frame_image: - image_params['data'] = lazy_image(frame_image, decode_image) + image_params['data'] = frame_image if frame_filename: image_params['path'] = osp.join(images_dir, frame_filename) image = None if image_params: - image = Image(**image_params, size=image_size) + image = ByteImage(**image_params, size=image_size) dataset_items.append(DatasetItem(id=item_id, subset=subset, image=image, annotations=annotations, diff --git a/datumaro/plugins/voc_format/converter.py b/datumaro/plugins/voc_format/converter.py index 4d0a05d2de..bcd0c91454 100644 --- a/datumaro/plugins/voc_format/converter.py +++ b/datumaro/plugins/voc_format/converter.py @@ -198,15 +198,10 @@ def save_subsets(self): if item.has_image: h, w = item.image.size - if item.image.has_data: - image_shape = item.image.data.shape - c = 1 if len(image_shape) == 2 else image_shape[2] - else: - c = 3 size_elem = ET.SubElement(root_elem, 'size') ET.SubElement(size_elem, 'width').text = str(w) ET.SubElement(size_elem, 'height').text = str(h) - ET.SubElement(size_elem, 'depth').text = str(c) + ET.SubElement(size_elem, 'depth').text = '' item_segmented = 0 < len(masks) ET.SubElement(root_elem, 'segmented').text = \ @@ -337,17 +332,17 @@ def save_subsets(self): action_list[item.id] = None segm_list[item.id] = None - if self._tasks & {VocTask.classification, VocTask.detection, - VocTask.action_classification, VocTask.person_layout}: - self.save_clsdet_lists(subset_name, clsdet_list) - if self._tasks & {VocTask.classification}: - self.save_class_lists(subset_name, class_lists) - if self._tasks & {VocTask.action_classification}: - self.save_action_lists(subset_name, action_list) - if self._tasks & {VocTask.person_layout}: - self.save_layout_lists(subset_name, layout_list) - if self._tasks & {VocTask.segmentation}: - self.save_segm_lists(subset_name, segm_list) + if self._tasks & {VocTask.classification, VocTask.detection, + VocTask.action_classification, VocTask.person_layout}: + self.save_clsdet_lists(subset_name, clsdet_list) + if self._tasks & {VocTask.classification}: + self.save_class_lists(subset_name, class_lists) + if self._tasks & {VocTask.action_classification}: + self.save_action_lists(subset_name, action_list) + if self._tasks & {VocTask.person_layout}: + self.save_layout_lists(subset_name, layout_list) + if self._tasks & {VocTask.segmentation}: + self.save_segm_lists(subset_name, segm_list) def save_action_lists(self, subset_name, action_list): if not action_list: diff --git a/datumaro/util/image.py b/datumaro/util/image.py index 625424be5b..c653adf687 100644 --- a/datumaro/util/image.py +++ b/datumaro/util/image.py @@ -219,6 +219,10 @@ def __init__(self, data=None, path=None, loader=None, cache=None, def path(self): return self._path + @property + def ext(self): + return osp.splitext(osp.basename(self.path))[1] + @property def data(self): if callable(self._data): @@ -247,4 +251,45 @@ def __eq__(self, other): (np.array_equal(self.size, other.size)) and \ (self.has_data == other.has_data) and \ (self.has_data and np.array_equal(self.data, other.data) or \ + not self.has_data) + +class ByteImage(Image): + def __init__(self, data=None, path=None, ext=None, cache=None, size=None): + loader = None + if data is not None: + if callable(data) and not isinstance(data, lazy_image): + data = lazy_image(path, loader=data, cache=cache) + loader = lambda _: decode_image(self.get_bytes()) + + super().__init__(path=path, size=size, loader=loader, cache=cache) + if data is None and loader is None: + # unset defaults for regular images + # to avoid random file reading to bytes + self._data = None + + self._bytes_data = data + if ext: + ext = ext.lower() + if not ext.startswith('.'): + ext = '.' + ext + self._ext = ext + + def get_bytes(self): + if callable(self._bytes_data): + return self._bytes_data() + return self._bytes_data + + @property + def ext(self): + if self._ext: + return self._ext + return super().ext + + def __eq__(self, other): + if not isinstance(other, __class__): + return super().__eq__(other) + return \ + (np.array_equal(self.size, other.size)) and \ + (self.has_data == other.has_data) and \ + (self.has_data and self.get_bytes() == other.get_bytes() or \ not self.has_data) \ No newline at end of file diff --git a/datumaro/util/test_utils.py b/datumaro/util/test_utils.py index f0ee641ea7..adad8f4916 100644 --- a/datumaro/util/test_utils.py +++ b/datumaro/util/test_utils.py @@ -81,7 +81,8 @@ def _compare_annotations(expected, actual, ignored_attrs=None): actual.attributes = b_attr return r -def compare_datasets(test, expected, actual, ignored_attrs=None): +def compare_datasets(test, expected, actual, ignored_attrs=None, + require_images=False): compare_categories(test, expected.categories(), actual.categories()) test.assertEqual(sorted(expected.subsets()), sorted(actual.subsets())) @@ -91,6 +92,10 @@ def compare_datasets(test, expected, actual, ignored_attrs=None): x.subset == item_a.subset) test.assertFalse(item_b is None, item_a.id) test.assertEqual(item_a.attributes, item_b.attributes) + if require_images or \ + item_a.has_image and item_a.image.has_data and \ + item_b.has_image and item_b.image.has_data: + test.assertEqual(item_a.image, item_b.image, item_a.id) test.assertEqual(len(item_a.annotations), len(item_b.annotations)) for ann_a in item_a.annotations: # We might find few corresponding items, so check them all diff --git a/tests/test_cvat_format.py b/tests/test_cvat_format.py index 7a97320656..fb8aea5513 100644 --- a/tests/test_cvat_format.py +++ b/tests/test_cvat_format.py @@ -63,7 +63,7 @@ def test_can_load_image(self): def test_can_load_video(self): expected_dataset = Dataset.from_iterable([ DatasetItem(id='frame_000010', subset='annotations', - image=np.ones((20, 25, 3)), + image=255 * np.ones((20, 25, 3)), annotations=[ Bbox(3, 4, 7, 1, label=2, id=0, @@ -82,7 +82,7 @@ def test_can_load_video(self): }), ], attributes={'frame': 10}), DatasetItem(id='frame_000013', subset='annotations', - image=np.ones((20, 25, 3)), + image=255 * np.ones((20, 25, 3)), annotations=[ Bbox(7, 6, 7, 2, label=2, id=0, diff --git a/tests/test_images.py b/tests/test_images.py index c8ae3274e9..a003b8d426 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -4,7 +4,8 @@ from unittest import TestCase from datumaro.util.test_utils import TestDir -from datumaro.util.image import lazy_image, load_image, save_image, Image +from datumaro.util.image import (lazy_image, load_image, save_image, \ + Image, ByteImage, encode_image) from datumaro.util.image_cache import ImageCache @@ -47,7 +48,7 @@ def test_global_cache_is_accessible(self): class ImageTest(TestCase): def test_lazy_image_shape(self): - data = np.ones((5, 6, 7)) + data = np.ones((5, 6, 3)) image_lazy = Image(data=data, size=(2, 4)) image_eager = Image(data=data) @@ -75,7 +76,47 @@ def test_ctors(self): with self.subTest(**args): img = Image(**args) # pylint: disable=pointless-statement + self.assertTrue(img.has_data) + self.assertEqual(img, image) + self.assertEqual(img.size, tuple(image.shape[:2])) + # pylint: enable=pointless-statement + +class BytesImageTest(TestCase): + def test_lazy_image_shape(self): + data = encode_image(np.ones((5, 6, 3)), 'png') + + image_lazy = ByteImage(data=data, size=(2, 4)) + image_eager = ByteImage(data=data) + + self.assertEqual((2, 4), image_lazy.size) + self.assertEqual((5, 6), image_eager.size) + + def test_ctors(self): + with TestDir() as test_dir: + path = osp.join(test_dir, 'path.png') + image = np.ones([2, 4, 3]) + image_bytes = encode_image(image, 'png') + + for args in [ + { 'data': image_bytes }, + { 'data': lambda _: image_bytes }, + { 'data': lambda _: image_bytes, 'ext': '.jpg' }, + { 'data': image_bytes, 'path': path }, + { 'data': image_bytes, 'path': path, 'size': (2, 4) }, + { 'data': image_bytes, 'path': path, 'size': (2, 4) }, + { 'path': path }, + { 'path': path, 'size': (2, 4) }, + ]: + with self.subTest(**args): + img = ByteImage(**args) + # pylint: disable=pointless-statement + self.assertEqual('data' in args, img.has_data) if img.has_data: - img.data + self.assertEqual(img, image) + self.assertEqual(img.get_bytes(), image_bytes) img.size + if 'size' in args: + self.assertEqual(img.size, args['size']) + if 'ext' in args or 'path' in args: + self.assertEqual(img.ext, args.get('ext', '.png')) # pylint: enable=pointless-statement diff --git a/tests/test_voc_format.py b/tests/test_voc_format.py index d7c9038578..b33aaa125f 100644 --- a/tests/test_voc_format.py +++ b/tests/test_voc_format.py @@ -82,7 +82,7 @@ class DstExtractor(TestExtractorBase): def __iter__(self): return iter([ DatasetItem(id='2007_000001', subset='train', - image=Image(path='2007_000001.jpg', size=(20, 10)), + image=Image(path='2007_000001.jpg', size=(10, 20)), annotations=[ Label(self._label(l.name)) for l in VOC.VocLabel if l.value % 2 == 1 @@ -119,7 +119,7 @@ def __iter__(self): ] ), DatasetItem(id='2007_000002', subset='test', - image=np.zeros((20, 10, 3))), + image=np.ones((10, 20, 3))), ]) dataset = Project.import_from(DUMMY_DATASET_DIR, 'voc').make_dataset()