diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index 3be34a4a4398..7072150f2df6 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -266,7 +266,8 @@ def is_cache_initialized(self) -> bool: @property def _is_unchanged_wrapper(self) -> bool: - return self._source is not None and self._storage.is_empty() + return self._source is not None and self._storage.is_empty() and \ + not self._transforms def init_cache(self): if not self.is_cache_initialized(): @@ -513,16 +514,17 @@ def get_subset(self, name): return self._merged().get_subset(name) def subsets(self): - subsets = {} - if not self.is_cache_initialized(): - subsets.update(self._source.subsets()) - subsets.update(self._storage.subsets()) - return subsets + # TODO: check if this can be optimized in case of transforms + # and other cases + return self._merged().subsets() def transform(self, method: Transform, *args, **kwargs): # Flush accumulated changes - source = self._merged() - self._storage = DatasetItemStorage() + if not self._storage.is_empty(): + source = self._merged() + self._storage = DatasetItemStorage() + else: + source = self._source if not self._transforms: # The stack of transforms only needs a single source diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index 26506b30021c..17220d04afc4 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -686,6 +686,9 @@ def get_subset(self, name): if self._subsets is None: self._init_cache() if name in self._subsets: + if len(self._subsets) == 1: + return self + return self.select(lambda item: item.subset == name) else: raise Exception("Unknown subset '%s', available subsets: %s" % \ diff --git a/datumaro/plugins/coco_format/importer.py b/datumaro/plugins/coco_format/importer.py index ca76dc55db38..44a4c7362190 100644 --- a/datumaro/plugins/coco_format/importer.py +++ b/datumaro/plugins/coco_format/importer.py @@ -73,8 +73,12 @@ def __call__(self, path, **extra_params): @classmethod def find_sources(cls, path): - if path.endswith('.json') and osp.isfile(path): - subset_paths = [path] + if osp.isfile(path): + if len(cls._TASKS) == 1: + return {'': { next(iter(cls._TASKS)): path }} + + if path.endswith('.json'): + subset_paths = [path] else: subset_paths = glob(osp.join(path, '**', '*_*.json'), recursive=True) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d141d00d8840..e594a87f260e 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -867,6 +867,120 @@ def transform_item(self, item): self.assertEqual(iter_called, 1) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_len_after_local_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(ItemTransform): + def transform_item(self, item): + return self.wrap_item(item, id=int(item.id) + 1) + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual(4, len(dataset)) + + self.assertEqual(iter_called, 1) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_len_after_nonlocal_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(Transform): + def __iter__(self): + for item in self._extractor: + yield self.wrap_item(item, id=int(item.id) + 1) + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual(4, len(dataset)) + + self.assertEqual(iter_called, 2) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_subsets_after_local_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(ItemTransform): + def transform_item(self, item): + return self.wrap_item(item, id=int(item.id) + 1, subset='a') + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual({'a'}, set(dataset.subsets())) + + self.assertEqual(iter_called, 1) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_subsets_after_nonlocal_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(Transform): + def __iter__(self): + for item in self._extractor: + yield self.wrap_item(item, id=int(item.id) + 1, subset='a') + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual({'a'}, set(dataset.subsets())) + + self.assertEqual(iter_called, 2) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_raises_when_repeated_items_in_source(self): dataset = Dataset.from_iterable([DatasetItem(0), DatasetItem(0)])