Skip to content

Commit

Permalink
Add support for inline expect tests. (pytorch#12825)
Browse files Browse the repository at this point in the history
Summary:
expecttest and test_expecttest are the implementation and tests
for this functionality.  I wired it up to the --accept flag,
but there's also a new environment variable EXPECTTEST_ACCEPT
which may be more convenient to trigger.  Haven't tested if this
works in fbcode.

There may be a few expect tests which will benefit from inline
treatment, but I just did one to show it works.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#12825

Reviewed By: teng-li

Differential Revision: D10448630

Pulled By: ezyang

fbshipit-source-id: 3d339f82e2d00891309620a60e13039fa1ed8b46
  • Loading branch information
ezyang authored and facebook-github-bot committed Oct 23, 2018
1 parent 952df2b commit bc1d96c
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 6 deletions.
1 change: 1 addition & 0 deletions .jenkins/pytorch/macos-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fi
export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH"
source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate
conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja six
pip install hypothesis
if [ -z "${IN_CIRCLECI}" ]; then
rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch*
fi
Expand Down
3 changes: 3 additions & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ python ./configure.py --bootstrap
export PATH="$PWD:$PATH"
popd

# TODO: move this to Docker
pip install hypothesis

# DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems
# if you're not careful. Check this if you made some changes and the
# ASAN test is not working
Expand Down
2 changes: 1 addition & 1 deletion .jenkins/pytorch/win-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ call %CONDA_PARENT_DIR%\\Miniconda3\\Scripts\\activate.bat %CONDA_PARENT_DIR%\\M
if NOT "%BUILD_ENVIRONMENT%"=="" (
call conda install -y -q numpy mkl cffi pyyaml boto3
)
pip install ninja future
pip install ninja future hypothesis
call "C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Auxiliary\\Build\\vcvarsall.bat" x86_amd64
Expand Down
11 changes: 7 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import __main__
import errno

import expecttest

import torch
import torch.cuda
from torch._utils_internal import get_writable_path
Expand All @@ -44,7 +46,8 @@
parser.add_argument('--accept', action='store_true')
args, remaining = parser.parse_known_args()
SEED = args.seed
ACCEPT = args.accept
if not expecttest.ACCEPT:
expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)

Expand Down Expand Up @@ -241,7 +244,7 @@ def __exit__(self, exec_type, exec_value, traceback):
self.name, after - before, i))


class TestCase(unittest.TestCase):
class TestCase(expecttest.TestCase):
precision = 1e-5
maxDiff = None
_do_cuda_memory_leak_check = False
Expand Down Expand Up @@ -530,7 +533,7 @@ def accept_output(update_type):
except IOError as e:
if e.errno != errno.ENOENT:
raise
elif ACCEPT:
elif expecttest.ACCEPT:
return accept_output("output")
else:
raise RuntimeError(
Expand All @@ -543,7 +546,7 @@ def accept_output(update_type):
expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)

if ACCEPT:
if expecttest.ACCEPT:
if expected != s:
return accept_output("updated output")
else:
Expand Down
202 changes: 202 additions & 0 deletions test/expecttest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import re
import unittest
import traceback
import os
import string


ACCEPT = os.getenv('EXPECTTEST_ACCEPT')


def nth_line(src, lineno):
"""
Compute the starting index of the n-th line (where n is 1-indexed)
>>> nth_line("aaa\\nbb\\nc", 2)
4
"""
assert lineno >= 1
pos = 0
for _ in range(lineno - 1):
pos = src.find('\n', pos) + 1
return pos


def nth_eol(src, lineno):
"""
Compute the ending index of the n-th line (before the newline,
where n is 1-indexed)
>>> nth_eol("aaa\\nbb\\nc", 2)
6
"""
assert lineno >= 1
pos = -1
for _ in range(lineno):
pos = src.find('\n', pos + 1)
if pos == -1:
return len(src)
return pos


def normalize_nl(t):
return t.replace('\r\n', '\n').replace('\r', '\n')


