Skip to content

Commit

Permalink
[splitter] add random seed test and modify formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jihyeonyi committed Feb 11, 2021
1 parent faaf841 commit 7c74aa7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 49 deletions.
2 changes: 1 addition & 1 deletion datumaro/plugins/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _split_dataset(self):
)

by_splits = dict()
for sname in subsets:
for sname in self._subsets:
by_splits[sname] = []

total = len(self._extractor)
Expand Down
121 changes: 73 additions & 48 deletions tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,13 @@ def test_split_for_classification_multi_label_with_attr(self):
splits = [("train", 0.7), ("test", 0.3)]
actual = splitter.ClassificationSplit(source, splits)

self.assertEqual(168, len(actual.get_subset("train")))
self.assertEqual(72, len(actual.get_subset("test")))
train = actual.get_subset("train")
test = actual.get_subset("test")
self.assertEqual(168, len(train))
self.assertEqual(72, len(test))

# check stats for train
stat_train = compute_ann_statistics(actual.get_subset("train"))
stat_train = compute_ann_statistics(train)
dist_train = stat_train["annotations"]["labels"]["distribution"]
self.assertEqual(84, dist_train["label1"][0])
self.assertEqual(84, dist_train["label2"][0])
Expand All @@ -196,7 +198,7 @@ def test_split_for_classification_multi_label_with_attr(self):
self.assertEqual(35, attr_train["attr3"]["distribution"]["2"][0])

# check stats for test
stat_test = compute_ann_statistics(actual.get_subset("test"))
stat_test = compute_ann_statistics(test)
dist_test = stat_test["annotations"]["labels"]["distribution"]
self.assertEqual(36, dist_test["label1"][0])
self.assertEqual(36, dist_test["label2"][0])
Expand All @@ -210,6 +212,17 @@ def test_split_for_classification_multi_label_with_attr(self):
self.assertEqual(9, attr_test["attr3"]["distribution"]["1"][0])
self.assertEqual(15, attr_test["attr3"]["distribution"]["2"][0])

# random seed test
r1 = splitter.ClassificationSplit(source, splits, seed=1234)
r2 = splitter.ClassificationSplit(source, splits, seed=1234)
r3 = splitter.ClassificationSplit(source, splits, seed=4321)
self.assertEqual(
list(r1.get_subset("test")), list(r2.get_subset("test"))
)
self.assertNotEqual(
list(r1.get_subset("test")), list(r3.get_subset("test"))
)

def test_split_for_classification_gives_error(self):
with self.subTest("no label"):
source = Dataset.from_iterable(
Expand Down Expand Up @@ -263,9 +276,9 @@ def test_split_for_matching_reid(self):
config = {"person": {"attrs": ["PID"], "counts": counts}}
source = self._generate_dataset(config)

id_splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
actual = splitter.MatchingReIDSplit(source, id_splits, test_splits)
actual = splitter.MatchingReIDSplit(source, splits, test_splits)

stats = dict()
for sname in ["train", "val", "test"]:
Expand Down Expand Up @@ -313,6 +326,18 @@ def test_split_for_matching_reid(self):
self.assertEqual(int(total * 0.3 / 0.7), dist_gallery[pid][0])
self.assertEqual(int(total * 0.4 / 0.7), dist_query[pid][0])

# random seed test
splits = [("train", 0.5), ("test", 0.5)]
r1 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=1234)
r2 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=1234)
r3 = splitter.MatchingReIDSplit(source, splits, test_splits, seed=4321)
self.assertEqual(
list(r1.get_subset("test")), list(r2.get_subset("test"))
)
self.assertNotEqual(
list(r1.get_subset("test")), list(r3.get_subset("test"))
)

def test_split_for_matching_reid_gives_error(self):
with self.subTest("no label"):
source = Dataset.from_iterable(
Expand All @@ -323,10 +348,10 @@ def test_split_for_matching_reid_gives_error(self):
categories=["a", "b", "c"],
)
with self.assertRaisesRegex(Exception, "exact one label"):
id_splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
actual = splitter.MatchingReIDSplit(
source, id_splits, test_splits
source, splits, test_splits
)
len(actual.get_subset("train"))

Expand All @@ -339,10 +364,10 @@ def test_split_for_matching_reid_gives_error(self):
categories=["a", "b", "c"],
)
with self.assertRaisesRegex(Exception, "exact one label"):
id_splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
actual = splitter.MatchingReIDSplit(
source, id_splits, test_splits
source, splits, test_splits
)
len(actual.get_subset("train"))

