diff --git a/test/test_models_detection_anchor_utils.py b/test/test_models_detection_anchor_utils.py new file mode 100644 index 00000000000..d606b7848e4 --- /dev/null +++ b/test/test_models_detection_anchor_utils.py @@ -0,0 +1,15 @@ +import torch +import unittest +from torchvision.models.detection.anchor_utils import AnchorGenerator +from torchvision.models.detection.image_list import ImageList + + +class Tester(unittest.TestCase): + def test_incorrect_anchors(self): + incorrect_sizes = ((2, 4, 8), (32, 8), ) + incorrect_aspects = (0.5, 1.0) + anc = AnchorGenerator(incorrect_sizes, incorrect_aspects) + image1 = torch.randn(3, 800, 800) + image_list = ImageList(image1, [(800, 800)]) + feature_maps = [torch.randn(1, 50)] + self.assertRaises(ValueError, anc, image_list, feature_maps)