forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for inline expect tests. (pytorch#12825)
Summary: expecttest and test_expecttest are the implementation and tests for this functionality. I wired it up to the --accept flag, but there's also a new environment variable EXPECTTEST_ACCEPT which may be more convenient to trigger. Haven't tested if this works in fbcode. There may be a few expect tests which will benefit from inline treatment, but I just did one to show it works. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#12825 Reviewed By: teng-li Differential Revision: D10448630 Pulled By: ezyang fbshipit-source-id: 3d339f82e2d00891309620a60e13039fa1ed8b46
- Loading branch information
1 parent
952df2b
commit bc1d96c
Showing
8 changed files
with
326 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import re | ||
import unittest | ||
import traceback | ||
import os | ||
import string | ||
|
||
|
||
ACCEPT = os.getenv('EXPECTTEST_ACCEPT') | ||
|
||
|
||
def nth_line(src, lineno): | ||
""" | ||
Compute the starting index of the n-th line (where n is 1-indexed) | ||
>>> nth_line("aaa\\nbb\\nc", 2) | ||
4 | ||
""" | ||
assert lineno >= 1 | ||
pos = 0 | ||
for _ in range(lineno - 1): | ||
pos = src.find('\n', pos) + 1 | ||
return pos | ||
|
||
|
||
def nth_eol(src, lineno): | ||
""" | ||
Compute the ending index of the n-th line (before the newline, | ||
where n is 1-indexed) | ||
>>> nth_eol("aaa\\nbb\\nc", 2) | ||
6 | ||
""" | ||
assert lineno >= 1 | ||
pos = -1 | ||
for _ in range(lineno): | ||
pos = src.find('\n', pos + 1) | ||
if pos == -1: | ||
return len(src) | ||
return pos | ||
|
||
|
||
def normalize_nl(t): | ||
return t.replace('\r\n', '\n').replace('\r', '\n') | ||
|
||
|
||
def escape_trailing_quote(s, quote): | ||
if s and s[-1] == quote: | ||
return s[:-1] + '\\' + quote | ||
else: | ||
return s | ||
|
||
|
||
class EditHistory(object): | ||
def __init__(self): | ||
self.state = {} | ||
|
||
def adjust_lineno(self, fn, lineno): | ||
if fn not in self.state: | ||
return lineno | ||
for edit_loc, edit_diff in self.state[fn]: | ||
if lineno > edit_loc: | ||
lineno += edit_diff | ||
return lineno | ||
|
||
def seen_file(self, fn): | ||
return fn in self.state | ||
|
||
def record_edit(self, fn, lineno, delta): | ||
self.state.setdefault(fn, []).append((lineno, delta)) | ||
|
||
|
||
EDIT_HISTORY = EditHistory() | ||
|
||
|
||
def ok_for_raw_triple_quoted_string(s, quote): | ||
""" | ||
Is this string representable inside a raw triple-quoted string? | ||
Due to the fact that backslashes are always treated literally, | ||
some strings are not representable. | ||
>>> ok_for_raw_triple_quoted_string("blah", quote="'") | ||
True | ||
>>> ok_for_raw_triple_quoted_string("'", quote="'") | ||
False | ||
>>> ok_for_raw_triple_quoted_string("a ''' b", quote="'") | ||
False | ||
""" | ||
return quote * 3 not in s and (not s or s[-1] not in [quote, '\\']) | ||
|
||
|
||
# This operates on the REVERSED string (that's why suffix is first) | ||
RE_EXPECT = re.compile(r"^(?P<suffix>[^\n]*?)" | ||
r"(?P<quote>'''|" r'""")' | ||
r"(?P<body>.*?)" | ||
r"(?P=quote)" | ||
r"(?P<raw>r?)", re.DOTALL) | ||
|
||
|
||
def replace_string_literal(src, lineno, new_string): | ||
r""" | ||
Replace a triple quoted string literal with new contents. | ||
Only handles printable ASCII correctly at the moment. This | ||
will preserve the quote style of the original string, and | ||
makes a best effort to preserve raw-ness (unless it is impossible | ||
to do so.) | ||
Returns a tuple of the replaced string, as well as a delta of | ||
number of lines added/removed. | ||
>>> replace_string_literal("'''arf'''", 1, "barf") | ||
("'''barf'''", 0) | ||
>>> r = replace_string_literal(" moo = '''arf'''", 1, "'a'\n\\b\n") | ||
>>> print(r[0]) | ||
moo = '''\ | ||
'a' | ||
\\b | ||
''' | ||
>>> r[1] | ||
3 | ||
>>> replace_string_literal(" moo = '''\\\narf'''", 2, "'a'\n\\b\n")[1] | ||
2 | ||
>>> print(replace_string_literal(" f('''\"\"\"''')", 1, "a ''' b")[0]) | ||
f('''a \'\'\' b''') | ||
""" | ||
# Haven't implemented correct escaping for non-printable characters | ||
assert all(c in string.printable for c in new_string) | ||
i = nth_eol(src, lineno) | ||
new_string = normalize_nl(new_string) | ||
|
||
delta = [new_string.count("\n")] | ||
if delta[0] > 0: | ||
delta[0] += 1 # handle the extra \\\n | ||
|
||
def replace(m): | ||
s = new_string | ||
raw = m.group('raw') == 'r' | ||
if not raw or not ok_for_raw_triple_quoted_string(s, quote=m.group('quote')[0]): | ||
raw = False | ||
s = s.replace('\\', '\\\\') | ||
if m.group('quote') == "'''": | ||
s = escape_trailing_quote(s, "'").replace("'''", r"\'\'\'") | ||
else: | ||
s = escape_trailing_quote(s, '"').replace('"""', r'\"\"\"') | ||
|
||
new_body = "\\\n" + s if "\n" in s and not raw else s | ||
delta[0] -= m.group('body').count("\n") | ||
|
||
return ''.join([m.group('suffix'), | ||
m.group('quote'), | ||
new_body[::-1], | ||
m.group('quote'), | ||
'r' if raw else '', | ||
]) | ||
|
||
# Having to do this in reverse is very irritating, but it's the | ||
# only way to make the non-greedy matches work correctly. | ||
return (RE_EXPECT.sub(replace, src[:i][::-1], count=1)[::-1] + src[i:], delta[0]) | ||
|
||
|
||
class TestCase(unittest.TestCase): | ||
longMessage = True | ||
|
||
def assertExpectedInline(self, actual, expect, skip=0): | ||
if ACCEPT: | ||
if actual != expect: | ||
# current frame and parent frame, plus any requested skip | ||
tb = traceback.extract_stack(limit=2 + skip) | ||
fn, lineno, _, _ = tb[0] | ||
print("Accepting new output for {} at {}:{}".format(self.id(), fn, lineno)) | ||
with open(fn, 'r+') as f: | ||
old = f.read() | ||
|
||
# compute the change in lineno | ||
lineno = EDIT_HISTORY.adjust_lineno(fn, lineno) | ||
new, delta = replace_string_literal(old, lineno, actual) | ||
|
||
assert old != new, "Failed to substitute string at {}:{}".format(fn, lineno) | ||
|
||
# Only write the backup file the first time we hit the | ||
# file | ||
if not EDIT_HISTORY.seen_file(fn): | ||
with open(fn + ".bak", 'w') as f_bak: | ||
f_bak.write(old) | ||
f.seek(0) | ||
f.truncate(0) | ||
|
||
f.write(new) | ||
|
||
EDIT_HISTORY.record_edit(fn, lineno, delta) | ||
else: | ||
help_text = ("To accept the new output, re-run test with " | ||
"envvar EXPECTTEST_ACCEPT=1 (we recommend " | ||
"staging/committing your changes before doing this)") | ||
if hasattr(self, "assertMultiLineEqual"): | ||
self.assertMultiLineEqual(expect, actual, msg=help_text) | ||
else: | ||
self.assertEqual(expect, actual, msg=help_text) | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
'dataloader', | ||
'distributed', | ||
'distributions', | ||
'expecttest', | ||
'indexing', | ||
'jit', | ||
'multiprocessing', | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import expecttest | ||
|
||
import unittest | ||
import string | ||
import textwrap | ||
import doctest | ||
|
||
import hypothesis | ||
from hypothesis.strategies import text, integers, composite, sampled_from, booleans | ||
|
||
|
||
@composite | ||
def text_lineno(draw): | ||
t = draw(text("a\n")) | ||
lineno = draw(integers(min_value=1, max_value=t.count("\n") + 1)) | ||
return (t, lineno) | ||
|
||
|
||
class TestExpectTest(expecttest.TestCase): | ||
@hypothesis.given(text_lineno()) | ||
def test_nth_line_ref(self, t_lineno): | ||
t, lineno = t_lineno | ||
hypothesis.event("lineno = {}".format(lineno)) | ||
|
||
def nth_line_ref(src, lineno): | ||
xs = src.split("\n")[:lineno] | ||
xs[-1] = '' | ||
return len("\n".join(xs)) | ||
self.assertEqual(expecttest.nth_line(t, lineno), nth_line_ref(t, lineno)) | ||
|
||
@hypothesis.given(text(string.printable), booleans(), sampled_from(['"', "'"])) | ||
def test_replace_string_literal_roundtrip(self, t, raw, quote): | ||
if raw: | ||
hypothesis.assume(expecttest.ok_for_raw_triple_quoted_string(t, quote=quote)) | ||
prog = """\ | ||
r = {r}{quote}placeholder{quote} | ||
r2 = {r}{quote}placeholder2{quote} | ||
r3 = {r}{quote}placeholder3{quote} | ||
""".format(r='r' if raw else '', quote=quote * 3) | ||
new_prog = expecttest.replace_string_literal(textwrap.dedent(prog), 2, t)[0] | ||
ns = {} | ||
exec(new_prog, ns) | ||
msg = "program was:\n{}".format(new_prog) | ||
self.assertEqual(ns['r'], 'placeholder', msg=msg) # noqa: F821 | ||
self.assertEqual(ns['r2'], expecttest.normalize_nl(t), msg=msg) # noqa: F821 | ||
self.assertEqual(ns['r3'], 'placeholder3', msg=msg) # noqa: F821 | ||
|
||
def test_sample(self): | ||
prog = r""" | ||
single_single('''0''') | ||
single_multi('''1''') | ||
multi_single('''\ | ||
2 | ||
''') | ||
multi_multi_less('''\ | ||
3 | ||
4 | ||
''') | ||
multi_multi_same('''\ | ||
5 | ||
''') | ||
multi_multi_more('''\ | ||
6 | ||
''') | ||
""" | ||
# NB: These are the end of the statements, not beginning | ||
# TODO: Test other permutations of these edits | ||
edits = [(2, "a"), | ||
(3, "b\n"), | ||
(6, "c"), | ||
(10, "d\n"), | ||
(13, "e\n"), | ||
(16, "f\ng\n")] | ||
history = expecttest.EditHistory() | ||
fn = 'not_a_real_file.py' | ||
for lineno, actual in edits: | ||
lineno = history.adjust_lineno(fn, lineno) | ||
prog, delta = expecttest.replace_string_literal(prog, lineno, actual) | ||
history.record_edit(fn, lineno, delta) | ||
self.assertExpectedInline(prog, r""" | ||
single_single('''a''') | ||
single_multi('''\ | ||
b | ||
''') | ||
multi_single('''c''') | ||
multi_multi_less('''\ | ||
d | ||
''') | ||
multi_multi_same('''\ | ||
e | ||
''') | ||
multi_multi_more('''\ | ||
f | ||
g | ||
''') | ||
""") | ||
|
||
|
||
def load_tests(loader, tests, ignore): | ||
tests.addTests(doctest.DocTestSuite(expecttest)) | ||
return tests | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.