From 07ddea65b0a6b2b6ac09dfc064d3a73ac20f96da Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 7 Aug 2024 22:59:15 +0800 Subject: [PATCH] fix[lang]: fix certain varinfo comparisons (#4164) for `VarInfo`s which are declared in memory, the `VarInfo` initialization is missing `decl_node`, and therefore different variables with the same type get detected as overlapping in loop iterator modification detection. this commit properly initializes memory `VarInfo`s with the appropriate `decl_node`. --------- Co-authored-by: cyberthirst --- .../features/iteration/test_for_in_list.py | 33 +++++++++++++++++++ tests/unit/ast/test_ast_dict.py | 24 ++++++++++++-- vyper/semantics/analysis/base.py | 5 ++- vyper/semantics/analysis/local.py | 6 ++-- 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index 184e6a2859..036e7c0647 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -897,3 +897,36 @@ def foo(): compile_code(main, input_bundle=input_bundle) assert e.value._message == "Cannot modify loop variable `queue`" + + +def test_iterator_modification_memory(get_contract): + code = """ +@external +def foo() -> DynArray[uint256, 10]: + # check VarInfos are distinguished by decl_node when they have same type + alreadyDone: DynArray[uint256, 10] = [] + _assets: DynArray[uint256, 10] = [1, 2, 3, 4, 3, 2, 1] + for a: uint256 in _assets: + if a in alreadyDone: + continue + alreadyDone.append(a) + return alreadyDone + """ + c = get_contract(code) + assert c.foo() == [1, 2, 3, 4] + + +def test_iterator_modification_func_arg(get_contract): + code = """ +@internal +def boo(a: DynArray[uint256, 12] = [], b: DynArray[uint256, 12] = []) -> DynArray[uint256, 12]: + for i: uint256 in a: + b.append(i) + return b + +@external +def foo() -> DynArray[uint256, 12]: + return self.boo([1, 2, 3]) + """ + c = get_contract(code) + assert c.foo() == [1, 2, 3] diff --git a/tests/unit/ast/test_ast_dict.py b/tests/unit/ast/test_ast_dict.py index 81c3dc46fa..07da3c0ace 100644 --- a/tests/unit/ast/test_ast_dict.py +++ b/tests/unit/ast/test_ast_dict.py @@ -1255,7 +1255,13 @@ def qux2(): { "annotation": {"ast_type": "Name", "id": "uint256"}, "ast_type": "AnnAssign", - "target": {"ast_type": "Name", "id": "x"}, + "target": { + "ast_type": "Name", + "id": "x", + "variable_reads": [ + {"name": "x", "decl_node": {"node_id": 15, "source_id": 0}, "access_path": []} + ], + }, "value": { "ast_type": "Attribute", "attr": "counter", @@ -1300,7 +1306,13 @@ def qux2(): { "annotation": {"ast_type": "Name", "id": "uint256"}, "ast_type": "AnnAssign", - "target": {"ast_type": "Name", "id": "x"}, + "target": { + "ast_type": "Name", + "id": "x", + "variable_reads": [ + {"name": "x", "decl_node": {"node_id": 35, "source_id": 0}, "access_path": []} + ], + }, "value": { "ast_type": "Attribute", "attr": "counter", @@ -1317,7 +1329,13 @@ def qux2(): { "annotation": {"ast_type": "Name", "id": "uint256"}, "ast_type": "AnnAssign", - "target": {"ast_type": "Name", "id": "y"}, + "target": { + "ast_type": "Name", + "id": "y", + "variable_reads": [ + {"name": "y", "decl_node": {"node_id": 44, "source_id": 0}, "access_path": []} + ], + }, "value": { "ast_type": "Attribute", "attr": "counter", diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 026e0626e7..65bc8df3ab 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -257,7 +257,10 @@ def to_dict(self): # map SUBSCRIPT_ACCESS to `"$subscript_access"` (which is an identifier # which can't be constructed by the user) path = ["$subscript_access" if s is self.SUBSCRIPT_ACCESS else s for s in self.path] - varname = var.decl_node.target.id + if isinstance(var.decl_node, vy_ast.arg): + varname = var.decl_node.arg + else: + varname = var.decl_node.target.id decl_node = var.decl_node.get_id_dict() ret = {"name": varname, "decl_node": decl_node, "access_path": path} diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 26c6a4ef9f..b5292b1dad 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -317,7 +317,7 @@ def analyze(self): for arg in self.func.arguments: self.namespace[arg.name] = VarInfo( - arg.typ, location=location, modifiability=modifiability + arg.typ, location=location, modifiability=modifiability, decl_node=arg.ast_source ) for node in self.fn_node.body: @@ -363,7 +363,7 @@ def visit_AnnAssign(self, node): # validate the value before adding it to the namespace self.expr_visitor.visit(node.value, typ) - self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY) + self.namespace[name] = VarInfo(typ, location=DataLocation.MEMORY, decl_node=node) self.expr_visitor.visit(node.target, typ) @@ -575,7 +575,7 @@ def visit_For(self, node): target_name = node.target.target.id # maybe we should introduce a new Modifiability: LOOP_VARIABLE self.namespace[target_name] = VarInfo( - target_type, modifiability=Modifiability.RUNTIME_CONSTANT + target_type, modifiability=Modifiability.RUNTIME_CONSTANT, decl_node=node.target ) self.expr_visitor.visit(node.target.target, target_type)