Skip to content

Commit

Permalink
pythongh-121141: add support for copy.replace to AST nodes (python#…
Browse files Browse the repository at this point in the history
  • Loading branch information
picnixz authored Jul 4, 2024
1 parent 94f50f8 commit 9728ead
Show file tree
Hide file tree
Showing 5 changed files with 837 additions and 2 deletions.
8 changes: 6 additions & 2 deletions Doc/whatsnew/3.14.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ Improved Modules
ast
---

Added :func:`ast.compare` for comparing two ASTs.
(Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`.)
* Added :func:`ast.compare` for comparing two ASTs.
(Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`.)

* Add support for :func:`copy.replace` for AST nodes.

(Contributed by Bénédikt Tran in :gh:`121141`.)

os
--
Expand Down
272 changes: 272 additions & 0 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,25 @@ def test_none_checks(self) -> None:
class CopyTests(unittest.TestCase):
"""Test copying and pickling AST nodes."""

@staticmethod
def iter_ast_classes():
"""Iterate over the (native) subclasses of ast.AST recursively.
This excludes the special class ast.Index since its constructor
returns an integer.
"""
def do(cls):
if cls.__module__ != 'ast':
return
if cls is ast.Index:
return

yield cls
for sub in cls.__subclasses__():
yield from do(sub)

yield from do(ast.AST)

def test_pickling(self):
import pickle

Expand Down Expand Up @@ -1218,6 +1237,259 @@ def test_copy_with_parents(self):
)):
self.assertEqual(to_tuple(child.parent), to_tuple(node))

def test_replace_interface(self):
for klass in self.iter_ast_classes():
with self.subTest(klass=klass):
self.assertTrue(hasattr(klass, '__replace__'))

fields = set(klass._fields)
with self.subTest(klass=klass, fields=fields):
node = klass(**dict.fromkeys(fields))
# forbid positional arguments in replace()
self.assertRaises(TypeError, copy.replace, node, 1)
self.assertRaises(TypeError, node.__replace__, 1)

def test_replace_native(self):
for klass in self.iter_ast_classes():
fields = set(klass._fields)
attributes = set(klass._attributes)

with self.subTest(klass=klass, fields=fields, attributes=attributes):
# use of object() to ensure that '==' and 'is'
# behave similarly in ast.compare(node, repl)
old_fields = {field: object() for field in fields}
old_attrs = {attr: object() for attr in attributes}

# check shallow copy
node = klass(**old_fields)
repl = copy.replace(node)
self.assertTrue(ast.compare(node, repl, compare_attributes=True))
# check when passing using attributes (they may be optional!)
node = klass(**old_fields, **old_attrs)
repl = copy.replace(node)
self.assertTrue(ast.compare(node, repl, compare_attributes=True))

for field in fields:
# check when we sometimes have attributes and sometimes not
for init_attrs in [{}, old_attrs]:
node = klass(**old_fields, **init_attrs)
# only change a single field (do not change attributes)
new_value = object()
repl = copy.replace(node, **{field: new_value})
for f in fields:
old_value = old_fields[f]
# assert that there is no side-effect
self.assertIs(getattr(node, f), old_value)
# check the changes
if f != field:
self.assertIs(getattr(repl, f), old_value)
else:
self.assertIs(getattr(repl, f), new_value)
self.assertFalse(ast.compare(node, repl, compare_attributes=True))

for attribute in attributes:
node = klass(**old_fields, **old_attrs)
# only change a single attribute (do not change fields)
new_attr = object()
repl = copy.replace(node, **{attribute: new_attr})
for a in attributes:
old_attr = old_attrs[a]
# assert that there is no side-effect
self.assertIs(getattr(node, a), old_attr)
# check the changes
if a != attribute:
self.assertIs(getattr(repl, a), old_attr)
else:
self.assertIs(getattr(repl, a), new_attr)
self.assertFalse(ast.compare(node, repl, compare_attributes=True))

def test_replace_accept_known_class_fields(self):
nid, ctx = object(), object()

node = ast.Name(id=nid, ctx=ctx)
self.assertIs(node.id, nid)
self.assertIs(node.ctx, ctx)

new_nid = object()
repl = copy.replace(node, id=new_nid)
# assert that there is no side-effect
self.assertIs(node.id, nid)
self.assertIs(node.ctx, ctx)
# check the changes
self.assertIs(repl.id, new_nid)
self.assertIs(repl.ctx, node.ctx) # no changes

def test_replace_accept_known_class_attributes(self):
node = ast.parse('x').body[0].value
self.assertEqual(node.id, 'x')
self.assertEqual(node.lineno, 1)

# constructor allows any type so replace() should do the same
lineno = object()
repl = copy.replace(node, lineno=lineno)
# assert that there is no side-effect
self.assertEqual(node.lineno, 1)
# check the changes
self.assertEqual(repl.id, node.id)
self.assertEqual(repl.ctx, node.ctx)
self.assertEqual(repl.lineno, lineno)

