Skip to content

Commit

Permalink
Fix for issue 649.
Browse files Browse the repository at this point in the history
  • Loading branch information
NanduTej committed Jul 18, 2019
1 parent 0a1c826 commit 5fbfff2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ def foo():
return x[0]
"""

correct_example9 = """
def test():
print('test')
"""

correct_example10 = """
def test(some):
if some:
return
print('test')
"""

correct_example11 = """
def test(some):
if some:
return some
print('test')
return None
"""

# Wrong:

wrong_example1 = """
Expand Down Expand Up @@ -103,6 +123,26 @@ def decorator(*args, **kwargs):
return decorator
"""

wrong_example6 = """
def test():
print('test')
return None
"""

wrong_example7 = """
def test():
print('test')
return
"""

wrong_example8 = """
def test(some):
if some:
return
print('test')
return
"""

double_wrong_example1 = """
def some():
if something() == 1:
Expand All @@ -121,6 +161,9 @@ def some():
wrong_example3,
wrong_example4,
wrong_example5,
wrong_example6,
wrong_example7,
wrong_example8
])
def test_wrong_return_variable(
assert_errors,
Expand All @@ -146,6 +189,9 @@ def test_wrong_return_variable(
correct_example6,
correct_example7,
correct_example8,
correct_example9,
correct_example10,
correct_example11,
])
def test_correct_return_statements(
assert_errors,
Expand Down
18 changes: 17 additions & 1 deletion wemake_python_styleguide/visitors/ast/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def _check_variables_for_return(self, node: AnyFunctionDef) -> None:
returns, return_sub_nodes = self._get_return_node_variables(nodes)

returns = {name: returns[name] for name in returns if name in assign}

if not return_sub_nodes:
self._check_return_at_the_end(node)
self._check_for_violations(names, return_sub_nodes, returns)

def _check_for_violations(self, names, return_sub_nodes, returns) -> None:
Expand All @@ -301,6 +302,21 @@ def _check_for_violations(self, names, return_sub_nodes, returns) -> None:
),
)

def _check_return_at_the_end(self, node):
if len(node.body) <= 1:
return
last = node.body[-1]
if isinstance(last, ast.Return):
if last.value is None:
self.add_violation(
InconsistentReturnVariableViolation(last),
)
elif isinstance(last.value, ast.NameConstant) and (last.value.value is None):
self.add_violation(
InconsistentReturnVariableViolation(last),
)


def visit_return_variable(self, node: AnyFunctionDef) -> None:
"""
Helper to get all ``return`` variables in a function at once.
Expand Down

0 comments on commit 5fbfff2

Please sign in to comment.