Skip to content

Commit

Permalink
Add ImageNet format
Browse files Browse the repository at this point in the history
  • Loading branch information
yasakova-anastasia committed Oct 30, 2020
1 parent 373d94c commit d172497
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
41 changes: 41 additions & 0 deletions cvat/apps/dataset_manager/formats/imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
from glob import glob

import zipfile
from tempfile import TemporaryDirectory

from datumaro.components.project import Dataset
from cvat.apps.dataset_manager.bindings import CvatTaskDataExtractor, \
import_dm_annotations
from cvat.apps.dataset_manager.util import make_zip_archive

from .registry import dm_env, exporter, importer


@exporter(name='ImageNet', ext='ZIP', version='1.0')
def _export(dst_file, task_data, save_images=False):
extractor = CvatTaskDataExtractor(task_data, include_images=save_images)
extractor = Dataset.from_extractors(extractor) # apply lazy transform
with TemporaryDirectory() as temp_dir:
if save_images:
dm_env.converters.get('imagenet').convert(extractor,
save_dir=temp_dir, save_images=save_images)
else:
dm_env.converters.get('imagenet_txt').convert(extractor,
save_dir=temp_dir, save_images=save_images)

make_zip_archive(temp_dir, dst_file)

@importer(name='ImageNet', ext='ZIP', version='1.0')
def _import(src_file, task_data):
with TemporaryDirectory() as tmp_dir:
zipfile.ZipFile(src_file).extractall(tmp_dir)
if glob(osp.join(tmp_dir, '*.txt')):
dataset = dm_env.make_importer('imagenet_txt')(tmp_dir).make_dataset()
else:
dataset = dm_env.make_importer('imagenet')(tmp_dir).make_dataset()
import_dm_annotations(dataset, task_data)
3 changes: 2 additions & 1 deletion cvat/apps/dataset_manager/formats/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,5 @@ def make_exporter(name):
import cvat.apps.dataset_manager.formats.mots
import cvat.apps.dataset_manager.formats.pascal_voc
import cvat.apps.dataset_manager.formats.tfrecord
import cvat.apps.dataset_manager.formats.yolo
import cvat.apps.dataset_manager.formats.yolo
import cvat.apps.dataset_manager.formats.imagenet
3 changes: 3 additions & 0 deletions cvat/apps/dataset_manager/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def test_export_formats_query(self):
'Segmentation mask 1.1',
'TFRecord 1.0',
'YOLO 1.1',
'ImageNet 1.0',
})

def test_import_formats_query(self):
Expand All @@ -287,6 +288,7 @@ def test_import_formats_query(self):
'Segmentation mask 1.1',
'TFRecord 1.0',
'YOLO 1.1',
'ImageNet 1.0',
})

def test_exports(self):
Expand Down Expand Up @@ -322,6 +324,7 @@ def test_empty_images_are_exported(self):
('Segmentation mask 1.1', 'voc'),
('TFRecord 1.0', 'tf_detection_api'),
('YOLO 1.1', 'yolo'),
('ImageNet 1.0', 'imagenet_txt'),
]:
with self.subTest(format=format_name):
if not dm.formats.registry.EXPORT_FORMATS[format_name].ENABLED:
Expand Down
4 changes: 4 additions & 0 deletions cvat/apps/engine/tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3355,6 +3355,10 @@ def _get_initial_annotation(annotation_format):
+ polygon_shapes_with_attrs
annotations["tags"] = tags_with_attrs + tags_wo_attrs

elif annotation_format == "ImageNet 1.0":
annotations["tags"] = tags_wo_attrs


else:
raise Exception("Unknown format {}".format(annotation_format))

Expand Down

0 comments on commit d172497

Please sign in to comment.