Skip to content

Commit

Permalink
Unroll calls to any #5062 (#5103)
Browse files Browse the repository at this point in the history
Unroll calls to any #5062
  • Loading branch information
nicoddemus authored May 27, 2019
2 parents 0a57124 + 22d91a3 commit 2b9ca34
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
1 change: 1 addition & 0 deletions changelog/5062.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Unroll calls to ``all`` to full for-loops for better failure messages, especially when using Generator Expressions.
25 changes: 25 additions & 0 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,8 @@ def visit_Call_35(self, call):
"""
visit `ast.Call` nodes on Python3.5 and after
"""
if isinstance(call.func, ast.Name) and call.func.id == "all":
return self._visit_all(call)
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
Expand All @@ -972,6 +974,27 @@ def visit_Call_35(self, call):
outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
return res, outer_expl

def _visit_all(self, call):
"""Special rewrite for the builtin all function, see #5062"""
if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)):
return
gen_exp = call.args[0]
assertion_module = ast.Module(
body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)]
)
AssertionRewriter(module_path=None, config=None).run(assertion_module)
for_loop = ast.For(
iter=gen_exp.generators[0].iter,
target=gen_exp.generators[0].target,
body=assertion_module.body,
orelse=[],
)
self.statements.append(for_loop)
return (
ast.Num(n=1),
"",
) # Return an empty expression, all the asserts are in the for_loop

def visit_Starred(self, starred):
# From Python 3.5, a Starred node can appear in a function call
res, expl = self.visit(starred.value)
Expand All @@ -982,6 +1005,8 @@ def visit_Call_legacy(self, call):
"""
visit `ast.Call nodes on 3.4 and below`
"""
if isinstance(call.func, ast.Name) and call.func.id == "all":
return self._visit_all(call)
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
Expand Down
53 changes: 53 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,12 @@ def __repr__(self):
else:
assert lines == ["assert 0 == 1\n + where 1 = \\n{ \\n~ \\n}.a"]

def test_unroll_expression(self):
def f():
assert all(x == 1 for x in range(10))

assert "0 == 1" in getmsg(f)

def test_custom_repr_non_ascii(self):
def f():
class A(object):
Expand All @@ -671,6 +677,53 @@ def __repr__(self):
assert "UnicodeDecodeError" not in msg
assert "UnicodeEncodeError" not in msg

def test_unroll_generator(self, testdir):
testdir.makepyfile(
"""
def check_even(num):
if num % 2 == 0:
return True
return False
def test_generator():
odd_list = list(range(1,9,2))
assert all(check_even(num) for num in odd_list)"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])

def test_unroll_list_comprehension(self, testdir):
testdir.makepyfile(
"""
def check_even(num):
if num % 2 == 0:
return True
return False
def test_list_comprehension():
odd_list = list(range(1,9,2))
assert all([check_even(num) for num in odd_list])"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])

def test_for_loop(self, testdir):
testdir.makepyfile(
"""
def check_even(num):
if num % 2 == 0:
return True
return False
def test_for_loop():
odd_list = list(range(1,9,2))
for num in odd_list:
assert check_even(num)
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])


class TestRewriteOnImport(object):
def test_pycache_is_a_file(self, testdir):
Expand Down

0 comments on commit 2b9ca34

Please sign in to comment.