Expand All @@ -351,45 +376,45 @@ def test_split_for_matching_reid_gives_error(self):
source = self._generate_dataset(config)
with self.subTest("wrong ratio"):
with self.assertRaisesRegex(Exception, "in the range"):
id_splits = [("train", -0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", -0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
splitter.MatchingReIDSplit(source, id_splits, test_splits)
splitter.MatchingReIDSplit(source, splits, test_splits)
with self.assertRaisesRegex(Exception, "Sum of ratios"):
id_splits = [("train", 0.6), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.6), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
splitter.MatchingReIDSplit(source, id_splits, test_splits)
splitter.MatchingReIDSplit(source, splits, test_splits)
with self.assertRaisesRegex(Exception, "in the range"):
id_splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", -0.4 / 0.7), ("gallery", 0.3 / 0.7)]
actual = splitter.MatchingReIDSplit(
source, id_splits, test_splits
source, splits, test_splits
)
len(actual.get_subset_by_group("query"))
with self.assertRaisesRegex(Exception, "Sum of ratios"):
id_splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.5 / 0.7), ("gallery", 0.3 / 0.7)]
actual = splitter.MatchingReIDSplit(
source, id_splits, test_splits
source, splits, test_splits
)
len(actual.get_subset_by_group("query"))

with self.subTest("wrong subset name"):
with self.assertRaisesRegex(Exception, "Subset name"):
id_splits = [("_train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("_train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
splitter.MatchingReIDSplit(source, id_splits, test_splits)
splitter.MatchingReIDSplit(source, splits, test_splits)
with self.assertRaisesRegex(Exception, "Subset name"):
id_splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("_query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
actual = splitter.MatchingReIDSplit(
source, id_splits, test_splits
source, splits, test_splits
)
len(actual.get_subset_by_group("query"))

with self.subTest("wrong attribute name for person id"):
id_splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)]
test_splits = [("query", 0.4 / 0.7), ("gallery", 0.3 / 0.7)]
actual = splitter.MatchingReIDSplit(source, id_splits, test_splits)
actual = splitter.MatchingReIDSplit(source, splits, test_splits)
with self.assertRaisesRegex(Exception, "Unknown group"):
actual.get_subset_by_group("_gallery")

Expand Down Expand Up @@ -445,10 +470,7 @@ def _get_append_bbox(dataset_type):
def append_bbox_coco(annotations, **kwargs):
annotations.append(
Bbox(
1,
1,
2,
2,
1, 1, 2, 2,
label=kwargs["label_id"],
id=kwargs["ann_id"],
attributes=kwargs["attributes"],
Expand All @@ -462,10 +484,7 @@ def append_bbox_coco(annotations, **kwargs):
def append_bbox_voc(annotations, **kwargs):
annotations.append(
Bbox(
1,
1,
2,
2,
1, 1, 2, 2,
label=kwargs["label_id"],
id=kwargs["ann_id"] + 1,
attributes=kwargs["attributes"],
Expand All @@ -477,10 +496,7 @@ def append_bbox_voc(annotations, **kwargs):
)
annotations.append(
Bbox(
1,
1,
2,
2,
1, 1, 2, 2,
label=kwargs["label_id"] + 3,
group=kwargs["ann_id"],
)
Expand All @@ -498,10 +514,7 @@ def append_bbox_yolo(annotations, **kwargs):
def append_bbox_cvat(annotations, **kwargs):
annotations.append(
Bbox(
1,
1,
2,
2,
1, 1, 2, 2,
label=kwargs["label_id"],
id=kwargs["ann_id"],
attributes=kwargs["attributes"],
Expand All @@ -516,10 +529,7 @@ def append_bbox_cvat(annotations, **kwargs):
def append_bbox_labelme(annotations, **kwargs):
annotations.append(
Bbox(
1,
1,
2,
2,
1, 1, 2, 2,
label=kwargs["label_id"],
attributes=kwargs["attributes"],
id=kwargs["ann_id"],
Expand All @@ -532,10 +542,7 @@ def append_bbox_labelme(annotations, **kwargs):
def append_bbox_mot(annotations, **kwargs):
annotations.append(
Bbox(
1,
1,
2,
2,
1, 1, 2, 2,
label=kwargs["label_id"],
attributes=kwargs["attributes"],
)
Expand Down Expand Up @@ -597,6 +604,24 @@ def test_split_for_detection(self):
self.assertEqual(val, len(actual.get_subset("val")))
self.assertEqual(test, len(actual.get_subset("test")))

# random seed test
source, _ = self._generate_detection_dataset(
append_bbox=self._get_append_bbox("cvat"),
with_attr=True,
nimages=10,
)

splits = [("train", 0.5), ("test", 0.5)]
r1 = splitter.DetectionSplit(source, splits, seed=1234)
r2 = splitter.DetectionSplit(source, splits, seed=1234)
r3 = splitter.DetectionSplit(source, splits, seed=4321)
self.assertEqual(
list(r1.get_subset("test")), list(r2.get_subset("test"))
)
self.assertNotEqual(
list(r1.get_subset("test")), list(r3.get_subset("test"))
)

def test_split_for_detection_gives_error(self):
with self.subTest(msg="bbox annotation"):
source = Dataset.from_iterable(
Expand Down

0 comments on commit 7c74aa7

Please sign in to comment.