Skip to content

Commit

Permalink
Dataset caching fixes (cvat-ai#351)
Browse files Browse the repository at this point in the history
* Fix importing arbitrary file names in COCO subformats

* Optimize subset iteration in a simple scenario

* Fix subset iteration in dataset with transforms
  • Loading branch information
Maxim Zhiltsov authored Jul 13, 2021
1 parent f9a5a8b commit d6914f5
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 10 deletions.
18 changes: 10 additions & 8 deletions datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def is_cache_initialized(self) -> bool:

@property
def _is_unchanged_wrapper(self) -> bool:
return self._source is not None and self._storage.is_empty()
return self._source is not None and self._storage.is_empty() and \
not self._transforms

def init_cache(self):
if not self.is_cache_initialized():
Expand Down Expand Up @@ -513,16 +514,17 @@ def get_subset(self, name):
return self._merged().get_subset(name)

def subsets(self):
subsets = {}
if not self.is_cache_initialized():
subsets.update(self._source.subsets())
subsets.update(self._storage.subsets())
return subsets
# TODO: check if this can be optimized in case of transforms
# and other cases
return self._merged().subsets()

def transform(self, method: Transform, *args, **kwargs):
# Flush accumulated changes
source = self._merged()
self._storage = DatasetItemStorage()
if not self._storage.is_empty():
source = self._merged()
self._storage = DatasetItemStorage()
else:
source = self._source

if not self._transforms:
# The stack of transforms only needs a single source
Expand Down
3 changes: 3 additions & 0 deletions datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,9 @@ def get_subset(self, name):
if self._subsets is None:
self._init_cache()
if name in self._subsets:
if len(self._subsets) == 1:
return self

return self.select(lambda item: item.subset == name)
else:
raise Exception("Unknown subset '%s', available subsets: %s" % \
Expand Down
8 changes: 6 additions & 2 deletions datumaro/plugins/coco_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ def __call__(self, path, **extra_params):

@classmethod
def find_sources(cls, path):
if path.endswith('.json') and osp.isfile(path):
subset_paths = [path]
if osp.isfile(path):
if len(cls._TASKS) == 1:
return {'': { next(iter(cls._TASKS)): path }}

if path.endswith('.json'):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '**', '*_*.json'),
recursive=True)
Expand Down
114 changes: 114 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,120 @@ def transform_item(self, item):

self.assertEqual(iter_called, 1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_len_after_local_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(ItemTransform):
def transform_item(self, item):
return self.wrap_item(item, id=int(item.id) + 1)

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual(4, len(dataset))

self.assertEqual(iter_called, 1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_len_after_nonlocal_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(Transform):
def __iter__(self):
for item in self._extractor:
yield self.wrap_item(item, id=int(item.id) + 1)

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual(4, len(dataset))

self.assertEqual(iter_called, 2)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_subsets_after_local_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(ItemTransform):
def transform_item(self, item):
return self.wrap_item(item, id=int(item.id) + 1, subset='a')

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual({'a'}, set(dataset.subsets()))

self.assertEqual(iter_called, 1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_subsets_after_nonlocal_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(Transform):
def __iter__(self):
for item in self._extractor:
yield self.wrap_item(item, id=int(item.id) + 1, subset='a')

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual({'a'}, set(dataset.subsets()))

self.assertEqual(iter_called, 2)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_raises_when_repeated_items_in_source(self):
dataset = Dataset.from_iterable([DatasetItem(0), DatasetItem(0)])
Expand Down

0 comments on commit d6914f5

Please sign in to comment.