Skip to content

Commit

Permalink
Simplify the assertExpected method (pytorch#2965)
Browse files Browse the repository at this point in the history
* Simplify the ACCEPT=True logic in assertExpected().

* Separate the expected filename estimation from assertExpected
  • Loading branch information
datumbox authored and bryant1410 committed Nov 22, 2020
1 parent 6d02f67 commit 05e88ee
Showing 1 changed file with 31 additions and 40 deletions.
71 changes: 31 additions & 40 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,7 @@ def is_iterable(obj):
class TestCase(unittest.TestCase):
precision = 1e-5

def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.
If you call this multiple times in a single function, you must
give a unique subname each time.
strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
def _get_expected_file(self, subname=None, strip_suffix=None):
def remove_prefix_suffix(text, prefix, suffix):
if text.startswith(prefix):
text = text[len(prefix):]
Expand All @@ -128,33 +111,41 @@ def remove_prefix_suffix(text, prefix, suffix):
subname_output = " ({})".format(subname)
expected_file += "_expect.pkl"

def accept_output(update_type):
print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, output))
if not ACCEPT and not os.path.exists(expected_file):
raise RuntimeError(
("No expect file exists for {}{}; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, __main__.__file__, munged_id))

return expected_file

def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using --accept.
If you call this multiple times in a single function, you must
give a unique subname each time.
strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
expected_file = self._get_expected_file(subname, strip_suffix)

if ACCEPT:
print("Accepting updated output for {}:\n\n{}".format(os.path.basename(expected_file), output))
torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
self.assertTrue(binary_size <= MAX_PICKLE_SIZE)

try:
expected = torch.load(expected_file)
except IOError as e:
if e.errno != errno.ENOENT:
raise
elif ACCEPT:
accept_output("output")
return
else:
raise RuntimeError(
("I got this output for {}{}:\n\n{}\n\n"
"No expect file exists; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, subname_output, output, __main__.__file__, munged_id))

if ACCEPT:
try:
self.assertEqual(output, expected, prec=prec)
except Exception:
accept_output("updated output")
else:
expected = torch.load(expected_file)
self.assertEqual(output, expected, prec=prec)

def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
Expand Down

0 comments on commit 05e88ee

Please sign in to comment.