_, _, state = node.__reduce__()
self.assertEqual(state['id'], 'x')
self.assertEqual(state['ctx'], node.ctx)
self.assertEqual(state['lineno'], 1)

_, _, state = repl.__reduce__()
self.assertEqual(state['id'], 'x')
self.assertEqual(state['ctx'], node.ctx)
self.assertEqual(state['lineno'], lineno)

def test_replace_accept_known_custom_class_fields(self):
class MyNode(ast.AST):
_fields = ('name', 'data')
__annotations__ = {'name': str, 'data': object}
__match_args__ = ('name', 'data')

name, data = 'name', object()

node = MyNode(name, data)
self.assertIs(node.name, name)
self.assertIs(node.data, data)
# check shallow copy
repl = copy.replace(node)
# assert that there is no side-effect
self.assertIs(node.name, name)
self.assertIs(node.data, data)
# check the shallow copy
self.assertIs(repl.name, name)
self.assertIs(repl.data, data)

node = MyNode(name, data)
repl_data = object()
# replace custom but known field
repl = copy.replace(node, data=repl_data)
# assert that there is no side-effect
self.assertIs(node.name, name)
self.assertIs(node.data, data)
# check the changes
self.assertIs(repl.name, node.name)
self.assertIs(repl.data, repl_data)

def test_replace_accept_known_custom_class_attributes(self):
class MyNode(ast.AST):
x = 0
y = 1
_attributes = ('x', 'y')

node = MyNode()
self.assertEqual(node.x, 0)
self.assertEqual(node.y, 1)

y = object()
# custom attributes are currently not supported and raise a warning
# because the allowed attributes are hard-coded !
msg = (
"MyNode.__init__ got an unexpected keyword argument 'y'. "
"Support for arbitrary keyword arguments is deprecated and "
"will be removed in Python 3.15"
)
with self.assertWarnsRegex(DeprecationWarning, re.escape(msg)):
repl = copy.replace(node, y=y)
# assert that there is no side-effect
self.assertEqual(node.x, 0)
self.assertEqual(node.y, 1)
# check the changes
self.assertEqual(repl.x, 0)
self.assertEqual(repl.y, y)

def test_replace_ignore_known_custom_instance_fields(self):
node = ast.parse('x').body[0].value
node.extra = extra = object() # add instance 'extra' field
context = node.ctx

# assert initial values
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)
# shallow copy, but drops extra fields
repl = copy.replace(node)
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)
# verify that the 'extra' field is not kept
self.assertIs(repl.id, 'x')
self.assertIs(repl.ctx, context)
self.assertRaises(AttributeError, getattr, repl, 'extra')

# change known native field
repl = copy.replace(node, id='y')
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)
# verify that the 'extra' field is not kept
self.assertIs(repl.id, 'y')
self.assertIs(repl.ctx, context)
self.assertRaises(AttributeError, getattr, repl, 'extra')

def test_replace_reject_missing_field(self):
# case: warn if deleted field is not replaced
node = ast.parse('x').body[0].value
context = node.ctx
del node.id

self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)
msg = "Name.__replace__ missing 1 keyword argument: 'id'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node)
# assert that there is no side-effect
self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)

# case: do not raise if deleted field is replaced
node = ast.parse('x').body[0].value
context = node.ctx
del node.id

self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)
repl = copy.replace(node, id='y')
# assert that there is no side-effect
self.assertRaises(AttributeError, getattr, node, 'id')
self.assertIs(node.ctx, context)
self.assertIs(repl.id, 'y')
self.assertIs(repl.ctx, context)

def test_replace_reject_known_custom_instance_fields_commits(self):
node = ast.parse('x').body[0].value
node.extra = extra = object() # add instance 'extra' field
context = node.ctx

# explicit rejection of known instance fields
self.assertTrue(hasattr(node, 'extra'))
msg = "Name.__replace__ got an unexpected keyword argument 'extra'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node, extra=1)
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertIs(node.extra, extra)

def test_replace_reject_unknown_instance_fields(self):
node = ast.parse('x').body[0].value
context = node.ctx

# explicit rejection of unknown extra fields
self.assertRaises(AttributeError, getattr, node, 'unknown')
msg = "Name.__replace__ got an unexpected keyword argument 'unknown'."
with self.assertRaisesRegex(TypeError, re.escape(msg)):
copy.replace(node, unknown=1)
# assert that there is no side-effect
self.assertIs(node.id, 'x')
self.assertIs(node.ctx, context)
self.assertRaises(AttributeError, getattr, node, 'unknown')

class ASTHelpers_Test(unittest.TestCase):
maxDiff = None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for :func:`copy.replace` to AST nodes. Patch by Bénédikt Tran.
Loading

0 comments on commit 9728ead

Please sign in to comment.