Skip to content

Commit

Permalink
Add label groups for hierarchical classification in ImageNet
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Trushkin <[email protected]>
  • Loading branch information
itrushkin committed Oct 15, 2024
1 parent 964387f commit 3812f9e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
9 changes: 9 additions & 0 deletions src/datumaro/plugins/data_formats/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,17 @@ def _load_categories(self, path):
path = Path(path)
for dirname in sorted(d for d in path.rglob("*") if d.is_dir()):
dirname = dirname.relative_to(path)
level = len(dirname.parts)
if str(dirname) != ImagenetPath.IMAGE_DIR_NO_LABEL:
label_cat.add(str(dirname))
if level > 1:
group_name = str(dirname.parents[0])
if not any([g.name == group_name for g in label_cat.label_groups]):
label_cat.add_label_group(group_name, [str(dirname.name)], group_type=0)
else:
g = next(x for x in label_cat.label_groups if x.name == group_name)
g.labels.append(str(dirname.name))

return {AnnotationType.label: label_cat}

def _load_items(self, path):
Expand Down
13 changes: 8 additions & 5 deletions tests/unit/test_imagenet_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ class ImagenetImporterTest:
IMPORTER_NAME = ImagenetImporter.NAME

def _create_expected_dataset(self):
label_categories = LabelCategories.from_iterable(
("label_0", "label_1", f"{Path('label_1', 'label_1_1')}")
)
label_categories.label_groups = [
LabelCategories.LabelGroup(name="label_1", labels=["label_1_1"]),
]

return Dataset.from_iterable(
[
DatasetItem(
Expand All @@ -204,11 +211,7 @@ def _create_expected_dataset(self):
annotations=[Label(1)],
),
],
categories={
AnnotationType.label: LabelCategories.from_iterable(
("label_0", "label_1", f"{Path('label_1', 'label_1_1')}")
),
},
categories={AnnotationType.label: label_categories},
)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
Expand Down
11 changes: 11 additions & 0 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ def compare_categories(test, expected, actual):
sorted(expected[AnnotationType.label].items, key=lambda t: t.name),
sorted(actual[AnnotationType.label].items, key=lambda t: t.name),
)
if expected[AnnotationType.label].label_groups:
assert len(expected[AnnotationType.label].label_groups) == len(
actual[AnnotationType.label].label_groups
)
for expected_group, actual_group in zip(
expected[AnnotationType.label].label_groups,
actual[AnnotationType.label].label_groups,
):
test.assertEqual(set(expected_group.labels), set(actual_group.labels))
test.assertEqual(expected_group.group_type, actual_group.group_type)

if AnnotationType.mask in expected:
test.assertEqual(
expected[AnnotationType.mask].colormap,
Expand Down

0 comments on commit 3812f9e

Please sign in to comment.