From bc1d96ca985241752485ba02d0b609283192b49c Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Mon, 22 Oct 2018 19:26:20 -0700 Subject: [PATCH] Add support for inline expect tests. (#12825) Summary: expecttest and test_expecttest are the implementation and tests for this functionality. I wired it up to the --accept flag, but there's also a new environment variable EXPECTTEST_ACCEPT which may be more convenient to trigger. Haven't tested if this works in fbcode. There may be a few expect tests which will benefit from inline treatment, but I just did one to show it works. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/12825 Reviewed By: teng-li Differential Revision: D10448630 Pulled By: ezyang fbshipit-source-id: 3d339f82e2d00891309620a60e13039fa1ed8b46 --- .jenkins/pytorch/macos-test.sh | 1 + .jenkins/pytorch/test.sh | 3 + .jenkins/pytorch/win-test.sh | 2 +- test/common_utils.py | 11 +- test/expecttest.py | 202 +++++++++++++++++++++++++++++++++ test/run_test.py | 1 + test/test_expecttest.py | 105 +++++++++++++++++ test/test_jit.py | 7 +- 8 files changed, 326 insertions(+), 6 deletions(-) create mode 100644 test/expecttest.py create mode 100644 test/test_expecttest.py diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 9a4c287ee2e199..1ce42ebdb2e0c8 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -16,6 +16,7 @@ fi export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH" source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja six +pip install hypothesis if [ -z "${IN_CIRCLECI}" ]; then rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch* fi diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 4bc82c043bfc93..e1505ec80eb029 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -32,6 +32,9 @@ python ./configure.py --bootstrap export PATH="$PWD:$PATH" popd +# TODO: move this to Docker +pip install hypothesis + # DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems # if you're not careful. Check this if you made some changes and the # ASAN test is not working diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index 4fb721364c2dcc..679494fdb5311e 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -53,7 +53,7 @@ call %CONDA_PARENT_DIR%\\Miniconda3\\Scripts\\activate.bat %CONDA_PARENT_DIR%\\M if NOT "%BUILD_ENVIRONMENT%"=="" ( call conda install -y -q numpy mkl cffi pyyaml boto3 ) -pip install ninja future +pip install ninja future hypothesis call "C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Auxiliary\\Build\\vcvarsall.bat" x86_amd64 diff --git a/test/common_utils.py b/test/common_utils.py index 84f591add4837c..414c6dd8aae17e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -27,6 +27,8 @@ import __main__ import errno +import expecttest + import torch import torch.cuda from torch._utils_internal import get_writable_path @@ -44,7 +46,8 @@ parser.add_argument('--accept', action='store_true') args, remaining = parser.parse_known_args() SEED = args.seed -ACCEPT = args.accept +if not expecttest.ACCEPT: + expecttest.ACCEPT = args.accept UNITTEST_ARGS = [sys.argv[0]] + remaining torch.manual_seed(SEED) @@ -241,7 +244,7 @@ def __exit__(self, exec_type, exec_value, traceback): self.name, after - before, i)) -class TestCase(unittest.TestCase): +class TestCase(expecttest.TestCase): precision = 1e-5 maxDiff = None _do_cuda_memory_leak_check = False @@ -530,7 +533,7 @@ def accept_output(update_type): except IOError as e: if e.errno != errno.ENOENT: raise - elif ACCEPT: + elif expecttest.ACCEPT: return accept_output("output") else: raise RuntimeError( @@ -543,7 +546,7 @@ def accept_output(update_type): expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected) s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s) - if ACCEPT: + if expecttest.ACCEPT: if expected != s: return accept_output("updated output") else: diff --git a/test/expecttest.py b/test/expecttest.py new file mode 100644 index 00000000000000..f6f2649b57cf86 --- /dev/null +++ b/test/expecttest.py @@ -0,0 +1,202 @@ +import re +import unittest +import traceback +import os +import string + + +ACCEPT = os.getenv('EXPECTTEST_ACCEPT') + + +def nth_line(src, lineno): + """ + Compute the starting index of the n-th line (where n is 1-indexed) + + >>> nth_line("aaa\\nbb\\nc", 2) + 4 + """ + assert lineno >= 1 + pos = 0 + for _ in range(lineno - 1): + pos = src.find('\n', pos) + 1 + return pos + + +def nth_eol(src, lineno): + """ + Compute the ending index of the n-th line (before the newline, + where n is 1-indexed) + + >>> nth_eol("aaa\\nbb\\nc", 2) + 6 + """ + assert lineno >= 1 + pos = -1 + for _ in range(lineno): + pos = src.find('\n', pos + 1) + if pos == -1: + return len(src) + return pos + + +def normalize_nl(t): + return t.replace('\r\n', '\n').replace('\r', '\n') + + +def escape_trailing_quote(s, quote): + if s and s[-1] == quote: + return s[:-1] + '\\' + quote + else: + return s + + +class EditHistory(object): + def __init__(self): + self.state = {} + + def adjust_lineno(self, fn, lineno): + if fn not in self.state: + return lineno + for edit_loc, edit_diff in self.state[fn]: + if lineno > edit_loc: + lineno += edit_diff + return lineno + + def seen_file(self, fn): + return fn in self.state + + def record_edit(self, fn, lineno, delta): + self.state.setdefault(fn, []).append((lineno, delta)) + + +EDIT_HISTORY = EditHistory() + + +def ok_for_raw_triple_quoted_string(s, quote): + """ + Is this string representable inside a raw triple-quoted string? + Due to the fact that backslashes are always treated literally, + some strings are not representable. + + >>> ok_for_raw_triple_quoted_string("blah", quote="'") + True + >>> ok_for_raw_triple_quoted_string("'", quote="'") + False + >>> ok_for_raw_triple_quoted_string("a ''' b", quote="'") + False + """ + return quote * 3 not in s and (not s or s[-1] not in [quote, '\\']) + + +# This operates on the REVERSED string (that's why suffix is first) +RE_EXPECT = re.compile(r"^(?P[^\n]*?)" + r"(?P'''|" r'""")' + r"(?P.*?)" + r"(?P=quote)" + r"(?Pr?)", re.DOTALL) + + +def replace_string_literal(src, lineno, new_string): + r""" + Replace a triple quoted string literal with new contents. + Only handles printable ASCII correctly at the moment. This + will preserve the quote style of the original string, and + makes a best effort to preserve raw-ness (unless it is impossible + to do so.) + + Returns a tuple of the replaced string, as well as a delta of + number of lines added/removed. + + >>> replace_string_literal("'''arf'''", 1, "barf") + ("'''barf'''", 0) + >>> r = replace_string_literal(" moo = '''arf'''", 1, "'a'\n\\b\n") + >>> print(r[0]) + moo = '''\ + 'a' + \\b + ''' + >>> r[1] + 3 + >>> replace_string_literal(" moo = '''\\\narf'''", 2, "'a'\n\\b\n")[1] + 2 + >>> print(replace_string_literal(" f('''\"\"\"''')", 1, "a ''' b")[0]) + f('''a \'\'\' b''') + """ + # Haven't implemented correct escaping for non-printable characters + assert all(c in string.printable for c in new_string) + i = nth_eol(src, lineno) + new_string = normalize_nl(new_string) + + delta = [new_string.count("\n")] + if delta[0] > 0: + delta[0] += 1 # handle the extra \\\n + + def replace(m): + s = new_string + raw = m.group('raw') == 'r' + if not raw or not ok_for_raw_triple_quoted_string(s, quote=m.group('quote')[0]): + raw = False + s = s.replace('\\', '\\\\') + if m.group('quote') == "'''": + s = escape_trailing_quote(s, "'").replace("'''", r"\'\'\'") + else: + s = escape_trailing_quote(s, '"').replace('"""', r'\"\"\"') + + new_body = "\\\n" + s if "\n" in s and not raw else s + delta[0] -= m.group('body').count("\n") + + return ''.join([m.group('suffix'), + m.group('quote'), + new_body[::-1], + m.group('quote'), + 'r' if raw else '', + ]) + + # Having to do this in reverse is very irritating, but it's the + # only way to make the non-greedy matches work correctly. + return (RE_EXPECT.sub(replace, src[:i][::-1], count=1)[::-1] + src[i:], delta[0]) + + +class TestCase(unittest.TestCase): + longMessage = True + + def assertExpectedInline(self, actual, expect, skip=0): + if ACCEPT: + if actual != expect: + # current frame and parent frame, plus any requested skip + tb = traceback.extract_stack(limit=2 + skip) + fn, lineno, _, _ = tb[0] + print("Accepting new output for {} at {}:{}".format(self.id(), fn, lineno)) + with open(fn, 'r+') as f: + old = f.read() + + # compute the change in lineno + lineno = EDIT_HISTORY.adjust_lineno(fn, lineno) + new, delta = replace_string_literal(old, lineno, actual) + + assert old != new, "Failed to substitute string at {}:{}".format(fn, lineno) + + # Only write the backup file the first time we hit the + # file + if not EDIT_HISTORY.seen_file(fn): + with open(fn + ".bak", 'w') as f_bak: + f_bak.write(old) + f.seek(0) + f.truncate(0) + + f.write(new) + + EDIT_HISTORY.record_edit(fn, lineno, delta) + else: + help_text = ("To accept the new output, re-run test with " + "envvar EXPECTTEST_ACCEPT=1 (we recommend " + "staging/committing your changes before doing this)") + if hasattr(self, "assertMultiLineEqual"): + self.assertMultiLineEqual(expect, actual, msg=help_text) + else: + self.assertEqual(expect, actual, msg=help_text) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/test/run_test.py b/test/run_test.py index 3cea2dee1382c2..1f5a337fb40e6d 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -25,6 +25,7 @@ 'dataloader', 'distributed', 'distributions', + 'expecttest', 'indexing', 'jit', 'multiprocessing', diff --git a/test/test_expecttest.py b/test/test_expecttest.py new file mode 100644 index 00000000000000..040012fc9526fb --- /dev/null +++ b/test/test_expecttest.py @@ -0,0 +1,105 @@ +import expecttest + +import unittest +import string +import textwrap +import doctest + +import hypothesis +from hypothesis.strategies import text, integers, composite, sampled_from, booleans + + +@composite +def text_lineno(draw): + t = draw(text("a\n")) + lineno = draw(integers(min_value=1, max_value=t.count("\n") + 1)) + return (t, lineno) + + +class TestExpectTest(expecttest.TestCase): + @hypothesis.given(text_lineno()) + def test_nth_line_ref(self, t_lineno): + t, lineno = t_lineno + hypothesis.event("lineno = {}".format(lineno)) + + def nth_line_ref(src, lineno): + xs = src.split("\n")[:lineno] + xs[-1] = '' + return len("\n".join(xs)) + self.assertEqual(expecttest.nth_line(t, lineno), nth_line_ref(t, lineno)) + + @hypothesis.given(text(string.printable), booleans(), sampled_from(['"', "'"])) + def test_replace_string_literal_roundtrip(self, t, raw, quote): + if raw: + hypothesis.assume(expecttest.ok_for_raw_triple_quoted_string(t, quote=quote)) + prog = """\ + r = {r}{quote}placeholder{quote} + r2 = {r}{quote}placeholder2{quote} + r3 = {r}{quote}placeholder3{quote} + """.format(r='r' if raw else '', quote=quote * 3) + new_prog = expecttest.replace_string_literal(textwrap.dedent(prog), 2, t)[0] + ns = {} + exec(new_prog, ns) + msg = "program was:\n{}".format(new_prog) + self.assertEqual(ns['r'], 'placeholder', msg=msg) # noqa: F821 + self.assertEqual(ns['r2'], expecttest.normalize_nl(t), msg=msg) # noqa: F821 + self.assertEqual(ns['r3'], 'placeholder3', msg=msg) # noqa: F821 + + def test_sample(self): + prog = r""" +single_single('''0''') +single_multi('''1''') +multi_single('''\ +2 +''') +multi_multi_less('''\ +3 +4 +''') +multi_multi_same('''\ +5 +''') +multi_multi_more('''\ +6 +''') +""" + # NB: These are the end of the statements, not beginning + # TODO: Test other permutations of these edits + edits = [(2, "a"), + (3, "b\n"), + (6, "c"), + (10, "d\n"), + (13, "e\n"), + (16, "f\ng\n")] + history = expecttest.EditHistory() + fn = 'not_a_real_file.py' + for lineno, actual in edits: + lineno = history.adjust_lineno(fn, lineno) + prog, delta = expecttest.replace_string_literal(prog, lineno, actual) + history.record_edit(fn, lineno, delta) + self.assertExpectedInline(prog, r""" +single_single('''a''') +single_multi('''\ +b +''') +multi_single('''c''') +multi_multi_less('''\ +d +''') +multi_multi_same('''\ +e +''') +multi_multi_more('''\ +f +g +''') +""") + + +def load_tests(loader, tests, ignore): + tests.addTests(doctest.DocTestSuite(expecttest)) + return tests + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_jit.py b/test/test_jit.py index 7566358a39c58d..3246c5180b535a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8510,7 +8510,12 @@ def test_script_graph_contains_custom_op(self): @torch.jit.script def func(x): return torch.ops.aten.relu(x) - self.assertExpected(canonical(func.graph)) + self.assertExpectedInline(canonical(func.graph), '''\ +graph(%x : Dynamic) { + %1 : Dynamic = aten::relu(%x) + return (%1); +} +''') # UBSAN per-function exclusions don't seem to work with OpenMP pragmas, # and we have to disable the failing tests here instead.