Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests for simple augmented assignments #1612

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
45 changes: 45 additions & 0 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,51 @@ def test_augassign(self) -> None:
self.assertIsInstance(inferred[0], nodes.Const)
self.assertEqual(inferred[0].value, 3)

def test_augassign_multi(self) -> None:
code = """
a = 1
a += 1
a += 1
print (a)
"""
ast = parse(code, __name__)
inferred = list(test_utils.get_name_node(ast, "a").infer())

self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], nodes.Const)
self.assertEqual(inferred[0].value, 3)

def test_augassign_multi_expr(self) -> None:
code = """
a = 1
a += 1
a += 1
a
"""
ast = parse(code, __name__)
# No inference function for Expr
inferred = list(ast.body[-1].value.infer())

self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], nodes.Const)
self.assertEqual(inferred[0].value, 3)

def test_augassign_multi_list(self) -> None:
code = """
a = []
a += [1]
a += [1]
print (a)
"""
ast = parse(code, __name__)
inferred = list(test_utils.get_name_node(ast, "a").infer())

self.assertEqual(len(inferred), 1)
self.assertIsInstance(inferred[0], nodes.List)
self.assertEqual(len(inferred[0].elts), 2)
self.assertEqual(inferred[0].elts[1].value, 1)
self.assertEqual(inferred[0].elts[0].value, 1)

def test_nonregr_func_arg(self) -> None:
code = """
def foo(self, bar):
Expand Down
57 changes: 57 additions & 0 deletions tests/unittest_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,63 @@ def test_except_assign_after_block_overwritten(self) -> None:
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 8)

def test_except_assign_exclusive_branches_getattr(self) -> None:
"""When a variable is assigned in exlcusive branches, both are returned"""
code = """
try:
1 / 0
except ZeroDivisionError:
x = 10
except NameError:
x = 100
print(x)
"""
astroid = builder.parse(code)
stmts = astroid.getattr("x")
self.assertEqual(len(stmts), 2)

self.assertEqual(stmts[0].lineno, 5)
self.assertEqual(stmts[1].lineno, 7)

def test_except_assign_after_block_overwritten_getattr(self) -> None:
"""When a variable is assigned in an except clause, it is not returned
when it is reassigned and used after the except block.
"""
code = """
try:
1 / 0
except ZeroDivisionError:
x = 10
except NameError:
x = 100
x = 1000
print(x)
"""
astroid = builder.parse(code)
stmts = astroid.getattr("x")
self.assertEqual(len(stmts), 1)
self.assertEqual(stmts[0].lineno, 8)

def test_except_assign_after_block_overwritten_getattr_class(self) -> None:
"""When a variable is assigned in an except clause, it is not returned
when it is reassigned and used after the except block.
"""
code = """
class C:
try:
1 / 0
except ZeroDivisionError:
x = 10
except NameError:
x = 100
x = 1000
print(x)
C.x
"""
astroid = builder.parse(code)
stmts = list(astroid.body[-1].value.infer())
self.assertEqual(len(stmts), 1)


if __name__ == "__main__":
unittest.main()