def escape_trailing_quote(s, quote):
if s and s[-1] == quote:
return s[:-1] + '\\' + quote
else:
return s


class EditHistory(object):
def __init__(self):
self.state = {}

def adjust_lineno(self, fn, lineno):
if fn not in self.state:
return lineno
for edit_loc, edit_diff in self.state[fn]:
if lineno > edit_loc:
lineno += edit_diff
return lineno

def seen_file(self, fn):
return fn in self.state

def record_edit(self, fn, lineno, delta):
self.state.setdefault(fn, []).append((lineno, delta))


EDIT_HISTORY = EditHistory()


def ok_for_raw_triple_quoted_string(s, quote):
"""
Is this string representable inside a raw triple-quoted string?
Due to the fact that backslashes are always treated literally,
some strings are not representable.
>>> ok_for_raw_triple_quoted_string("blah", quote="'")
True
>>> ok_for_raw_triple_quoted_string("'", quote="'")
False
>>> ok_for_raw_triple_quoted_string("a ''' b", quote="'")
False
"""
return quote * 3 not in s and (not s or s[-1] not in [quote, '\\'])


# This operates on the REVERSED string (that's why suffix is first)
RE_EXPECT = re.compile(r"^(?P<suffix>[^\n]*?)"
r"(?P<quote>'''|" r'""")'
r"(?P<body>.*?)"
r"(?P=quote)"
r"(?P<raw>r?)", re.DOTALL)


def replace_string_literal(src, lineno, new_string):
r"""
Replace a triple quoted string literal with new contents.
Only handles printable ASCII correctly at the moment. This
will preserve the quote style of the original string, and
makes a best effort to preserve raw-ness (unless it is impossible
to do so.)
Returns a tuple of the replaced string, as well as a delta of
number of lines added/removed.
>>> replace_string_literal("'''arf'''", 1, "barf")
("'''barf'''", 0)
>>> r = replace_string_literal(" moo = '''arf'''", 1, "'a'\n\\b\n")
>>> print(r[0])
moo = '''\
'a'
\\b
'''
>>> r[1]
3
>>> replace_string_literal(" moo = '''\\\narf'''", 2, "'a'\n\\b\n")[1]
2
>>> print(replace_string_literal(" f('''\"\"\"''')", 1, "a ''' b")[0])
f('''a \'\'\' b''')
"""
# Haven't implemented correct escaping for non-printable characters
assert all(c in string.printable for c in new_string)
i = nth_eol(src, lineno)
new_string = normalize_nl(new_string)

delta = [new_string.count("\n")]
if delta[0] > 0:
delta[0] += 1 # handle the extra \\\n

def replace(m):
s = new_string
raw = m.group('raw') == 'r'
if not raw or not ok_for_raw_triple_quoted_string(s, quote=m.group('quote')[0]):
raw = False
s = s.replace('\\', '\\\\')
if m.group('quote') == "'''":
s = escape_trailing_quote(s, "'").replace("'''", r"\'\'\'")
else:
s = escape_trailing_quote(s, '"').replace('"""', r'\"\"\"')

new_body = "\\\n" + s if "\n" in s and not raw else s
delta[0] -= m.group('body').count("\n")

return ''.join([m.group('suffix'),
m.group('quote'),
new_body[::-1],
m.group('quote'),
'r' if raw else '',
])

# Having to do this in reverse is very irritating, but it's the
# only way to make the non-greedy matches work correctly.
return (RE_EXPECT.sub(replace, src[:i][::-1], count=1)[::-1] + src[i:], delta[0])


class TestCase(unittest.TestCase):
longMessage = True

def assertExpectedInline(self, actual, expect, skip=0):
if ACCEPT:
if actual != expect:
# current frame and parent frame, plus any requested skip
tb = traceback.extract_stack(limit=2 + skip)
fn, lineno, _, _ = tb[0]
print("Accepting new output for {} at {}:{}".format(self.id(), fn, lineno))
with open(fn, 'r+') as f:
old = f.read()

# compute the change in lineno
lineno = EDIT_HISTORY.adjust_lineno(fn, lineno)
new, delta = replace_string_literal(old, lineno, actual)

