diff --git a/test/common_utils.py b/test/common_utils.py index e5713dc0832..9f0d6a2dcf7 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -105,16 +105,14 @@ def remove_prefix_suffix(text, prefix, suffix): "expect", munged_id) - subname_output = "" if subname: expected_file += "_" + subname - subname_output = " ({})".format(subname) expected_file += "_expect.pkl" if not ACCEPT and not os.path.exists(expected_file): raise RuntimeError( - ("No expect file exists for {}{}; to accept the current output, run:\n" - "python {} {} --accept").format(munged_id, subname_output, __main__.__file__, munged_id)) + ("No expect file exists for {}; to accept the current output, run:\n" + "python {} {} --accept").format(os.path.basename(expected_file), __main__.__file__, munged_id)) return expected_file @@ -139,11 +137,13 @@ def assertExpected(self, output, subname=None, prec=None, strip_suffix=None): expected_file = self._get_expected_file(subname, strip_suffix) if ACCEPT: - print("Accepting updated output for {}:\n\n{}".format(os.path.basename(expected_file), output)) + filename = {os.path.basename(expected_file)} + print("Accepting updated output for {}:\n\n{}".format(filename, output)) torch.save(output, expected_file) MAX_PICKLE_SIZE = 50 * 1000 # 50 KB binary_size = os.path.getsize(expected_file) - self.assertTrue(binary_size <= MAX_PICKLE_SIZE) + if binary_size > MAX_PICKLE_SIZE: + raise RuntimeError("The output for {}, is larger than 50kb".format(filename)) else: expected = torch.load(expected_file) self.assertEqual(output, expected, prec=prec) diff --git a/test/expect/ModelTester.test_fasterrcnn_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_resnet50_fpn_expect.pkl index 3e4fc8ec641..878b9f5d431 100644 Binary files a/test/expect/ModelTester.test_fasterrcnn_resnet50_fpn_expect.pkl and b/test/expect/ModelTester.test_fasterrcnn_resnet50_fpn_expect.pkl differ diff --git a/test/expect/ModelTester.test_keypointrcnn_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_keypointrcnn_resnet50_fpn_expect.pkl index a12333d6fb1..18e14836238 100644 Binary files a/test/expect/ModelTester.test_keypointrcnn_resnet50_fpn_expect.pkl and b/test/expect/ModelTester.test_keypointrcnn_resnet50_fpn_expect.pkl differ diff --git a/test/expect/ModelTester.test_maskrcnn_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_maskrcnn_resnet50_fpn_expect.pkl index c05342100a3..f2406f8cabb 100644 Binary files a/test/expect/ModelTester.test_maskrcnn_resnet50_fpn_expect.pkl and b/test/expect/ModelTester.test_maskrcnn_resnet50_fpn_expect.pkl differ diff --git a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl index 548b0a22e1c..1bffc79fd07 100644 Binary files a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl and b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index e27021c4337..7a8e1d83b6e 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,14 +1,15 @@ from common_utils import TestCase, map_nested_tensor_object, freeze_rng_state from collections import OrderedDict from itertools import product +import functools +import operator import torch import torch.nn as nn import numpy as np from torchvision import models import unittest import random - -from torchvision.models.detection._utils import overwrite_eps +import warnings def set_rng_seed(seed): @@ -88,14 +89,10 @@ def get_available_video_models(): # trying autocast. However, they still try an autocasted forward pass, so they still ensure # autocast coverage suffices to prevent dtype errors in each model. autocast_flaky_numerics = ( - "fasterrcnn_resnet50_fpn", "inception_v3", - "keypointrcnn_resnet50_fpn", - "maskrcnn_resnet50_fpn", "resnet101", "resnet152", "wide_resnet101_2", - "retinanet_resnet50_fpn", ) @@ -148,10 +145,9 @@ def _test_detection_model(self, name, dev): set_rng_seed(0) kwargs = {} if "retinanet" in name: - kwargs["score_thresh"] = 0.013 + # Reduce the default threshold to ensure the returned boxes are not empty. + kwargs["score_thresh"] = 0.01 model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) - if "keypointrcnn" in name or "retinanet" in name: - overwrite_eps(model, 0.0) model.eval().to(device=dev) input_shape = (3, 300, 300) # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests @@ -163,15 +159,22 @@ def _test_detection_model(self, name, dev): def check_out(out): self.assertEqual(len(out), 1) + def compact(tensor): + size = tensor.size() + elements_per_sample = functools.reduce(operator.mul, size[1:], 1) + if elements_per_sample > 30: + return compute_mean_std(tensor) + else: + return subsample_tensor(tensor) + def subsample_tensor(tensor): - num_elems = tensor.numel() + num_elems = tensor.size(0) num_samples = 20 if num_elems <= num_samples: return tensor - flat_tensor = tensor.flatten() ith_index = num_elems // num_samples - return flat_tensor[ith_index - 1::ith_index] + return tensor[ith_index - 1::ith_index] def compute_mean_std(tensor): # can't compute mean of integral tensor @@ -180,18 +183,32 @@ def compute_mean_std(tensor): std = torch.std(tensor) return {"mean": mean, "std": std} - if name == "maskrcnn_resnet50_fpn": - # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now - # compare results with mean and std - test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std) - # mean values are small, use large prec - self.assertExpected(test_value, prec=.01, strip_suffix="_" + dev) - else: - self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), - prec=0.01, - strip_suffix="_" + dev) - - check_out(out) + output = map_nested_tensor_object(out, tensor_map_fn=compact) + prec = 0.01 + strip_suffix = "_" + dev + try: + # We first try to assert the entire output if possible. This is not + # only the best way to assert results but also handles the cases + # where we need to create a new expected result. + self.assertExpected(output, prec=prec, strip_suffix=strip_suffix) + except AssertionError: + # Unfortunately detection models are flaky due to the unstable sort + # in NMS. If matching across all outputs fails, use the same approach + # as in NMSTester.test_nms_cuda to see if this is caused by duplicate + # scores. + expected_file = self._get_expected_file(strip_suffix=strip_suffix) + expected = torch.load(expected_file) + self.assertEqual(output[0]["scores"], expected[0]["scores"], prec=prec) + + # Note: Fmassa proposed turning off NMS by adapting the threshold + # and then using the Hungarian algorithm as in DETR to find the + # best match between output and expected boxes and eliminate some + # of the flakiness. Worth exploring. + return False # Partial validation performed + + return True # Full validation performed + + full_validation = check_out(out) scripted_model = torch.jit.script(model) scripted_model.eval() @@ -200,9 +217,6 @@ def compute_mean_std(tensor): self.assertEqual(scripted_out[0]["scores"], out[0]["scores"]) # labels currently float in script: need to investigate (though same result) self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"]) - self.assertTrue("boxes" in out[0]) - self.assertTrue("scores" in out[0]) - self.assertTrue("labels" in out[0]) # don't check script because we are compiling it here: # TODO: refactor tests # self.check_script(model, name) @@ -213,7 +227,15 @@ def compute_mean_std(tensor): out = model(model_input) # See autocast_flaky_numerics comment at top of file. if name not in autocast_flaky_numerics: - check_out(out) + full_validation &= check_out(out) + + if not full_validation: + msg = "The output of {} could only be partially validated. " \ + "This is likely due to unit-test flakiness, but you may " \ + "want to do additional manual checks if you made " \ + "significant changes to the codebase.".format(self._testMethodName) + warnings.warn(msg, RuntimeWarning) + raise unittest.SkipTest(msg) def _test_detection_model_validation(self, name): set_rng_seed(0)