From 3812f9e6832bc2b11bea7585f2189a3fdd8c0767 Mon Sep 17 00:00:00 2001 From: Ilya Trushkin Date: Tue, 15 Oct 2024 17:51:13 +0300 Subject: [PATCH] Add label groups for hierarchical classification in ImageNet Signed-off-by: Ilya Trushkin --- src/datumaro/plugins/data_formats/imagenet.py | 9 +++++++++ tests/unit/test_imagenet_format.py | 13 ++++++++----- tests/utils/test_utils.py | 11 +++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/datumaro/plugins/data_formats/imagenet.py b/src/datumaro/plugins/data_formats/imagenet.py index 673cfd132d..6fdee52507 100644 --- a/src/datumaro/plugins/data_formats/imagenet.py +++ b/src/datumaro/plugins/data_formats/imagenet.py @@ -48,8 +48,17 @@ def _load_categories(self, path): path = Path(path) for dirname in sorted(d for d in path.rglob("*") if d.is_dir()): dirname = dirname.relative_to(path) + level = len(dirname.parts) if str(dirname) != ImagenetPath.IMAGE_DIR_NO_LABEL: label_cat.add(str(dirname)) + if level > 1: + group_name = str(dirname.parents[0]) + if not any([g.name == group_name for g in label_cat.label_groups]): + label_cat.add_label_group(group_name, [str(dirname.name)], group_type=0) + else: + g = next(x for x in label_cat.label_groups if x.name == group_name) + g.labels.append(str(dirname.name)) + return {AnnotationType.label: label_cat} def _load_items(self, path): diff --git a/tests/unit/test_imagenet_format.py b/tests/unit/test_imagenet_format.py index e84b9406ea..da96e3f2c3 100644 --- a/tests/unit/test_imagenet_format.py +++ b/tests/unit/test_imagenet_format.py @@ -182,6 +182,13 @@ class ImagenetImporterTest: IMPORTER_NAME = ImagenetImporter.NAME def _create_expected_dataset(self): + label_categories = LabelCategories.from_iterable( + ("label_0", "label_1", f"{Path('label_1', 'label_1_1')}") + ) + label_categories.label_groups = [ + LabelCategories.LabelGroup(name="label_1", labels=["label_1_1"]), + ] + return Dataset.from_iterable( [ DatasetItem( @@ -204,11 +211,7 @@ def _create_expected_dataset(self): annotations=[Label(1)], ), ], - categories={ - AnnotationType.label: LabelCategories.from_iterable( - ("label_0", "label_1", f"{Path('label_1', 'label_1_1')}") - ), - }, + categories={AnnotationType.label: label_categories}, ) @mark_requirement(Requirements.DATUM_GENERAL_REQ) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index f9a2e72d59..3a25c1d158 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -108,6 +108,17 @@ def compare_categories(test, expected, actual): sorted(expected[AnnotationType.label].items, key=lambda t: t.name), sorted(actual[AnnotationType.label].items, key=lambda t: t.name), ) + if expected[AnnotationType.label].label_groups: + assert len(expected[AnnotationType.label].label_groups) == len( + actual[AnnotationType.label].label_groups + ) + for expected_group, actual_group in zip( + expected[AnnotationType.label].label_groups, + actual[AnnotationType.label].label_groups, + ): + test.assertEqual(set(expected_group.labels), set(actual_group.labels)) + test.assertEqual(expected_group.group_type, actual_group.group_type) + if AnnotationType.mask in expected: test.assertEqual( expected[AnnotationType.mask].colormap,