Skip to content

Commit

Permalink
[Datumaro] Add tests for dataset examples (cvat-ai#1648)
Browse files Browse the repository at this point in the history
* add dataset examples

* update docs

* update yolo tests

* join voc format test classes

* remplace voc extractor tests with import test

* update tfrecord format tests

* update mot tests

* update labelme tests

* update image dir tests
  • Loading branch information
zhiltsov-max authored and Fernando Martínez González committed Aug 3, 2020
1 parent af2fa19 commit 9c78837
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 316 deletions.
16 changes: 8 additions & 8 deletions datumaro/tests/test_image_dir_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@


class ImageDirFormatTest(TestCase):
class TestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, image=np.ones((10, 6, 3))),
DatasetItem(id=2, image=np.ones((5, 4, 3))),
])

def test_can_load(self):
class TestExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, image=np.ones((10, 6, 3))),
DatasetItem(id=2, image=np.ones((5, 4, 3))),
])

with TestDir() as test_dir:
source_dataset = self.TestExtractor()
source_dataset = TestExtractor()

ImageDirConverter()(source_dataset, save_dir=test_dir)

Expand Down
24 changes: 10 additions & 14 deletions datumaro/tests/test_labelme_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from datumaro.components.extractor import (Extractor, DatasetItem,
AnnotationType, Bbox, Mask, Polygon, LabelCategories
)
from datumaro.components.project import Dataset
from datumaro.plugins.labelme_format import LabelMeExtractor, LabelMeImporter, \
from datumaro.components.project import Project
from datumaro.plugins.labelme_format import LabelMeImporter, \
LabelMeConverter
from datumaro.util.test_utils import TestDir, compare_datasets

Expand Down Expand Up @@ -111,8 +111,11 @@ def categories(self):

DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__), 'assets', 'labelme_dataset')

class LabelMeExtractorTest(TestCase):
def test_can_load(self):
class LabelMeImporterTest(TestCase):
def test_can_detect(self):
self.assertTrue(LabelMeImporter.detect(DUMMY_DATASET_DIR))

def test_can_import(self):
class DstExtractor(Extractor):
def __iter__(self):
img1 = np.ones((77, 102, 3)) * 255
Expand Down Expand Up @@ -208,13 +211,6 @@ def categories(self):
AnnotationType.label: label_cat,
}

parsed = Dataset.from_extractors(LabelMeExtractor(DUMMY_DATASET_DIR))
compare_datasets(self, expected=DstExtractor(), actual=parsed)

class LabelMeImporterTest(TestCase):
def test_can_detect(self):
self.assertTrue(LabelMeImporter.detect(DUMMY_DATASET_DIR))

def test_can_import(self):
parsed = LabelMeImporter()(DUMMY_DATASET_DIR).make_dataset()
self.assertEqual(1, len(parsed))
parsed = Project.import_from(DUMMY_DATASET_DIR, 'label_me') \
.make_dataset()
compare_datasets(self, expected=DstExtractor(), actual=parsed)
27 changes: 18 additions & 9 deletions datumaro/tests/test_mot_format.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import os.path as osp

from unittest import TestCase

from datumaro.components.extractor import (Extractor, DatasetItem,
AnnotationType, Bbox, LabelCategories
)
from datumaro.components.project import Project
from datumaro.plugins.mot_format import MotSeqGtConverter, MotSeqImporter
from datumaro.util.test_utils import TestDir, compare_datasets

Expand Down Expand Up @@ -116,15 +118,25 @@ def categories(self):
SrcExtractor(), MotSeqGtConverter(save_images=True),
test_dir, target_dataset=DstExtractor())


DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__), 'assets', 'mot_dataset')

class MotImporterTest(TestCase):
def test_can_detect(self):
class TestExtractor(Extractor):
self.assertTrue(MotSeqImporter.detect(DUMMY_DATASET_DIR))

def test_can_import(self):
class DstExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, subset='train',
DatasetItem(id=1,
image=np.ones((16, 16, 3)),
annotations=[
Bbox(0, 4, 4, 8, label=2),
Bbox(0, 4, 4, 8, label=2, attributes={
'occluded': False,
'visibility': 1.0,
'ignored': False,
}),
]
),
])
Expand All @@ -137,10 +149,7 @@ def categories(self):
AnnotationType.label: label_cat,
}

def generate_dummy_dataset(path):
MotSeqGtConverter()(TestExtractor(), save_dir=path)

with TestDir() as test_dir:
generate_dummy_dataset(test_dir)
dataset = Project.import_from(DUMMY_DATASET_DIR, 'mot_seq') \
.make_dataset()

self.assertTrue(MotSeqImporter.detect(test_dir))
compare_datasets(self, DstExtractor(), dataset)
46 changes: 27 additions & 19 deletions datumaro/tests/test_tfrecord_format.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import os.path as osp

from unittest import TestCase, skipIf

from datumaro.components.extractor import (Extractor, DatasetItem,
AnnotationType, Bbox, Mask, LabelCategories
)
from datumaro.components.project import Project
from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, compare_datasets
from datumaro.util.tf_util import check_import
Expand Down Expand Up @@ -56,17 +58,6 @@ def __iter__(self):
Bbox(2, 4, 4, 4),
]
),

DatasetItem(id=2, subset='val',
image=np.ones((8, 8, 3)),
annotations=[
Bbox(1, 2, 4, 2, label=3),
]
),

DatasetItem(id=3, subset='test',
image=np.ones((5, 4, 3)) * 3,
),
])

def categories(self):
Expand Down Expand Up @@ -188,17 +179,37 @@ def test_labelmap_parsing(self):

self.assertEqual(expected, parsed)


DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__),
'assets', 'tf_detection_api_dataset')

@skipIf(import_failed, "Failed to import tensorflow")
class TfrecordImporterTest(TestCase):
def test_can_detect(self):
class TestExtractor(Extractor):
self.assertTrue(TfDetectionApiImporter.detect(DUMMY_DATASET_DIR))

def test_can_import(self):
class DstExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, subset='train',
image=np.ones((16, 16, 3)),
annotations=[
Bbox(0, 4, 4, 8, label=2),
]
Bbox(0, 4, 4, 4, label=3),
Bbox(2, 4, 4, 4),
],
),

DatasetItem(id=2, subset='val',
image=np.ones((8, 8, 3)),
annotations=[
Bbox(1, 2, 4, 2, label=3),
],
),

DatasetItem(id=3, subset='test',
image=np.ones((5, 4, 3)) * 3,
),
])

Expand All @@ -210,10 +221,7 @@ def categories(self):
AnnotationType.label: label_cat,
}

def generate_dummy_tfrecord(path):
TfDetectionApiConverter()(TestExtractor(), save_dir=path)

with TestDir() as test_dir:
generate_dummy_tfrecord(test_dir)
dataset = Project.import_from(DUMMY_DATASET_DIR, 'tf_detection_api') \
.make_dataset()

self.assertTrue(TfDetectionApiImporter.detect(test_dir))
compare_datasets(self, DstExtractor(), dataset)
Loading

0 comments on commit 9c78837

Please sign in to comment.