diff --git a/CHANGELOG.md b/CHANGELOG.md index c6a57f84dc5a..a13ea5e8a243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Changed -- +- Allowed arbitrary subset count and names in classification and detection splitters () ### Deprecated - diff --git a/datumaro/plugins/splitter.py b/datumaro/plugins/splitter.py index e4e3b432f75a..69240f4ef4d7 100644 --- a/datumaro/plugins/splitter.py +++ b/datumaro/plugins/splitter.py @@ -4,10 +4,12 @@ import logging as log import numpy as np +from math import gcd from datumaro.components.extractor import (Transform, AnnotationType, DEFAULT_SUBSET_NAME) from datumaro.components.cli_plugin import CliPlugin +from datumaro.util import cast NEAR_ZERO = 1e-7 @@ -33,20 +35,23 @@ def _split_arg(s): raise argparse.ArgumentTypeError() return (parts[0], float(parts[1])) - def __init__(self, dataset, splits, seed): + def __init__(self, dataset, splits, seed, restrict=False): super().__init__(dataset) if splits is None: splits = self._default_split - snames, sratio = self._validate_splits(splits) + snames, sratio, subsets = self._validate_splits(splits, restrict) self._snames = snames self._sratio = sratio self._seed = seed - self._subsets = {"train", "val", "test"} # output subset names + # remove subset name restriction + # regarding https://github.com/openvinotoolkit/datumaro/issues/194 + # self._subsets = {"train", "val", "test"} # output subset names + self._subsets = subsets self._parts = [] self._length = "parent" @@ -70,21 +75,29 @@ def _get_uniq_annotations(dataset): return annotations @staticmethod - def _validate_splits(splits, valid=None): + def _validate_splits(splits, restrict=False): snames = [] ratios = [] - if valid is None: - valid = ["train", "val", "test"] + subsets = set() + valid = ["train", "val", "test"] + # remove subset name restriction + # regarding https://github.com/openvinotoolkit/datumaro/issues/194 for subset, ratio in splits: - assert subset in valid, \ - "Subset name must be one of %s, but got %s" % (valid, subset) + if restrict: + assert subset in valid, \ + "Subset name must be one of %s, got %s" % (valid, subset) assert 0.0 <= ratio and ratio <= 1.0, \ "Ratio is expected to be in the range " \ "[0, 1], but got %s for %s" % (ratio, subset) # ignore near_zero ratio because it may produce partition error. if ratio > NEAR_ZERO: + # handling duplication + if subset in snames: + raise Exception("Subset (%s) is duplicated" % subset) snames.append(subset) ratios.append(float(ratio)) + subsets.add(subset) + ratios = np.array(ratios) total_ratio = np.sum(ratios) @@ -94,15 +107,26 @@ def _validate_splits(splits, valid=None): % (splits, total_ratio) ) - return snames, ratios + return snames, ratios, subsets @staticmethod def _get_required(ratio): - min_value = np.max(ratio) - for i in ratio: - if NEAR_ZERO < i and i < min_value: - min_value = i - required = int(np.around(1.0) / min_value) + if len(ratio) < 2: + return 1 + + for scale in [10, 100]: + farray = np.array(ratio) * scale + iarray = farray.astype(int) + if np.array_equal(iarray, farray): + break + + # find gcd + common_divisor = iarray[0] + for val in iarray[1:]: + common_divisor = gcd(common_divisor, val) + + required = np.sum(np.array(iarray / common_divisor).astype(int)) + return required @staticmethod @@ -129,37 +153,75 @@ def _group_by_attr(items): Returns: by_attributes: dict of { combination-of-attrs : list of index } """ + + # float--> numerical, others(int, string, bool) --> categorical + def _is_float(value): + if isinstance(value, str): + casted = cast(value, float) + if casted is not None: + if cast(casted, str) == value: + return True + return False + elif isinstance(value, float): + cast(value, float) + return True + return False + # group by attributes by_attributes = dict() for idx, ann in items: - attributes = tuple(sorted(ann.attributes.items())) + # ignore numeric attributes + filtered = {} + for k, v in ann.attributes.items(): + if _is_float(v): + continue + filtered[k] = v + attributes = tuple(sorted(filtered.items())) if attributes not in by_attributes: by_attributes[attributes] = [] by_attributes[attributes].append(idx) + return by_attributes def _split_by_attr(self, datasets, snames, ratio, out_splits, - dataset_key=None): + merge_small_classes=True): + + def _split_indice(indice): + sections = self._get_sections(len(indice), ratio) + splits = np.array_split(indice, sections) + for subset, split in zip(snames, splits): + if 0 < len(split): + out_splits[subset].extend(split) + required = self._get_required(ratio) - if dataset_key is None: - dataset_key = "label" - for key, items in datasets.items(): + rest = [] + for _, items in datasets.items(): np.random.shuffle(items) by_attributes = self._group_by_attr(items) - for attributes, indice in by_attributes.items(): - gname = "%s: %s, attrs: %s" % (dataset_key, key, attributes) - splits = self._split_indice(indice, gname, ratio, required) - for subset, split in zip(snames, splits): - if 0 < len(split): - out_splits[subset].extend(split) - - def _split_indice(self, indice, group_name, ratio, required): - filtered_size = len(indice) - if filtered_size < required: - log.warning("Not enough samples for a group, '%s'" % group_name) - sections = self._get_sections(filtered_size, ratio) - splits = np.array_split(indice, sections) - return splits + attr_names = list(by_attributes.keys()) + np.random.shuffle(attr_names) # add randomness + for attr in attr_names: + indice = by_attributes[attr] + quo = len(indice) // required + if quo > 0: + filtered_size = quo * required + _split_indice(indice[:filtered_size]) + rest.extend(indice[filtered_size:]) + else: + rest.extend(indice) + + quo = len(rest) // required + if quo > 0: + filtered_size = quo * required + _split_indice(rest[:filtered_size]) + rest = rest[filtered_size:] + + if not merge_small_classes and len(rest) > 0: + _split_indice(rest) + rest = [] + + if len(rest) > 0: + _split_indice(rest) def _find_split(self, index): for subset_indices, subset in self._parts: @@ -181,7 +243,7 @@ def __iter__(self): class ClassificationSplit(_TaskSpecificSplit): """ - Splits dataset into train/val/test set in class-wise manner. |n + Splits dataset into subsets(train/val/test) in class-wise manner. |n Splits dataset images in the specified ratio, keeping the initial class distribution.|n |n @@ -201,7 +263,6 @@ def __init__(self, dataset, splits, seed=None): dataset : Dataset splits : list A list of (subset(str), ratio(float)) - Subset is expected to be one of ["train", "val", "test"]. The sum of ratios is expected to be 1. seed : int, optional """ @@ -214,6 +275,7 @@ def _split_dataset(self): # 1. group by label by_labels = dict() annotations = self._get_uniq_annotations(self._extractor) + for idx, ann in enumerate(annotations): label = getattr(ann, 'label', None) if label not in by_labels: @@ -290,7 +352,7 @@ def __init__(self, dataset, splits, query=None, if this is not specified, label would be used. seed : int, optional """ - super().__init__(dataset, splits, seed) + super().__init__(dataset, splits, seed, restrict=True) if query is None: query = self._default_query_ratio @@ -300,7 +362,7 @@ def __init__(self, dataset, splits, query=None, "[0, 1], but got %f" % query test_splits = [('test-query', query), ('test-gallery', 1.0 - query)] - # reset output subset names + # remove subset name restriction self._subsets = {"train", "val", "test-gallery", "test-query"} self._test_splits = test_splits self._attr_for_id = attr_for_id @@ -350,7 +412,6 @@ def _split_dataset(self): splits = np.array_split(IDs, sections) testset = {pid: by_id[pid] for pid in splits[0]} trval = {pid: by_id[pid] for pid in splits[1]} - # follow the ratio of datasetitems as possible. # naive heuristic: exchange the best item one by one. expected_count = int(len(self._extractor) * split_ratio[0]) @@ -373,7 +434,7 @@ def _split_dataset(self): test_ratio.append(float(ratio)) self._split_by_attr(testset, test_snames, test_ratio, by_splits, - dataset_key=attr_for_id) + merge_small_classes=False) # 3. split 'trval' into 'train' and 'val' trval_snames = ["train", "val"] @@ -395,7 +456,7 @@ def _split_dataset(self): else: trval_ratio /= total_ratio # normalize self._split_by_attr(trval, trval_snames, trval_ratio, by_splits, - dataset_key=attr_for_id) + merge_small_classes=False) self._set_parts(by_splits) @@ -448,7 +509,7 @@ def _rebalancing(test, trval, expected_count, testset_total): class DetectionSplit(_TaskSpecificSplit): """ - Splits a dataset into train/val/test subsets for detection task, + Splits a dataset into subsets(train/val/test) for detection task, using object annotations as a basis for splitting.|n Tries to produce an image split with the specified ratio, keeping the initial distribution of class objects.|n @@ -476,7 +537,6 @@ def __init__(self, dataset, splits, seed=None): dataset : Dataset splits : list A list of (subset(str), ratio(float)) - Subset is expected to be one of ["train", "val", "test"]. The sum of ratios is expected to be 1. seed : int, optional """ @@ -507,79 +567,105 @@ def _split_dataset(self): by_labels = self._group_by_bbox_labels(self._extractor) # 2. group by attributes - by_combinations = dict() - for label, items in by_labels.items(): + required = self._get_required(sratio) + by_combinations = list() + for _, items in by_labels.items(): by_attributes = self._group_by_attr(items) - for attributes, indice in by_attributes.items(): - gname = "label: %s, attributes: %s" % (label, attributes) - by_combinations[gname] = indice + # merge groups which have too small samples. + attr_names = list(by_attributes.keys()) + np.random.shuffle(attr_names) # add randomless + cluster = [] + minumum = max(required, len(items) * 0.1) # temp solution + for attr in attr_names: + indice = by_attributes[attr] + if len(indice) >= minumum: + by_combinations.append(indice) + else: + cluster.extend(indice) + if len(cluster) >= minumum: + by_combinations.append(cluster) + cluster = [] + if len(cluster) > 0: + by_combinations.append(cluster) + cluster = [] + + total = len(self._extractor) # total number of GT samples per label-attr combinations - n_combs = {k: len(v) for k, v in by_combinations.items()} + n_combs = [len(v) for v in by_combinations] # 3-1. initially count per-image GT samples - scores_all = {} + counts_all = {idx: dict() for idx in range(total)} + for idx_comb, indice in enumerate(by_combinations): + for idx in indice: + if idx_comb not in counts_all[idx]: + counts_all[idx] = {idx_comb: 1} + else: + counts_all[idx][idx_comb] += 1 + init_scores = {} - for idx, _ in enumerate(self._extractor): - counts = {k: v.count(idx) for k, v in by_combinations.items()} - scores_all[idx] = counts - init_scores[idx] = np.sum( - [v / n_combs[k] for k, v in counts.items()] - ) + for idx, counts in counts_all.items(): + norm_sum = 0.0 + for idx_comb, count in counts.items(): + norm_sum += count / n_combs[idx_comb] + init_scores[idx] = norm_sum by_splits = dict() for sname in self._subsets: by_splits[sname] = [] - total = len(self._extractor) target_size = dict() expected = [] # expected numbers of per split GT samples for sname, ratio in zip(subsets, sratio): target_size[sname] = total * ratio - expected.append( - (sname, {k: v * ratio for k, v in n_combs.items()}) - ) + expected.append([sname, np.array(n_combs) * ratio]) # functions for keep the # of annotations not exceed the expected num def compute_penalty(counts, n_combs): p = 0 - for k, v in counts.items(): - p += max(0, (v / n_combs[k]) - 1.0) + for idx_comb, v in counts.items(): + p += max(0, (v / n_combs[idx_comb]) - 1.0) return p def update_nc(counts, n_combs): - for k, v in counts.items(): - n_combs[k] = max(0, n_combs[k] - v) - if n_combs[k] == 0: - n_combs[k] = -1 - return n_combs + for idx_comb, v in counts.items(): + n_combs[idx_comb] = max(0, n_combs[idx_comb] - v) + if n_combs[idx_comb] == 0: + n_combs[idx_comb] = -1 + + by_scores = dict() + for idx, score in init_scores.items(): + if score not in by_scores: + by_scores[score] = [idx] + else: + by_scores[score].append(idx) # 3-2. assign each DatasetItem to a split, one by one - for idx, _ in sorted( - init_scores.items(), key=lambda item: item[1], reverse=True - ): - counts = scores_all[idx] - - # shuffling split order to add randomness - # when two or more splits have the same penalty value - np.random.shuffle(expected) - - pp = [] - for sname, nc in expected: - if target_size[sname] <= len(by_splits[sname]): - # the split has enough images, - # stop adding more images to this split - pp.append(1e08) - else: - # compute penalty based on the number of GT samples - # added in the split - pp.append(compute_penalty(counts, nc)) - - # we push an image to a split with the minimum penalty - midx = np.argmin(pp) - - sname, nc = expected[midx] - by_splits[sname].append(idx) - update_nc(counts, nc) + for score in sorted(by_scores.keys(), reverse=True): + indice = by_scores[score] + np.random.shuffle(indice) # add randomness for the same score + + for idx in indice: + counts = counts_all[idx] + # shuffling split order to add randomness + # when two or more splits have the same penalty value + np.random.shuffle(expected) + + pp = [] + for sname, nc in expected: + if target_size[sname] <= len(by_splits[sname]): + # the split has enough images, + # stop adding more images to this split + pp.append(1e08) + else: + # compute penalty based on the number of GT samples + # added in the split + pp.append(compute_penalty(counts, nc)) + + # we push an image to a split with the minimum penalty + midx = np.argmin(pp) + sname, nc = expected[midx] + by_splits[sname].append(idx) + update_nc(counts, nc) self._set_parts(by_splits) diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 838694a62b5f..8091cb1f1ccb 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -44,6 +44,7 @@ def _generate_dataset(self, config): annotations=[ Label(label_id, attributes=attributes) ], + image=np.ones((1, 1, 3)) ) ) else: @@ -51,7 +52,8 @@ def _generate_dataset(self, config): idx += 1 iterable.append( DatasetItem(idx, subset=self._get_subset(idx), - annotations=[Label(label_id)]) + annotations=[Label(label_id)], + image=np.ones((1, 1, 3))) ) categories = {AnnotationType.label: label_cat} dataset = Dataset.from_iterable(iterable, categories) @@ -123,29 +125,37 @@ def test_split_for_classification_single_class_multi_attr(self): config = {"label": {"attrs": attrs, "counts": counts}} source = self._generate_dataset(config) - splits = [("train", 0.7), ("test", 0.3)] - actual = splitter.ClassificationSplit(source, splits) - - self.assertEqual(84, len(actual.get_subset("train"))) - self.assertEqual(36, len(actual.get_subset("test"))) - - # check stats for train - stat_train = compute_ann_statistics(actual.get_subset("train")) - attr_train = stat_train["annotations"]["labels"]["attributes"] - self.assertEqual(49, attr_train["attr1"]["distribution"]["0"][0]) - self.assertEqual(35, attr_train["attr1"]["distribution"]["1"][0]) - self.assertEqual(28, attr_train["attr2"]["distribution"]["0"][0]) - self.assertEqual(21, attr_train["attr2"]["distribution"]["1"][0]) - self.assertEqual(35, attr_train["attr2"]["distribution"]["2"][0]) - - # check stats for test - stat_test = compute_ann_statistics(actual.get_subset("test")) - attr_test = stat_test["annotations"]["labels"]["attributes"] - self.assertEqual(21, attr_test["attr1"]["distribution"]["0"][0]) - self.assertEqual(15, attr_test["attr1"]["distribution"]["1"][0]) - self.assertEqual(12, attr_test["attr2"]["distribution"]["0"][0]) - self.assertEqual(9, attr_test["attr2"]["distribution"]["1"][0]) - self.assertEqual(15, attr_test["attr2"]["distribution"]["2"][0]) + with self.subTest("zero remainder"): + splits = [("train", 0.7), ("test", 0.3)] + actual = splitter.ClassificationSplit(source, splits) + + self.assertEqual(84, len(actual.get_subset("train"))) + self.assertEqual(36, len(actual.get_subset("test"))) + + # check stats for train + stat_train = compute_ann_statistics(actual.get_subset("train")) + attr_train = stat_train["annotations"]["labels"]["attributes"] + self.assertEqual(49, attr_train["attr1"]["distribution"]["0"][0]) + self.assertEqual(35, attr_train["attr1"]["distribution"]["1"][0]) + self.assertEqual(28, attr_train["attr2"]["distribution"]["0"][0]) + self.assertEqual(21, attr_train["attr2"]["distribution"]["1"][0]) + self.assertEqual(35, attr_train["attr2"]["distribution"]["2"][0]) + + # check stats for test + stat_test = compute_ann_statistics(actual.get_subset("test")) + attr_test = stat_test["annotations"]["labels"]["attributes"] + self.assertEqual(21, attr_test["attr1"]["distribution"]["0"][0]) + self.assertEqual(15, attr_test["attr1"]["distribution"]["1"][0]) + self.assertEqual(12, attr_test["attr2"]["distribution"]["0"][0]) + self.assertEqual(9, attr_test["attr2"]["distribution"]["1"][0]) + self.assertEqual(15, attr_test["attr2"]["distribution"]["2"][0]) + + with self.subTest("non-zero remainder"): + splits = [("train", 0.95), ("test", 0.05)] + actual = splitter.ClassificationSplit(source, splits) + + self.assertEqual(114, len(actual.get_subset("train"))) + self.assertEqual(6, len(actual.get_subset("test"))) def test_split_for_classification_multi_label_with_attr(self): counts = { @@ -221,7 +231,7 @@ def test_split_for_classification_zero_ratio(self): splits = [("train", 0.1), ("val", 0.9), ("test", 0.0)] actual = splitter.ClassificationSplit(source, splits) - + self.assertEqual(1, len(actual.get_subset("train"))) self.assertEqual(4, len(actual.get_subset("val"))) self.assertEqual(0, len(actual.get_subset("test"))) @@ -263,9 +273,9 @@ def test_split_for_classification_gives_error(self): splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] splitter.ClassificationSplit(source, splits) - with self.subTest("wrong subset name"): - with self.assertRaisesRegex(Exception, "Subset name"): - splits = [("train_", 0.5), ("val", 0.2), ("test", 0.3)] + with self.subTest("duplicated subset name"): + with self.assertRaisesRegex(Exception, "duplicated"): + splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] splitter.ClassificationSplit(source, splits) def test_split_for_reidentification(self): @@ -427,6 +437,11 @@ def test_split_for_reidentification_gives_error(self): splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] actual = splitter.ReidentificationSplit(source, splits, -query) + with self.subTest("duplicated subset name"): + with self.assertRaisesRegex(Exception, "duplicated"): + splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] + splitter.ReidentificationSplit(source, splits, query) + with self.subTest("wrong subset name"): with self.assertRaisesRegex(Exception, "Subset name"): splits = [("_train", 0.5), ("val", 0.2), ("test", 0.3)] @@ -650,7 +665,36 @@ def test_split_for_detection_gives_error(self): splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] splitter.DetectionSplit(source, splits) - with self.subTest("wrong subset name"): - with self.assertRaisesRegex(Exception, "Subset name"): - splits = [("train_", 0.5), ("val", 0.2), ("test", 0.3)] + with self.subTest("duplicated subset name"): + with self.assertRaisesRegex(Exception, "duplicated"): + splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] splitter.DetectionSplit(source, splits) + + def test_no_subset_name_and_count_restriction(self): + splits = [("_train", 0.5), ("valid", 0.1), ("valid2", 0.1), + ("test*", 0.2), ("test2", 0.1)] + + with self.subTest("classification"): + config = { + "label1": {"attrs": None, "counts": 10} + } + source = self._generate_dataset(config) + actual = splitter.ClassificationSplit(source, splits) + self.assertEqual(5, len(actual.get_subset("_train"))) + self.assertEqual(1, len(actual.get_subset("valid"))) + self.assertEqual(1, len(actual.get_subset("valid2"))) + self.assertEqual(2, len(actual.get_subset("test*"))) + self.assertEqual(1, len(actual.get_subset("test2"))) + + with self.subTest("detection"): + source, _ = self._generate_detection_dataset( + append_bbox=self._get_append_bbox("cvat"), + with_attr=True, + nimages=10, + ) + actual = splitter.DetectionSplit(source, splits) + self.assertEqual(5, len(actual.get_subset("_train"))) + self.assertEqual(1, len(actual.get_subset("valid"))) + self.assertEqual(1, len(actual.get_subset("valid2"))) + self.assertEqual(2, len(actual.get_subset("test*"))) + self.assertEqual(1, len(actual.get_subset("test2")))