diff --git a/CHANGELOG.md b/CHANGELOG.md index b8957fc106..21c110c5c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Task-specific Splitter () - `WiderFace` dataset format () - Function to transform annotations to labels () - `VGGFace2` dataset format () diff --git a/datumaro/plugins/splitter.py b/datumaro/plugins/splitter.py new file mode 100644 index 0000000000..704e8c0966 --- /dev/null +++ b/datumaro/plugins/splitter.py @@ -0,0 +1,522 @@ +# Copyright (C) 2020 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import logging as log +import numpy as np + +from datumaro.components.extractor import (Transform, AnnotationType, + DEFAULT_SUBSET_NAME) + +NEAR_ZERO = 1e-7 + + +class _TaskSpecificSplit(Transform): + def __init__(self, dataset, splits, seed): + super().__init__(dataset) + + snames, sratio = self._validate_splits(splits) + + self._snames = snames + self._sratio = sratio + + self._seed = seed + + self._subsets = {"train", "val", "test"} # output subset names + self._parts = [] + self._length = "parent" + + self._initialized = False + + def _set_parts(self, by_splits): + self._parts = [] + for subset in self._subsets: + self._parts.append((set(by_splits[subset]), subset)) + + @staticmethod + def _get_uniq_annotations(dataset): + annotations = [] + for item in dataset: + labels = [a for a in item.annotations + if a.type == AnnotationType.label] + if len(labels) != 1: + raise Exception("Item '%s' contains %s labels, " + "but exactly one is expected" % (item.id, len(labels))) + annotations.append(labels[0]) + return annotations + + @staticmethod + def _validate_splits(splits, valid=None): + snames = [] + ratios = [] + if valid is None: + valid = ["train", "val", "test"] + for subset, ratio in splits: + assert subset in valid, \ + "Subset name must be one of %s, but got %s" % (valid, subset) + assert 0.0 <= ratio and ratio <= 1.0, \ + "Ratio is expected to be in the range " \ + "[0, 1], but got %s for %s" % (ratio, subset) + snames.append(subset) + ratios.append(float(ratio)) + ratios = np.array(ratios) + + total_ratio = np.sum(ratios) + if not abs(total_ratio - 1.0) <= NEAR_ZERO: + raise Exception( + "Sum of ratios is expected to be 1, got %s, which is %s" + % (splits, total_ratio) + ) + return snames, ratios + + @staticmethod + def _get_required(ratio): + min_value = np.max(ratio) + for i in ratio: + if NEAR_ZERO < i and i < min_value: + min_value = i + required = int(np.around(1.0) / min_value) + return required + + @staticmethod + def _get_sections(dataset_size, ratio): + n_splits = [int(np.around(dataset_size * r)) for r in ratio[:-1]] + n_splits.append(dataset_size - np.sum(n_splits)) + + # if there are splits with zero samples even if ratio is not 0, + # borrow one from the split who has one or more. + for ii, num_split in enumerate(n_splits): + if num_split == 0 and NEAR_ZERO < ratio[ii]: + midx = np.argmax(n_splits) + if n_splits[midx] > 0: + n_splits[ii] += 1 + n_splits[midx] -= 1 + sections = np.add.accumulate(n_splits[:-1]) + return sections + + @staticmethod + def _group_by_attr(items): + """ + Args: + items: list of (idx, ann). ann is the annotation from Label object. + Returns: + by_attributes: dict of { combination-of-attrs : list of index } + """ + # group by attributes + by_attributes = dict() + for idx, ann in items: + attributes = tuple(sorted(ann.attributes.items())) + if attributes not in by_attributes: + by_attributes[attributes] = [] + by_attributes[attributes].append(idx) + return by_attributes + + def _split_by_attr(self, datasets, snames, ratio, out_splits, + dataset_key="label"): + required = self._get_required(ratio) + for key, items in datasets.items(): + np.random.shuffle(items) + by_attributes = self._group_by_attr(items) + for attributes, indice in by_attributes.items(): + gname = "%s: %s, attrs: %s" % (dataset_key, key, attributes) + splits = self._split_indice(indice, gname, ratio, required) + for subset, split in zip(snames, splits): + if 0 < len(split): + out_splits[subset].extend(split) + + def _split_indice(self, indice, group_name, ratio, required): + filtered_size = len(indice) + if filtered_size < required: + log.warning("Not enough samples for a group, '%s'" % group_name) + sections = self._get_sections(filtered_size, ratio) + splits = np.array_split(indice, sections) + return splits + + def _find_split(self, index): + for subset_indices, subset in self._parts: + if index in subset_indices: + return subset + return DEFAULT_SUBSET_NAME # all the possible remainder --> default + + def _split_dataset(self): + raise NotImplementedError() + + def __iter__(self): + # lazy splitting + if self._initialized is False: + self._split_dataset() + self._initialized = True + for i, item in enumerate(self._extractor): + yield self.wrap_item(item, subset=self._find_split(i)) + + +class ClassificationSplit(_TaskSpecificSplit): + """ + Splits dataset into train/val/test set in class-wise manner. |n + |n + Notes:|n + - Single label is expected for each DatasetItem.|n + - If there are not enough images in some class or attributes group, + the split ratio can't be guaranteed.|n + """ + + def __init__(self, dataset, splits, seed=None): + """ + Parameters + ---------- + dataset : Dataset + splits : list + A list of (subset(str), ratio(float)) + Subset is expected to be one of ["train", "val", "test"]. + The sum of ratios is expected to be 1. + seed : int, optional + """ + super().__init__(dataset, splits, seed) + + def _split_dataset(self): + np.random.seed(self._seed) + + # support only single label for a DatasetItem + # 1. group by label + by_labels = dict() + annotations = self._get_uniq_annotations(self._extractor) + for idx, ann in enumerate(annotations): + label = getattr(ann, 'label', None) + if label not in by_labels: + by_labels[label] = [] + by_labels[label].append((idx, ann)) + + by_splits = dict() + for subset in self._subsets: + by_splits[subset] = [] + + # 2. group by attributes + self._split_by_attr(by_labels, self._snames, self._sratio, by_splits) + self._set_parts(by_splits) + + +class MatchingReIDSplit(_TaskSpecificSplit): + """ + Splits dataset for matching, especially re-id task.|n + First, splits dataset into 'train+val' and 'test' sets by person id.|n + Note that this splitting is not by DatasetItem. |n + Then, tags 'test' into 'gallery'/'query' in class-wise random manner.|n + Then, splits 'train+val' into 'train'/'val' sets in the same way.|n + Therefore, the final subsets would be 'train', 'val', 'test'. |n + And 'gallery', 'query' are tagged using anntoation group.|n + You can get the 'gallery' and 'query' sets using 'get_subset_by_group'.|n + Notes:|n + - Single label is expected for each DatasetItem.|n + - Each label is expected to have attribute representing the person id. |n + """ + + _group_map = dict() + + def __init__(self, dataset, splits, test_splits, pid_name="PID", seed=None): + """ + Parameters + ---------- + dataset : Dataset + splits : list + A list of (subset(str), ratio(float)) + Subset is expected to be one of ["train", "val", "test"]. + The sum of ratios is expected to be 1. + test_splits : list + A list of (subset(str), ratio(float)) + Subset is expected to be one of ["gallery", "query"]. + The sum of ratios is expected to be 1. + pid_name: str + attribute name representing the person id. (default: PID) + seed : int, optional + """ + super().__init__(dataset, splits, seed) + + self._test_splits = test_splits + self._pid_name = pid_name + + def _split_dataset(self): + np.random.seed(self._seed) + + id_snames, id_ratio = self._snames, self._sratio + + pid_name = self._pid_name + dataset = self._extractor + + groups = set() + + # group by PID(pid_name) + by_pid = dict() + annotations = self._get_uniq_annotations(dataset) + for idx, ann in enumerate(annotations): + attributes = dict(ann.attributes.items()) + assert pid_name in attributes, \ + "'%s' is expected as an attribute name" % pid_name + person_id = attributes[pid_name] + if person_id not in by_pid: + by_pid[person_id] = [] + by_pid[person_id].append((idx, ann)) + groups.add(ann.group) + + max_group_id = max(groups) + self._group_map["gallery"] = max_group_id + 1 + self._group_map["query"] = max_group_id + 2 + + required = self._get_required(id_ratio) + if len(by_pid) < required: + log.warning("There's not enough IDs, which is %s, " + "so train/val/test ratio can't be guaranteed." + % len(by_pid) + ) + + # 1. split dataset into trval and test + # IDs in test set should not exist in train/val set. + test = id_ratio[id_snames.index("test")] if "test" in id_snames else 0 + if NEAR_ZERO < test: # has testset + split_ratio = np.array([test, 1.0 - test]) + person_ids = list(by_pid.keys()) + np.random.shuffle(person_ids) + sections = self._get_sections(len(person_ids), split_ratio) + splits = np.array_split(person_ids, sections) + testset = {pid: by_pid[pid] for pid in splits[0]} + trval = {pid: by_pid[pid] for pid in splits[1]} + + # follow the ratio of datasetitems as possible. + # naive heuristic: exchange the best item one by one. + expected_count = int(len(self._extractor) * split_ratio[0]) + testset_total = int(np.sum([len(v) for v in testset.values()])) + self._rebalancing(testset, trval, expected_count, testset_total) + else: + testset = dict() + trval = by_pid + + by_splits = dict() + for subset in self._subsets: + by_splits[subset] = [] + + # 2. split 'test' into 'gallery' and 'query' + if 0 < len(testset): + for person_id, items in testset.items(): + indice = [idx for idx, _ in items] + by_splits["test"].extend(indice) + + valid = ["gallery", "query"] + test_splits = self._test_splits + test_snames, test_ratio = self._validate_splits(test_splits, valid) + by_groups = {s: [] for s in test_snames} + self._split_by_attr(testset, test_snames, test_ratio, by_groups, + dataset_key=pid_name) + + # tag using group + for idx, item in enumerate(self._extractor): + for subset, split in by_groups.items(): + if idx in split: + group_id = self._group_map[subset] + item.annotations[0].group = group_id + break + + # 3. split 'trval' into 'train' and 'val' + trval_snames = ["train", "val"] + trval_ratio = [] + for subset in trval_snames: + if subset in id_snames: + val = id_ratio[id_snames.index(subset)] + else: + val = 0.0 + trval_ratio.append(val) + trval_ratio = np.array(trval_ratio) + total_ratio = np.sum(trval_ratio) + if total_ratio < NEAR_ZERO: + trval_splits = list(zip(["train", "val"], trval_ratio)) + log.warning("Sum of ratios is expected to be positive, " + "got %s, which is %s" + % (trval_splits, total_ratio) + ) + else: + trval_ratio /= total_ratio # normalize + self._split_by_attr(trval, trval_snames, trval_ratio, by_splits, + dataset_key=pid_name) + + self._set_parts(by_splits) + + @staticmethod + def _rebalancing(test, trval, expected_count, testset_total): + diffs = dict() + for id_test, items_test in test.items(): + count_test = len(items_test) + for id_trval, items_trval in trval.items(): + count_trval = len(items_trval) + diff = count_trval - count_test + if diff == 0: + continue # exchange has no effect + if diff not in diffs: + diffs[diff] = [(id_test, id_trval)] + else: + diffs[diff].append((id_test, id_trval)) + exchanges = [] + while True: + target_diff = expected_count - testset_total + # find nearest diff. + keys = np.array(list(diffs.keys())) + idx = (np.abs(keys - target_diff)).argmin() + nearest = keys[idx] + if abs(target_diff) <= abs(target_diff - nearest): + break + choice = np.random.choice(range(len(diffs[nearest]))) + pid_test, pid_trval = diffs[nearest][choice] + testset_total += nearest + new_diffs = dict() + for diff, person_ids in diffs.items(): + new_list = [] + for id1, id2 in person_ids: + if id1 == pid_test or id2 == pid_trval: + continue + new_list.append((id1, id2)) + if 0 < len(new_list): + new_diffs[diff] = new_list + diffs = new_diffs + exchanges.append((pid_test, pid_trval)) + # exchange + for pid_test, pid_trval in exchanges: + test[pid_trval] = trval.pop(pid_trval) + trval[pid_test] = test.pop(pid_test) + + def get_subset_by_group(self, group: str): + available = list(self._group_map.keys()) + assert group in self._group_map, \ + "Unknown group '%s', available groups: %s" \ + % (group, available) + group_id = self._group_map[group] + return self.select(lambda item: item.annotations[0].group == group_id) + + +class DetectionSplit(_TaskSpecificSplit): + """ + Splits dataset into train/val/test set for detection task.|n + For detection dataset, each image can have multiple bbox annotations.|n + Since one DataItem can't be included in multiple subsets at the same time, + the dataset can't be divided according to the bbox annotations.|n + Thus, we split dataset based on DatasetItem + while preserving label distribution as possible.|n + |n + Notes:|n + - Each DatsetItem is expected to have one or more Bbox annotations.|n + - Label annotations are ignored. We only focus on the Bbox annotations.|n + """ + + def __init__(self, dataset, splits, seed=None): + """ + Parameters + ---------- + dataset : Dataset + splits : list + A list of (subset(str), ratio(float)) + Subset is expected to be one of ["train", "val", "test"]. + The sum of ratios is expected to be 1. + seed : int, optional + """ + super().__init__(dataset, splits, seed) + + @staticmethod + def _group_by_bbox_labels(dataset): + by_labels = dict() + for idx, item in enumerate(dataset): + bbox_anns = [a for a in item.annotations + if a.type == AnnotationType.bbox] + assert 0 < len(bbox_anns), \ + "Expected more than one bbox annotation in the dataset" + for ann in bbox_anns: + label = getattr(ann, 'label', None) + if label not in by_labels: + by_labels[label] = [(idx, ann)] + else: + by_labels[label].append((idx, ann)) + return by_labels + + def _split_dataset(self): + np.random.seed(self._seed) + + subsets, sratio = self._snames, self._sratio + + # 1. group by bbox label + by_labels = self._group_by_bbox_labels(self._extractor) + + # 2. group by attributes + by_combinations = dict() + for label, items in by_labels.items(): + by_attributes = self._group_by_attr(items) + for attributes, indice in by_attributes.items(): + gname = "label: %s, attributes: %s" % (label, attributes) + by_combinations[gname] = indice + + # total number of GT samples per label-attr combinations + n_combs = {k: len(v) for k, v in by_combinations.items()} + + # 3-1. initially count per-image GT samples + scores_all = {} + init_scores = {} + for idx, _ in enumerate(self._extractor): + counts = {k: v.count(idx) for k, v in by_combinations.items()} + scores_all[idx] = counts + init_scores[idx] = np.sum( + [v / n_combs[k] for k, v in counts.items()] + ) + + by_splits = dict() + for sname in self._subsets: + by_splits[sname] = [] + + total = len(self._extractor) + target_size = dict() + expected = [] # expected numbers of per split GT samples + for sname, ratio in zip(subsets, sratio): + target_size[sname] = total * ratio + expected.append( + (sname, {k: v * ratio for k, v in n_combs.items()}) + ) + + ## + # functions for keep the # of annotations not exceed the expected num + def compute_penalty(counts, n_combs): + p = 0 + for k, v in counts.items(): + p += max(0, (v / n_combs[k]) - 1.0) + return p + + def update_nc(counts, n_combs): + for k, v in counts.items(): + n_combs[k] = max(0, n_combs[k] - v) + if n_combs[k] == 0: + n_combs[k] = -1 + return n_combs + + ## + + # 3-2. assign each DatasetItem to a split, one by one + for idx, _ in sorted( + init_scores.items(), key=lambda item: item[1], reverse=True + ): + counts = scores_all[idx] + + # shuffling split order to add randomness + # when two or more splits have the same penalty value + np.random.shuffle(expected) + + pp = [] + for sname, nc in expected: + if target_size[sname] <= len(by_splits[sname]): + # the split has enough images, + # stop adding more images to this split + pp.append(1e08) + else: + # compute penalty based on the number of GT samples + # added in the split + pp.append(compute_penalty(counts, nc)) + + # we push an image to a split with the minimum penalty + midx = np.argmin(pp) + + sname, nc = expected[midx] + by_splits[sname].append(idx) + update_nc(counts, nc) + + self._set_parts(by_splits) diff --git a/tests/test_splitter.py b/tests/test_splitter.py new file mode 100644 index 0000000000..276ed5f557 --- /dev/null +++ b/tests/test_splitter.py @@ -0,0 +1,610 @@ +import numpy as np + +from unittest import TestCase + +from datumaro.components.project import Dataset +from datumaro.components.extractor import (DatasetItem, Label, Bbox, + LabelCategories, AnnotationType) + +import datumaro.plugins.splitter as splitter +from datumaro.components.operations import compute_ann_statistics + + +class SplitterTest(TestCase): + @staticmethod + def _get_subset(idx): + subsets = ["", "a", "b", "", "", "a", "", "b", "", "a"] + return subsets[idx % len(subsets)] + + def _generate_dataset(self, config): + # counts = {(0,0):20, (0,1):20, (0,2):30, (1,0):20, (1,1):10, (1,2):20} + # attr1 = ['attr1', 'attr2'] + # attr2 = ['attr1', 'attr3'] + # config = { "label1": { "attrs": attr1, "counts": counts }, + # "label2": { "attrs": attr2, "counts": counts }} + iterable = [] + label_cat = LabelCategories() + idx = 0 + for label_id, label in enumerate(config.keys()): + anames = config[label]["attrs"] + counts = config[label]["counts"] + label_cat.add(label, attributes=anames) + if isinstance(counts, dict): + for attrs, count in counts.items(): + attributes = dict() + if isinstance(attrs, tuple): + for aname, value in zip(anames, attrs): + attributes[aname] = value + else: + attributes[anames[0]] = attrs + for _ in range(count): + idx += 1 + iterable.append( + DatasetItem(idx, subset=self._get_subset(idx), + annotations=[ + Label(label_id, attributes=attributes) + ], + ) + ) + else: + for _ in range(counts): + idx += 1 + iterable.append( + DatasetItem(idx, subset=self._get_subset(idx), + annotations=[Label(label_id)]) + ) + categories = {AnnotationType.label: label_cat} + dataset = Dataset.from_iterable(iterable, categories) + return dataset + + def test_split_for_classification_multi_class_no_attr(self): + config = { + "label1": {"attrs": None, "counts": 10}, + "label2": {"attrs": None, "counts": 20}, + "label3": {"attrs": None, "counts": 30}, + } + source = self._generate_dataset(config) + + splits = [("train", 0.7), ("test", 0.3)] + actual = splitter.ClassificationSplit(source, splits) + + self.assertEqual(42, len(actual.get_subset("train"))) + self.assertEqual(18, len(actual.get_subset("test"))) + + # check stats for train + stat_train = compute_ann_statistics(actual.get_subset("train")) + dist_train = stat_train["annotations"]["labels"]["distribution"] + self.assertEqual(7, dist_train["label1"][0]) + self.assertEqual(14, dist_train["label2"][0]) + self.assertEqual(21, dist_train["label3"][0]) + + # check stats for test + stat_test = compute_ann_statistics(actual.get_subset("test")) + dist_test = stat_test["annotations"]["labels"]["distribution"] + self.assertEqual(3, dist_test["label1"][0]) + self.assertEqual(6, dist_test["label2"][0]) + self.assertEqual(9, dist_test["label3"][0]) + + def test_split_for_classification_single_class_single_attr(self): + counts = {0: 10, 1: 20, 2: 30} + config = {"label": {"attrs": ["attr"], "counts": counts}} + source = self._generate_dataset(config) + + splits = [("train", 0.7), ("test", 0.3)] + actual = splitter.ClassificationSplit(source, splits) + + self.assertEqual(42, len(actual.get_subset("train"))) + self.assertEqual(18, len(actual.get_subset("test"))) + + # check stats for train + stat_train = compute_ann_statistics(actual.get_subset("train")) + attr_train = stat_train["annotations"]["labels"]["attributes"] + self.assertEqual(7, attr_train["attr"]["distribution"]["0"][0]) + self.assertEqual(14, attr_train["attr"]["distribution"]["1"][0]) + self.assertEqual(21, attr_train["attr"]["distribution"]["2"][0]) + + # check stats for test + stat_test = compute_ann_statistics(actual.get_subset("test")) + attr_test = stat_test["annotations"]["labels"]["attributes"] + self.assertEqual(3, attr_test["attr"]["distribution"]["0"][0]) + self.assertEqual(6, attr_test["attr"]["distribution"]["1"][0]) + self.assertEqual(9, attr_test["attr"]["distribution"]["2"][0]) + + def test_split_for_classification_single_class_multi_attr(self): + counts = { + (0, 0): 20, + (0, 1): 20, + (0, 2): 30, + (1, 0): 20, + (1, 1): 10, + (1, 2): 20, + } + attrs = ["attr1", "attr2"] + config = {"label": {"attrs": attrs, "counts": counts}} + source = self._generate_dataset(config) + + splits = [("train", 0.7), ("test", 0.3)] + actual = splitter.ClassificationSplit(source, splits) + + self.assertEqual(84, len(actual.get_subset("train"))) + self.assertEqual(36, len(actual.get_subset("test"))) + + # check stats for train + stat_train = compute_ann_statistics(actual.get_subset("train")) + attr_train = stat_train["annotations"]["labels"]["attributes"] + self.assertEqual(49, attr_train["attr1"]["distribution"]["0"][0]) + self.assertEqual(35, attr_train["attr1"]["distribution"]["1"][0]) + self.assertEqual(28, attr_train["attr2"]["distribution"]["0"][0]) + self.assertEqual(21, attr_train["attr2"]["distribution"]["1"][0]) + self.assertEqual(35, attr_train["attr2"]["distribution"]["2"][0]) + + # check stats for test + stat_test = compute_ann_statistics(actual.get_subset("test")) + attr_test = stat_test["annotations"]["labels"]["attributes"] + self.assertEqual(21, attr_test["attr1"]["distribution"]["0"][0]) + self.assertEqual(15, attr_test["attr1"]["distribution"]["1"][0]) + self.assertEqual(12, attr_test["attr2"]["distribution"]["0"][0]) + self.assertEqual(9, attr_test["attr2"]["distribution"]["1"][0]) + self.assertEqual(15, attr_test["attr2"]["distribution"]["2"][0]) + + def test_split_for_classification_multi_label_with_attr(self): + counts = { + (0, 0): 20, + (0, 1): 20, + (0, 2): 30, + (1, 0): 20, + (1, 1): 10, + (1, 2): 20, + } + attr1 = ["attr1", "attr2"] + attr2 = ["attr1", "attr3"] + config = { + "label1": {"attrs": attr1, "counts": counts}, + "label2": {"attrs": attr2, "counts": counts}, + } + source = self._generate_dataset(config) + + splits = [("train", 0.7), ("test", 0.3)] + actual = splitter.ClassificationSplit(source, splits) + + train = actual.get_subset("train") + test = actual.get_subset("test") + self.assertEqual(168, len(train)) + self.assertEqual(72, len(test)) + + # check stats for train + stat_train = compute_ann_statistics(train) + dist_train = stat_train["annotations"]["labels"]["distribution"] + self.assertEqual(84, dist_train["label1"][0]) + self.assertEqual(84, dist_train["label2"][0]) + attr_train = stat_train["annotations"]["labels"]["attributes"] + self.assertEqual(49 * 2, attr_train["attr1"]["distribution"]["0"][0]) + self.assertEqual(35 * 2, attr_train["attr1"]["distribution"]["1"][0]) + self.assertEqual(28, attr_train["attr2"]["distribution"]["0"][0]) + self.assertEqual(21, attr_train["attr2"]["distribution"]["1"][0]) + self.assertEqual(35, attr_train["attr2"]["distribution"]["2"][0]) + self.assertEqual(28, attr_train["attr3"]["distribution"]["0"][0]) + self.assertEqual(21, attr_train["attr3"]["distribution"]["1"][0]) + self.assertEqual(35, attr_train["attr3"]["distribution"]["2"][0]) + + # check stats for test + stat_test = compute_ann_statistics(test) + dist_test = stat_test["annotations"]["labels"]["distribution"] + self.assertEqual(36, dist_test["label1"][0]) + self.assertEqual(36, dist_test["label2"][0]) + attr_test = stat_test["annotations"]["labels"]["attributes"] + self.assertEqual(21 * 2, attr_test["attr1"]["distribution"]["0"][0]) + self.assertEqual(15 * 2, attr_test["attr1"]["distribution"]["1"][0]) + self.assertEqual(12, attr_test["attr2"]["distribution"]["0"][0]) + self.assertEqual(9, attr_test["attr2"]["distribution"]["1"][0]) + self.assertEqual(15, attr_test["attr2"]["distribution"]["2"][0]) + self.assertEqual(12, attr_test["attr3"]["distribution"]["0"][0]) + self.assertEqual(9, attr_test["attr3"]["distribution"]["1"][0]) + self.assertEqual(15, attr_test["attr3"]["distribution"]["2"][0]) + + # random seed test + r1 = splitter.ClassificationSplit(source, splits, seed=1234) + r2 = splitter.ClassificationSplit(source, splits, seed=1234) + r3 = splitter.ClassificationSplit(source, splits, seed=4321) + self.assertEqual( + list(r1.get_subset("test")), list(r2.get_subset("test")) + ) + self.assertNotEqual( + list(r1.get_subset("test")), list(r3.get_subset("test")) + ) + + def test_split_for_classification_gives_error(self): + with self.subTest("no label"): + source = Dataset.from_iterable([ + DatasetItem(1, annotations=[]), + DatasetItem(2, annotations=[]), + ], categories=["a", "b", "c"]) + + with self.assertRaisesRegex(Exception, "exactly one is expected"): + splits = [("train", 0.7), ("test", 0.3)] + actual = splitter.ClassificationSplit(source, splits) + len(actual.get_subset("train")) + + with self.subTest("multi label"): + source = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0), Label(1)]), + DatasetItem(2, annotations=[Label(0), Label(2)]), + ], categories=["a", "b", "c"]) + + with self.assertRaisesRegex(Exception, "exactly one is expected"): + splits = [("train", 0.7), ("test", 0.3)] + splitter.ClassificationSplit(source, splits) + len(actual.get_subset("train")) + + source = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], categories=["a", "b", "c"]) + + with self.subTest("wrong ratio"): + with self.assertRaisesRegex(Exception, "in the range"): + splits = [("train", -0.5), ("test", 1.5)] + splitter.ClassificationSplit(source, splits) + + with self.assertRaisesRegex(Exception, "Sum of ratios"): + splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] + splitter.ClassificationSplit(source, splits) + + with self.subTest("wrong subset name"): + with self.assertRaisesRegex(Exception, "Subset name"): + splits = [("train_", 0.5), ("val", 0.2), ("test", 0.3)] + splitter.ClassificationSplit(source, splits) + + def test_split_for_matching_reid(self): + counts = {i: (i % 3 + 1) * 7 for i in range(10)} + config = {"person": {"attrs": ["PID"], "counts": counts}} + source = self._generate_dataset(config) + + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + actual = splitter.MatchingReIDSplit(source, splits, test_splits) + + stats = dict() + for sname in ["train", "val", "test"]: + subset = actual.get_subset(sname) + stat_subset = compute_ann_statistics(subset)["annotations"] + stat_attr = stat_subset["labels"]["attributes"]["PID"] + stats[sname] = stat_attr + + for sname in ["gallery", "query"]: + subset = actual.get_subset_by_group(sname) + stat_subset = compute_ann_statistics(subset)["annotations"] + stat_attr = stat_subset["labels"]["attributes"]["PID"] + stats[sname] = stat_attr + + self.assertEqual(65, stats["train"]["count"]) # depends on heuristic + self.assertEqual(26, stats["val"]["count"]) # depends on heuristic + self.assertEqual(42, stats["test"]["count"]) # depends on heuristic + + train_ids = stats["train"]["values present"] + self.assertEqual(7, len(train_ids)) + self.assertEqual(train_ids, stats["val"]["values present"]) + + trainval = stats["train"]["count"] + stats["val"]["count"] + self.assertEqual(int(trainval * 0.5 / 0.7), stats["train"]["count"]) + self.assertEqual(int(trainval * 0.2 / 0.7), stats["val"]["count"]) + + dist_train = stats["train"]["distribution"] + dist_val = stats["val"]["distribution"] + for pid in train_ids: + total = counts[int(pid)] + self.assertEqual(int(total * 0.5 / 0.7), dist_train[pid][0]) + self.assertEqual(int(total * 0.2 / 0.7), dist_val[pid][0]) + + test_ids = stats["test"]["values present"] + self.assertEqual(3, len(test_ids)) + self.assertEqual(test_ids, stats["gallery"]["values present"]) + self.assertEqual(test_ids, stats["query"]["values present"]) + + dist_test = stats["test"]["distribution"] + dist_gallery = stats["gallery"]["distribution"] + dist_query = stats["query"]["distribution"] + for pid in test_ids: + total = counts[int(pid)] + self.assertEqual(total, dist_test[pid][0]) + self.assertEqual(int(total * 0.3 / 0.7), dist_gallery[pid][0]) + self.assertEqual(int(total * 0.4 / 0.7), dist_query[pid][0]) + + # random seed test + splits = [("train", 0.5), ("test", 0.5)] + r1 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=1234) + r2 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=1234) + r3 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=4321) + self.assertEqual( + list(r1.get_subset("test")), list(r2.get_subset("test")) + ) + self.assertNotEqual( + list(r1.get_subset("test")), list(r3.get_subset("test")) + ) + + def test_split_for_matching_reid_gives_error(self): + with self.subTest("no label"): + source = Dataset.from_iterable([ + DatasetItem(1, annotations=[]), + DatasetItem(2, annotations=[]), + ], categories=["a", "b", "c"]) + + with self.assertRaisesRegex(Exception, "exactly one is expected"): + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + actual = splitter.MatchingReIDSplit(source, splits, test_splits) + len(actual.get_subset("train")) + + with self.subTest(msg="multi label"): + source = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0), Label(1)]), + DatasetItem(2, annotations=[Label(0), Label(2)]), + ], categories=["a", "b", "c"]) + + with self.assertRaisesRegex(Exception, "exactly one is expected"): + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + actual = splitter.MatchingReIDSplit(source, splits, test_splits) + len(actual.get_subset("train")) + + counts = {i: (i % 3 + 1) * 7 for i in range(10)} + config = {"person": {"attrs": ["PID"], "counts": counts}} + source = self._generate_dataset(config) + with self.subTest("wrong ratio"): + with self.assertRaisesRegex(Exception, "in the range"): + splits = [("train", -0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + splitter.MatchingReIDSplit(source, splits, test_splits) + + with self.assertRaisesRegex(Exception, "Sum of ratios"): + splits = [("train", 0.6), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + splitter.MatchingReIDSplit(source, splits, test_splits) + + with self.assertRaisesRegex(Exception, "in the range"): + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", -0.4 / 0.7), ("gallery", 0.3 / 0.7)] + actual = splitter.MatchingReIDSplit(source, splits, test_splits) + len(actual.get_subset_by_group("query")) + + with self.assertRaisesRegex(Exception, "Sum of ratios"): + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.5 / 0.7), ("gallery", 0.3 / 0.7)] + actual = splitter.MatchingReIDSplit(source, splits, test_splits) + len(actual.get_subset_by_group("query")) + + with self.subTest("wrong subset name"): + with self.assertRaisesRegex(Exception, "Subset name"): + splits = [("_train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + splitter.MatchingReIDSplit(source, splits, test_splits) + + with self.assertRaisesRegex(Exception, "Subset name"): + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("_query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + actual = splitter.MatchingReIDSplit(source, splits, test_splits) + len(actual.get_subset_by_group("query")) + + with self.subTest("wrong attribute name for person id"): + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)] + actual = splitter.MatchingReIDSplit(source, splits, test_splits) + + with self.assertRaisesRegex(Exception, "Unknown group"): + actual.get_subset_by_group("_gallery") + + def _generate_detection_dataset(self, **kwargs): + append_bbox = kwargs.get("append_bbox") + with_attr = kwargs.get("with_attr", False) + nimages = kwargs.get("nimages", 10) + + label_cat = LabelCategories() + for i in range(6): + label = "label%d" % (i + 1) + if with_attr is True: + attributes = {"attr0", "attr%d" % (i + 1)} + else: + attributes = {} + label_cat.add(label, attributes=attributes) + categories = {AnnotationType.label: label_cat} + + iterable = [] + attr_val = 0 + totals = np.zeros(3) + objects = [(1, 5, 2), (3, 4, 1), (2, 3, 4), (1, 1, 1), (2, 4, 2)] + for img_id in range(nimages): + cnts = objects[img_id % len(objects)] + totals += cnts + annotations = [] + for label_id, count in enumerate(cnts): + attributes = {} + if with_attr: + attr_val += 1 + attributes["attr0"] = attr_val % 3 + attributes["attr%d" % (label_id + 1)] = attr_val % 2 + for ann_id in range(count): + append_bbox(annotations, label_id=label_id, ann_id=ann_id, + attributes=attributes) + item = DatasetItem(img_id, subset=self._get_subset(img_id), + annotations=annotations, attributes={"id": img_id}) + iterable.append(item) + + dataset = Dataset.from_iterable(iterable, categories) + return dataset, totals + + @staticmethod + def _get_append_bbox(dataset_type): + def append_bbox_coco(annotations, **kwargs): + annotations.append( + Bbox(1, 1, 2, 2, label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_bbox_voc(annotations, **kwargs): + annotations.append( + Bbox(1, 1, 2, 2, label=kwargs["label_id"], + id=kwargs["ann_id"] + 1, + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + ) + ) # obj + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + annotations.append( + Bbox(1, 1, 2, 2, label=kwargs["label_id"] + 3, + group=kwargs["ann_id"], + ) + ) # part + annotations.append( + Label(kwargs["label_id"] + 3, attributes=kwargs["attributes"]) + ) + + def append_bbox_yolo(annotations, **kwargs): + annotations.append(Bbox(1, 1, 2, 2, label=kwargs["label_id"])) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_bbox_cvat(annotations, **kwargs): + annotations.append( + Bbox(1, 1, 2, 2, label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + z_order=kwargs["ann_id"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_bbox_labelme(annotations, **kwargs): + annotations.append( + Bbox(1, 1, 2, 2, label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_bbox_mot(annotations, **kwargs): + annotations.append( + Bbox(1, 1, 2, 2, label=kwargs["label_id"], + attributes=kwargs["attributes"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_bbox_widerface(annotations, **kwargs): + annotations.append( + Bbox(1, 1, 2, 2, attributes=kwargs["attributes"]) + ) + annotations.append(Label(0, attributes=kwargs["attributes"])) + + functions = { + "coco": append_bbox_coco, + "voc": append_bbox_voc, + "yolo": append_bbox_yolo, + "cvat": append_bbox_cvat, + "labelme": append_bbox_labelme, + "mot": append_bbox_mot, + "widerface": append_bbox_widerface, + } + + func = functions.get(dataset_type, append_bbox_cvat) + return func + + def test_split_for_detection(self): + dtypes = ["coco", "voc", "yolo", "cvat", "labelme", "mot", "widerface"] + params = [] + for dtype in dtypes: + for with_attr in [False, True]: + params.append((dtype, with_attr, 10, 5, 3, 2)) + params.append((dtype, with_attr, 10, 7, 0, 3)) + + for dtype, with_attr, nimages, train, val, test in params: + source, _ = self._generate_detection_dataset( + append_bbox=self._get_append_bbox(dtype), + with_attr=with_attr, + nimages=nimages, + ) + total = np.sum([train, val, test]) + splits = [ + ("train", train / total), + ("val", val / total), + ("test", test / total), + ] + with self.subTest( + dtype=dtype, + with_attr=with_attr, + nimage=nimages, + train=train, + val=val, + test=test, + ): + actual = splitter.DetectionSplit(source, splits) + + self.assertEqual(train, len(actual.get_subset("train"))) + self.assertEqual(val, len(actual.get_subset("val"))) + self.assertEqual(test, len(actual.get_subset("test"))) + + # random seed test + source, _ = self._generate_detection_dataset( + append_bbox=self._get_append_bbox("cvat"), + with_attr=True, + nimages=10, + ) + + splits = [("train", 0.5), ("test", 0.5)] + r1 = splitter.DetectionSplit(source, splits, seed=1234) + r2 = splitter.DetectionSplit(source, splits, seed=1234) + r3 = splitter.DetectionSplit(source, splits, seed=4321) + self.assertEqual( + list(r1.get_subset("test")), list(r2.get_subset("test")) + ) + self.assertNotEqual( + list(r1.get_subset("test")), list(r3.get_subset("test")) + ) + + def test_split_for_detection_gives_error(self): + with self.subTest(msg="bbox annotation"): + source = Dataset.from_iterable([ + DatasetItem(1, annotations=[Label(0), Label(1)]), + DatasetItem(2, annotations=[Label(0), Label(2)]), + ], categories=["a", "b", "c"]) + + with self.assertRaisesRegex(Exception, "more than one bbox"): + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + actual = splitter.DetectionSplit(source, splits) + len(actual.get_subset("train")) + + source, _ = self._generate_detection_dataset( + append_bbox=self._get_append_bbox("cvat"), + with_attr=True, + nimages=5, + ) + + with self.subTest("wrong ratio"): + with self.assertRaisesRegex(Exception, "in the range"): + splits = [("train", -0.5), ("test", 1.5)] + splitter.DetectionSplit(source, splits) + + with self.assertRaisesRegex(Exception, "Sum of ratios"): + splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] + splitter.DetectionSplit(source, splits) + + with self.subTest("wrong subset name"): + with self.assertRaisesRegex(Exception, "Subset name"): + splits = [("train_", 0.5), ("val", 0.2), ("test", 0.3)] + splitter.DetectionSplit(source, splits)