From d1724974558c303406d4e7f0d15554702fd92500 Mon Sep 17 00:00:00 2001 From: yasakova-anastasia Date: Fri, 30 Oct 2020 13:34:52 +0300 Subject: [PATCH] Add ImageNet format --- cvat/apps/dataset_manager/formats/imagenet.py | 41 +++++++++++++++++++ cvat/apps/dataset_manager/formats/registry.py | 3 +- .../dataset_manager/tests/test_formats.py | 3 ++ cvat/apps/engine/tests/test_rest_api.py | 4 ++ 4 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 cvat/apps/dataset_manager/formats/imagenet.py diff --git a/cvat/apps/dataset_manager/formats/imagenet.py b/cvat/apps/dataset_manager/formats/imagenet.py new file mode 100644 index 000000000000..d9847549f9e1 --- /dev/null +++ b/cvat/apps/dataset_manager/formats/imagenet.py @@ -0,0 +1,41 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import os.path as osp +from glob import glob + +import zipfile +from tempfile import TemporaryDirectory + +from datumaro.components.project import Dataset +from cvat.apps.dataset_manager.bindings import CvatTaskDataExtractor, \ + import_dm_annotations +from cvat.apps.dataset_manager.util import make_zip_archive + +from .registry import dm_env, exporter, importer + + +@exporter(name='ImageNet', ext='ZIP', version='1.0') +def _export(dst_file, task_data, save_images=False): + extractor = CvatTaskDataExtractor(task_data, include_images=save_images) + extractor = Dataset.from_extractors(extractor) # apply lazy transform + with TemporaryDirectory() as temp_dir: + if save_images: + dm_env.converters.get('imagenet').convert(extractor, + save_dir=temp_dir, save_images=save_images) + else: + dm_env.converters.get('imagenet_txt').convert(extractor, + save_dir=temp_dir, save_images=save_images) + + make_zip_archive(temp_dir, dst_file) + +@importer(name='ImageNet', ext='ZIP', version='1.0') +def _import(src_file, task_data): + with TemporaryDirectory() as tmp_dir: + zipfile.ZipFile(src_file).extractall(tmp_dir) + if glob(osp.join(tmp_dir, '*.txt')): + dataset = dm_env.make_importer('imagenet_txt')(tmp_dir).make_dataset() + else: + dataset = dm_env.make_importer('imagenet')(tmp_dir).make_dataset() + import_dm_annotations(dataset, task_data) \ No newline at end of file diff --git a/cvat/apps/dataset_manager/formats/registry.py b/cvat/apps/dataset_manager/formats/registry.py index c84d60fc603e..c175a42b7728 100644 --- a/cvat/apps/dataset_manager/formats/registry.py +++ b/cvat/apps/dataset_manager/formats/registry.py @@ -90,4 +90,5 @@ def make_exporter(name): import cvat.apps.dataset_manager.formats.mots import cvat.apps.dataset_manager.formats.pascal_voc import cvat.apps.dataset_manager.formats.tfrecord -import cvat.apps.dataset_manager.formats.yolo \ No newline at end of file +import cvat.apps.dataset_manager.formats.yolo +import cvat.apps.dataset_manager.formats.imagenet \ No newline at end of file diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py index 1eb3e2b56da9..07640a24b3d7 100644 --- a/cvat/apps/dataset_manager/tests/test_formats.py +++ b/cvat/apps/dataset_manager/tests/test_formats.py @@ -271,6 +271,7 @@ def test_export_formats_query(self): 'Segmentation mask 1.1', 'TFRecord 1.0', 'YOLO 1.1', + 'ImageNet 1.0', }) def test_import_formats_query(self): @@ -287,6 +288,7 @@ def test_import_formats_query(self): 'Segmentation mask 1.1', 'TFRecord 1.0', 'YOLO 1.1', + 'ImageNet 1.0', }) def test_exports(self): @@ -322,6 +324,7 @@ def test_empty_images_are_exported(self): ('Segmentation mask 1.1', 'voc'), ('TFRecord 1.0', 'tf_detection_api'), ('YOLO 1.1', 'yolo'), + ('ImageNet 1.0', 'imagenet_txt'), ]: with self.subTest(format=format_name): if not dm.formats.registry.EXPORT_FORMATS[format_name].ENABLED: diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index b7448fff7749..93399818ecb8 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -3355,6 +3355,10 @@ def _get_initial_annotation(annotation_format): + polygon_shapes_with_attrs annotations["tags"] = tags_with_attrs + tags_wo_attrs + elif annotation_format == "ImageNet 1.0": + annotations["tags"] = tags_wo_attrs + + else: raise Exception("Unknown format {}".format(annotation_format))