assert old != new, "Failed to substitute string at {}:{}".format(fn, lineno)

# Only write the backup file the first time we hit the
# file
if not EDIT_HISTORY.seen_file(fn):
with open(fn + ".bak", 'w') as f_bak:
f_bak.write(old)
f.seek(0)
f.truncate(0)

f.write(new)

EDIT_HISTORY.record_edit(fn, lineno, delta)
else:
help_text = ("To accept the new output, re-run test with "
"envvar EXPECTTEST_ACCEPT=1 (we recommend "
"staging/committing your changes before doing this)")
if hasattr(self, "assertMultiLineEqual"):
self.assertMultiLineEqual(expect, actual, msg=help_text)
else:
self.assertEqual(expect, actual, msg=help_text)


if __name__ == "__main__":
import doctest
doctest.testmod()
1 change: 1 addition & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
'dataloader',
'distributed',
'distributions',
'expecttest',
'indexing',
'jit',
'multiprocessing',
Expand Down
105 changes: 105 additions & 0 deletions test/test_expecttest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import expecttest

import unittest
import string
import textwrap
import doctest

import hypothesis
from hypothesis.strategies import text, integers, composite, sampled_from, booleans


@composite
def text_lineno(draw):
t = draw(text("a\n"))
lineno = draw(integers(min_value=1, max_value=t.count("\n") + 1))
return (t, lineno)


class TestExpectTest(expecttest.TestCase):
@hypothesis.given(text_lineno())
def test_nth_line_ref(self, t_lineno):
t, lineno = t_lineno
hypothesis.event("lineno = {}".format(lineno))

def nth_line_ref(src, lineno):
xs = src.split("\n")[:lineno]
xs[-1] = ''
return len("\n".join(xs))
self.assertEqual(expecttest.nth_line(t, lineno), nth_line_ref(t, lineno))

@hypothesis.given(text(string.printable), booleans(), sampled_from(['"', "'"]))
def test_replace_string_literal_roundtrip(self, t, raw, quote):
if raw:
hypothesis.assume(expecttest.ok_for_raw_triple_quoted_string(t, quote=quote))
prog = """\
r = {r}{quote}placeholder{quote}
r2 = {r}{quote}placeholder2{quote}
r3 = {r}{quote}placeholder3{quote}
""".format(r='r' if raw else '', quote=quote * 3)
new_prog = expecttest.replace_string_literal(textwrap.dedent(prog), 2, t)[0]
ns = {}
exec(new_prog, ns)
msg = "program was:\n{}".format(new_prog)
self.assertEqual(ns['r'], 'placeholder', msg=msg) # noqa: F821
self.assertEqual(ns['r2'], expecttest.normalize_nl(t), msg=msg) # noqa: F821
self.assertEqual(ns['r3'], 'placeholder3', msg=msg) # noqa: F821

def test_sample(self):
prog = r"""
single_single('''0''')
single_multi('''1''')
multi_single('''\
2
''')
multi_multi_less('''\
3
4
''')
multi_multi_same('''\
5
''')
multi_multi_more('''\
6
''')
"""
# NB: These are the end of the statements, not beginning
# TODO: Test other permutations of these edits
edits = [(2, "a"),
(3, "b\n"),
(6, "c"),
(10, "d\n"),
(13, "e\n"),
(16, "f\ng\n")]
history = expecttest.EditHistory()
fn = 'not_a_real_file.py'
for lineno, actual in edits:
lineno = history.adjust_lineno(fn, lineno)
prog, delta = expecttest.replace_string_literal(prog, lineno, actual)
history.record_edit(fn, lineno, delta)
self.assertExpectedInline(prog, r"""
single_single('''a''')
single_multi('''\
b
''')
multi_single('''c''')
multi_multi_less('''\
d
''')
multi_multi_same('''\
e
''')
multi_multi_more('''\
f
g
''')
""")


def load_tests(loader, tests, ignore):
tests.addTests(doctest.DocTestSuite(expecttest))
return tests


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit bc1d96c

Please sign in to comment.