Skip to content

Commit

Permalink
Support for CIFAR-10/100 format (cvat-ai#225)
Browse files Browse the repository at this point in the history
* add CIFAR dataset format

* add CIFAR to documentation

* update Changelog
  • Loading branch information
yasakova-anastasia authored Apr 26, 2021
1 parent 00d167c commit b9469d9
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Support for escaping in attribiute values in LabelMe format (<https://github.com/openvinotoolkit/datumaro/issues/49>)
- Support for Segmentation Splitting (<https://github.com/openvinotoolkit/datumaro/pull/223>)
- Support for CIFAR-10/100 dataset format (<https://github.com/openvinotoolkit/datumaro/pull/225>)

### Changed
- LabelMe format saves dataset items with their relative paths by subsets without changing names (<https://github.com/openvinotoolkit/datumaro/pull/200>)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ CVAT annotations ---> Publication, statistics etc.
- [MOT sequences](https://arxiv.org/pdf/1906.04567.pdf)
- [MOTS PNG](https://www.vision.rwth-aachen.de/page/mots)
- [ImageNet](http://image-net.org/)
- [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html) (`classification`)
- [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/)
- [CVAT](https://github.com/opencv/cvat/blob/develop/cvat/apps/documentation/xml_format.md)
- [LabelMe](http://labelme.csail.mit.edu/Release3.0)
Expand Down
181 changes: 181 additions & 0 deletions datumaro/plugins/cifar_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
import pickle

import numpy as np
from datumaro.components.converter import Converter
from datumaro.components.extractor import (AnnotationType, DatasetItem,
Importer, Label, LabelCategories, SourceExtractor)
from datumaro.util import cast


class CifarPath:
BATCHES_META = 'batches.meta'
TRAIN_ANNOTATION_FILE = 'data_batch_'
IMAGES_DIR = 'images'
IMAGE_SIZE = 32

CifarLabel = ['airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Support for Python version CIFAR-10/100

class CifarExtractor(SourceExtractor):
def __init__(self, path, subset=None):
if not osp.isfile(path):
raise FileNotFoundError("Can't read annotation file '%s'" % path)

if not subset:
file_name = osp.splitext(osp.basename(path))[0]
if file_name.startswith(CifarPath.TRAIN_ANNOTATION_FILE):
subset = 'train_%s' % file_name.split('_')[-1]
else:
subset = file_name.rsplit('_', maxsplit=1)[0]

super().__init__(subset=subset)

batches_meta_file = osp.join(osp.dirname(path), CifarPath.BATCHES_META)
self._categories = self._load_categories(batches_meta_file)

self._items = list(self._load_items(path).values())

def _load_categories(self, path):
label_cat = LabelCategories()

if osp.isfile(path):
# num_cases_per_batch: 1000
# label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer',
# 'dog', 'frog', 'horse', 'ship', 'truck']
# num_vis: 3072
with open(path, 'rb') as labels_file:
data = pickle.load(labels_file)
for label in data['label_names']:
label_cat.add(label)
else:
for label in CifarLabel:
label_cat.add(label)

return { AnnotationType.label: label_cat }

def _load_items(self, path):
items = {}

# 'batch_label': 'training batch 1 of 5'
# 'data': ndarray
# 'filenames': list
# 'labels': list
with open(path, 'rb') as anno_file:
annotation_dict = pickle.load(anno_file)

labels = annotation_dict.get('labels', [])
filenames = annotation_dict.get('filenames', [])
images_data = annotation_dict.get('data')
size = annotation_dict.get('image_sizes')

if len(labels) != len(filenames):
raise Exception("The sizes of the arrays 'filenames', " \
"'labels' don't match.")

if 0 < len(images_data) and len(images_data) != len(filenames):
raise Exception("The sizes of the arrays 'data', " \
"'filenames', 'labels' don't match.")

for i, (filename, label) in enumerate(zip(filenames, labels)):
item_id = osp.splitext(filename)[0]
annotations = []
if label != None:
annotations.append(Label(label))

image = None
if 0 < len(images_data):
image = images_data[i]
if size is not None and image is not None:
image = image.reshape(size[i][0],
size[i][1], 3).astype(np.uint8)
elif image is not None:
image = image.reshape(CifarPath.IMAGE_SIZE,
CifarPath.IMAGE_SIZE, 3).astype(np.uint8)

items[item_id] = DatasetItem(id=item_id, subset=self._subset,
image=image, annotations=annotations)

return items


class CifarImporter(Importer):
@classmethod
def find_sources(cls, path):
return cls._find_sources_recursive(path, '', 'cifar',
file_filter=lambda p: osp.basename(p) not in
{CifarPath.BATCHES_META, CifarPath.IMAGES_DIR})


class CifarConverter(Converter):
DEFAULT_IMAGE_EXT = '.png'

def apply(self):
label_categories = self._extractor.categories()[AnnotationType.label]

label_names = []
for label in label_categories:
label_names.append(label.name)
labels_dict = { 'label_names': label_names }
batches_meta_file = osp.join(self._save_dir, CifarPath.BATCHES_META)
with open(batches_meta_file, 'wb') as labels_file:
pickle.dump(labels_dict, labels_file)

for subset_name, subset in self._extractor.subsets().items():
filenames = []
labels = []
data = []
image_sizes = {}
for item in subset:
filenames.append(item.id + self._find_image_ext(item))

anns = [a.label for a in item.annotations
if a.type == AnnotationType.label]
label = None
if anns:
label = anns[0]
labels.append(label)

if item.has_image and self._save_images:
image = item.image
if not image.has_data:
data.append(None)
else:
image = image.data
data.append(image.reshape(-1).astype(np.uint8))
if image.shape[0] != CifarPath.IMAGE_SIZE or \
image.shape[1] != CifarPath.IMAGE_SIZE:
image_sizes[len(data) - 1] = (image.shape[0], image.shape[1])

annotation_dict = {}
annotation_dict['filenames'] = filenames
annotation_dict['labels'] = labels
annotation_dict['data'] = np.array(data)
if len(image_sizes):
size = (CifarPath.IMAGE_SIZE, CifarPath.IMAGE_SIZE)
# 'image_sizes' isn't included in the standart format,
# needed for different image sizes
annotation_dict['image_sizes'] = [image_sizes.get(p, size)
for p in range(len(data))]

filename = '%s_batch' % subset_name
batch_label = None
if subset_name.startswith('train_') and \
cast(subset_name.split('_')[1], int) is not None:
num = subset_name.split('_')[1]
filename = CifarPath.TRAIN_ANNOTATION_FILE + num
batch_label = 'training batch %s of 5' % (num, )
if subset_name == 'test':
batch_label = 'testing batch 1 of 1'
if batch_label:
annotation_dict['batch_label'] = batch_label

annotation_file = osp.join(self._save_dir, filename)
with open(annotation_file, 'wb') as labels_file:
pickle.dump(annotation_dict, labels_file)
3 changes: 3 additions & 0 deletions docs/user_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ List of supported formats:
- [Dataset example](../tests/assets/imagenet_dataset)
- [Dataset example (txt for classification)](../tests/assets/imagenet_txt_dataset)
- Detection format is the same as in PASCAL VOC
- CIFAR-10/100 (`classification` (python version))
- [Format specification](https://www.cs.toronto.edu/~kriz/cifar.html)
- [Dataset example](../tests/assets/cifar_dataset)
- CamVid (`segmentation`)
- [Format specification](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/)
- [Dataset example](../tests/assets/camvid_dataset)
Expand Down
Binary file added tests/assets/cifar_dataset/batches.meta
Binary file not shown.
Binary file added tests/assets/cifar_dataset/data_batch_1
Binary file not shown.
Binary file added tests/assets/cifar_dataset/test_batch
Binary file not shown.
149 changes: 149 additions & 0 deletions tests/test_cifar_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os.path as osp
from unittest import TestCase

import numpy as np
from datumaro.components.dataset import Dataset
from datumaro.components.extractor import (AnnotationType, DatasetItem, Label,
LabelCategories)
from datumaro.plugins.cifar_format import CifarConverter, CifarImporter
from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, compare_datasets


class CifarFormatTest(TestCase):
def test_can_save_and_load(self):
source_dataset = Dataset.from_iterable([
DatasetItem(id='image_2', subset='test',
image=np.ones((32, 32, 3)),
annotations=[Label(0)]
),
DatasetItem(id='image_3', subset='test',
image=np.ones((32, 32, 3))
),
DatasetItem(id='image_4', subset='test',
image=np.ones((32, 32, 3)),
annotations=[Label(1)]
)
], categories=['label_0', 'label_1'])

with TestDir() as test_dir:
CifarConverter.convert(source_dataset, test_dir, save_images=True)
parsed_dataset = Dataset.import_from(test_dir, 'cifar')

compare_datasets(self, source_dataset, parsed_dataset,
require_images=True)

def test_can_save_and_load_without_saving_images(self):
source_dataset = Dataset.from_iterable([
DatasetItem(id='a', subset='train_1',
annotations=[Label(0)]
),
DatasetItem(id='b', subset='train_first',
annotations=[Label(1)]
),
], categories={
AnnotationType.label: LabelCategories.from_iterable(
'label' + str(label) for label in range(2)),
})

with TestDir() as test_dir:
CifarConverter.convert(source_dataset, test_dir, save_images=False)
parsed_dataset = Dataset.import_from(test_dir, 'cifar')

compare_datasets(self, source_dataset, parsed_dataset,
require_images=True)

def test_can_save_and_load_with_different_image_size(self):
source_dataset = Dataset.from_iterable([
DatasetItem(id='image_1',
image=np.ones((10, 8, 3)),
annotations=[Label(0)]
),
DatasetItem(id='image_2',
image=np.ones((32, 32, 3)),
annotations=[Label(1)]
),
], categories={
AnnotationType.label: LabelCategories.from_iterable(
'label' + str(label) for label in range(2)),
})

with TestDir() as test_dir:
CifarConverter.convert(source_dataset, test_dir, save_images=True)
parsed_dataset = Dataset.import_from(test_dir, 'cifar')

compare_datasets(self, source_dataset, parsed_dataset,
require_images=True)

def test_can_save_dataset_with_cyrillic_and_spaces_in_filename(self):
source_dataset = Dataset.from_iterable([
DatasetItem(id="кириллица с пробелом",
image=np.ones((32, 32, 3)),
annotations=[Label(0)]
),
], categories=['label_0'])

with TestDir() as test_dir:
CifarConverter.convert(source_dataset, test_dir, save_images=True)
parsed_dataset = Dataset.import_from(test_dir, 'cifar')

compare_datasets(self, source_dataset, parsed_dataset,
require_images=True)

def test_can_save_and_load_image_with_arbitrary_extension(self):
dataset = Dataset.from_iterable([
DatasetItem(id='q/1', image=Image(path='q/1.JPEG',
data=np.zeros((32, 32, 3)))),
DatasetItem(id='a/b/c/2', image=Image(path='a/b/c/2.bmp',
data=np.zeros((32, 32, 3)))),
], categories=[])

with TestDir() as test_dir:
CifarConverter.convert(dataset, test_dir, save_images=True)
parsed_dataset = Dataset.import_from(test_dir, 'cifar')

compare_datasets(self, dataset, parsed_dataset,
require_images=True)

def test_can_save_and_load_empty_image(self):
dataset = Dataset.from_iterable([
DatasetItem(id='a', annotations=[Label(0)]),
DatasetItem(id='b')
], categories=['label_0'])

with TestDir() as test_dir:
CifarConverter.convert(dataset, test_dir, save_images=True)
parsed_dataset = Dataset.import_from(test_dir, 'cifar')

compare_datasets(self, dataset, parsed_dataset,
require_images=True)

DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__), 'assets', 'cifar_dataset')

class CifarImporterTest(TestCase):
def test_can_import(self):
expected_dataset = Dataset.from_iterable([
DatasetItem(id='image_1', subset='train_1',
image=np.ones((32, 32, 3)),
annotations=[Label(0)]
),
DatasetItem(id='image_2', subset='test',
image=np.ones((32, 32, 3)),
annotations=[Label(1)]
),
DatasetItem(id='image_3', subset='test',
image=np.ones((32, 32, 3)),
annotations=[Label(3)]
),
DatasetItem(id='image_4', subset='test',
image=np.ones((32, 32, 3)),
annotations=[Label(2)]
)
], categories=['airplane', 'automobile', 'bird', 'cat'])

dataset = Dataset.import_from(DUMMY_DATASET_DIR, 'cifar')

compare_datasets(self, expected_dataset, dataset)

def test_can_detect(self):
self.assertTrue(CifarImporter.detect(DUMMY_DATASET_DIR))

0 comments on commit b9469d9

Please sign in to comment.