Skip to content

Commit

Permalink
pythongh-122666: Tests for ast optimizations (python#122667)
Browse files Browse the repository at this point in the history
Co-authored-by: Sergey B Kirpichev <[email protected]>
Co-authored-by: Victor Stinner <[email protected]>
Co-authored-by: Jelle Zijlstra <[email protected]>
  • Loading branch information
4 people authored Aug 26, 2024
1 parent 1eed0f9 commit 9f9b00d
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 1 deletion.
1 change: 0 additions & 1 deletion Lib/test/test_ast/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def main():
print("]")
print("main()")
raise SystemExit
unittest.main()

#### EVERYTHING BELOW IS GENERATED BY python Lib/test/test_ast/snippets.py -g #####
exec_results = [
Expand Down
213 changes: 213 additions & 0 deletions Lib/test/test_ast/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3035,3 +3035,216 @@ def test_cli_file_input(self):
self.assertEqual(expected.splitlines(),
res.out.decode("utf8").splitlines())
self.assertEqual(res.rc, 0)


class ASTOptimiziationTests(unittest.TestCase):
binop = {
"+": ast.Add(),
"-": ast.Sub(),
"*": ast.Mult(),
"/": ast.Div(),
"%": ast.Mod(),
"<<": ast.LShift(),
">>": ast.RShift(),
"|": ast.BitOr(),
"^": ast.BitXor(),
"&": ast.BitAnd(),
"//": ast.FloorDiv(),
"**": ast.Pow(),
}

unaryop = {
"~": ast.Invert(),
"+": ast.UAdd(),
"-": ast.USub(),
}

def wrap_expr(self, expr):
return ast.Module(body=[ast.Expr(value=expr)])

def wrap_for(self, for_statement):
return ast.Module(body=[for_statement])

def assert_ast(self, code, non_optimized_target, optimized_target):
non_optimized_tree = ast.parse(code, optimize=-1)
optimized_tree = ast.parse(code, optimize=1)

# Is a non-optimized tree equal to a non-optimized target?
self.assertTrue(
ast.compare(non_optimized_tree, non_optimized_target),
f"{ast.dump(non_optimized_target)} must equal "
f"{ast.dump(non_optimized_tree)}",
)

# Is a optimized tree equal to a non-optimized target?
self.assertFalse(
ast.compare(optimized_tree, non_optimized_target),
f"{ast.dump(non_optimized_target)} must not equal "
f"{ast.dump(non_optimized_tree)}"
)

# Is a optimized tree is equal to an optimized target?
self.assertTrue(
ast.compare(optimized_tree, optimized_target),
f"{ast.dump(optimized_target)} must equal "
f"{ast.dump(optimized_tree)}",
)

def test_folding_binop(self):
code = "1 %s 1"
operators = self.binop.keys()

def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
return ast.BinOp(left=left, op=self.binop[operand], right=right)

for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(create_binop(op))
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

# Multiplication of constant tuples must be folded
code = "(1,) * 3"
non_optimized_target = self.wrap_expr(create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
optimized_target = self.wrap_expr(ast.Constant(eval(code)))

self.assert_ast(code, non_optimized_target, optimized_target)

def test_folding_unaryop(self):
code = "%s1"
operators = self.unaryop.keys()

def create_unaryop(operand):
return ast.UnaryOp(op=self.unaryop[operand], operand=ast.Constant(1))

for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(create_unaryop(op))
optimized_target = self.wrap_expr(ast.Constant(eval(result_code)))

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_not(self):
code = "not (1 %s (1,))"
operators = {
"in": ast.In(),
"is": ast.Is(),
}
opt_operators = {
"is": ast.IsNot(),
"in": ast.NotIn(),
}

def create_notop(operand):
return ast.UnaryOp(op=ast.Not(), operand=ast.Compare(
left=ast.Constant(value=1),
ops=[operators[operand]],
comparators=[ast.Tuple(elts=[ast.Constant(value=1)])]
))

for op in operators.keys():
result_code = code % op
non_optimized_target = self.wrap_expr(create_notop(op))
optimized_target = self.wrap_expr(
ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))])
)

with self.subTest(
result_code=result_code,
non_optimized_target=non_optimized_target,
optimized_target=optimized_target
):
self.assert_ast(result_code, non_optimized_target, optimized_target)

def test_folding_format(self):
code = "'%s' % (a,)"

non_optimized_target = self.wrap_expr(
ast.BinOp(
left=ast.Constant(value="%s"),
op=ast.Mod(),
right=ast.Tuple(elts=[ast.Name(id='a')]))
)
optimized_target = self.wrap_expr(
ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id='a'), conversion=115)
]
)
)

self.assert_ast(code, non_optimized_target, optimized_target)


def test_folding_tuple(self):
code = "(1,)"

non_optimized_target = self.wrap_expr(ast.Tuple(elts=[ast.Constant(1)]))
optimized_target = self.wrap_expr(ast.Constant(value=(1,)))

self.assert_ast(code, non_optimized_target, optimized_target)

def test_folding_comparator(self):
code = "1 %s %s1%s"
operators = [("in", ast.In()), ("not in", ast.NotIn())]
braces = [
("[", "]", ast.List, (1,)),
("{", "}", ast.Set, frozenset({1})),
]
for left, right, non_optimized_comparator, optimized_comparator in braces:
for op, node in operators:
non_optimized_target = self.wrap_expr(ast.Compare(
left=ast.Constant(1), ops=[node],
comparators=[non_optimized_comparator(elts=[ast.Constant(1)])]
))
optimized_target = self.wrap_expr(ast.Compare(
left=ast.Constant(1), ops=[node],
comparators=[ast.Constant(value=optimized_comparator)]
))
self.assert_ast(code % (op, left, right), non_optimized_target, optimized_target)

def test_folding_iter(self):
code = "for _ in %s1%s: pass"
braces = [
("[", "]", ast.List, (1,)),
("{", "}", ast.Set, frozenset({1})),
]

for left, right, ast_cls, optimized_iter in braces:
non_optimized_target = self.wrap_for(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast_cls(elts=[ast.Constant(1)]),
body=[ast.Pass()]
))
optimized_target = self.wrap_for(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast.Constant(value=optimized_iter),
body=[ast.Pass()]
))

self.assert_ast(code % (left, right), non_optimized_target, optimized_target)

def test_folding_subscript(self):
code = "(1,)[0]"

non_optimized_target = self.wrap_expr(
ast.Subscript(value=ast.Tuple(elts=[ast.Constant(value=1)]), slice=ast.Constant(value=0))
)
optimized_target = self.wrap_expr(ast.Constant(value=1))

self.assert_ast(code, non_optimized_target, optimized_target)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9f9b00d

Please sign in to comment.