diff --git a/tests/parser/syntax/test_return_tuple.py b/tests/parser/syntax/test_return_tuple.py index 0c2da518f6..2f442adad1 100644 --- a/tests/parser/syntax/test_return_tuple.py +++ b/tests/parser/syntax/test_return_tuple.py @@ -17,3 +17,124 @@ def test_tuple_return_fail(bad_code): with pytest.raises(FunctionDeclarationException): compiler.compile_code(bad_code) + + +def test_self_call_in_return_tuple(get_contract): + code = """ +@internal +def _foo() -> uint256: + a: uint256[10] = [6,7,8,9,10,11,12,13,14,15] + return 3 + +@external +def foo() -> (uint256, uint256, uint256, uint256, uint256): + return 1, 2, self._foo(), 4, 5 + """ + + c = get_contract(code) + + assert c.foo() == [1, 2, 3, 4, 5] + + +def test_call_in_call(get_contract): + code = """ +@internal +def _foo(a: uint256, b: uint256, c: uint256) -> (uint256, uint256, uint256, uint256, uint256): + return 1, a, b, c, 5 + +@internal +def _foo2() -> uint256: + a: uint256[10] = [6,7,8,9,10,11,12,13,15,16] + return 4 + +@external +def foo() -> (uint256, uint256, uint256, uint256, uint256): + return self._foo(2, 3, self._foo2()) + """ + + c = get_contract(code) + + assert c.foo() == [1, 2, 3, 4, 5] + + +def test_nested_calls_in_tuple_return(get_contract): + code = """ +@internal +def _foo(a: uint256, b: uint256, c: uint256) -> (uint256, uint256): + return 415, 3 + +@internal +def _foo2(a: uint256) -> uint256: + b: uint256[10] = [6,7,8,9,10,11,12,13,14,15] + return 99 + +@internal +def _foo3(a: uint256, b: uint256) -> uint256: + c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] + return 42 + +@internal +def _foo4() -> uint256: + c: uint256[10] = [14,15,16,17,18,19,20,21,22,23] + return 4 + +@external +def foo() -> (uint256, uint256, uint256, uint256, uint256): + return 1, 2, self._foo(6, 7, self._foo2(self._foo3(9, 11)))[1], self._foo4(), 5 + """ + + c = get_contract(code) + + assert c.foo() == [1, 2, 3, 4, 5] + + +def test_external_call_in_return_tuple(get_contract): + code = """ +@view +@external +def foo() -> (uint256, uint256): + return 3, 4 + """ + + code2 = """ +interface Foo: + def foo() -> (uint256, uint256): view + +@external +def foo(a: address) -> (uint256, uint256, uint256, uint256, uint256): + return 1, 2, Foo(a).foo()[0], 4, 5 + """ + + c = get_contract(code) + c2 = get_contract(code2) + + assert c2.foo(c.address) == [1, 2, 3, 4, 5] + + +def test_nested_external_call_in_return_tuple(get_contract): + code = """ +@view +@external +def foo() -> (uint256, uint256): + return 3, 4 + +@view +@external +def bar(a: uint256) -> uint256: + return a+1 + """ + + code2 = """ +interface Foo: + def foo() -> (uint256, uint256): view + def bar(a: uint256) -> uint256: view + +@external +def foo(a: address) -> (uint256, uint256, uint256, uint256, uint256): + return 1, 2, Foo(a).foo()[0], 4, Foo(a).bar(Foo(a).foo()[1]) + """ + + c = get_contract(code) + c2 = get_contract(code2) + + assert c2.foo(c.address) == [1, 2, 3, 4, 5] diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py index 6f23f6bf30..cba251e21a 100644 --- a/vyper/codegen/return_.py +++ b/vyper/codegen/return_.py @@ -97,6 +97,21 @@ def gen_tuple_return(stmt, context, sub): ] return LLLnode.from_list(os, typ=None, pos=getpos(stmt), valency=0) + # for tuple return types where a function is called inside the tuple, we + # process the calls prior to encoding the return data + if sub.value == "seq_unchecked" and sub.args[-1].value == "multi": + encode_out = abi_encode(return_buffer, sub.args[-1], pos=getpos(stmt), returns=True) + load_return_len = ["mload", MemoryPositions.FREE_VAR_SPACE] + os = ( + ["seq"] + + sub.args[:-1] + + [ + ["mstore", MemoryPositions.FREE_VAR_SPACE, encode_out], + make_return_stmt(stmt, context, return_buffer, load_return_len), + ] + ) + return LLLnode.from_list(os, typ=None, pos=getpos(stmt), valency=0) + # for all othe cases we are creating a stack variable named sub_loc to store the location # of the return expression. This is done so that the return expression does not get evaluated # abi-encode uses a function named o_list which evaluate the expression multiple times diff --git a/vyper/parser/expr.py b/vyper/parser/expr.py index 2f20bb2ddf..f6b58ae0e4 100644 --- a/vyper/parser/expr.py +++ b/vyper/parser/expr.py @@ -961,11 +961,35 @@ def struct_literals(expr, name, context): def parse_Tuple(self): if not len(self.expr.elements): return - lll_node = [] + call_lll = [] + multi_lll = [] for node in self.expr.elements: - lll_node.append(Expr(node, self.context).lll_node) - typ = TupleType([x.typ for x in lll_node], is_literal=True) - return LLLnode.from_list(["multi"] + lll_node, typ=typ, pos=getpos(self.expr)) + if isinstance(node, vy_ast.Call): + # for calls inside the tuple, we perform the call prior to building the tuple and + # assign it's result to memory - otherwise there is potential for memory corruption + lll_node = Expr(node, self.context).lll_node + target = LLLnode.from_list( + self.context.new_placeholder(lll_node.typ), + typ=lll_node.typ, + location="memory", + pos=getpos(self.expr), + ) + call_lll.append(make_setter(target, lll_node, "memory", pos=getpos(self.expr))) + multi_lll.append( + LLLnode.from_list( + target, typ=lll_node.typ, pos=getpos(self.expr), location="memory" + ), + ) + else: + multi_lll.append(Expr(node, self.context).lll_node) + + typ = TupleType([x.typ for x in multi_lll], is_literal=True) + multi_lll = LLLnode.from_list(["multi"] + multi_lll, typ=typ, pos=getpos(self.expr)) + if not call_lll: + return multi_lll + + lll_node = ["seq_unchecked"] + call_lll + [multi_lll] + return LLLnode.from_list(lll_node, typ=typ, pos=getpos(self.expr)) # Parse an expression that results in a value @classmethod diff --git a/vyper/parser/self_call.py b/vyper/parser/self_call.py index c14f093c2b..ad5b4d7a10 100644 --- a/vyper/parser/self_call.py +++ b/vyper/parser/self_call.py @@ -1,14 +1,14 @@ import itertools +from vyper import ast as vy_ast from vyper.codegen.abi import abi_decode from vyper.exceptions import ( StateAccessViolation, StructureException, TypeCheckFailure, - TypeMismatch, ) from vyper.parser.lll_node import LLLnode -from vyper.parser.parser_utils import getpos, pack_arguments +from vyper.parser.parser_utils import getpos, make_setter, pack_arguments from vyper.signatures.function_signature import FunctionSignature from vyper.types import ( BaseType, @@ -21,38 +21,6 @@ ) -def _call_lookup_specs(stmt_expr, context): - from vyper.parser.expr import Expr - - method_name = stmt_expr.func.attr - - if len(stmt_expr.keywords): - raise TypeMismatch( - "Cannot use keyword arguments in calls to functions via 'self'", stmt_expr, - ) - expr_args = [Expr(arg, context).lll_node for arg in stmt_expr.args] - - sig = FunctionSignature.lookup_sig(context.sigs, method_name, expr_args, stmt_expr, context,) - - return method_name, expr_args, sig - - -def make_call(stmt_expr, context): - method_name, _, sig = _call_lookup_specs(stmt_expr, context) - - if context.is_constant() and sig.mutability not in ("view", "pure"): - raise StateAccessViolation( - f"May not call state modifying function " - f"'{method_name}' within {context.pp_constancy()}.", - getpos(stmt_expr), - ) - - if not sig.internal: - raise StructureException("Cannot call external functions via 'self'", stmt_expr) - - return _call_self_internal(stmt_expr, context, sig) - - def _call_make_placeholder(stmt_expr, context, sig): if sig.output_type is None: return 0, 0, 0 @@ -79,7 +47,7 @@ def _call_make_placeholder(stmt_expr, context, sig): return output_placeholder, returner, output_size -def _call_self_internal(stmt_expr, context, sig): +def make_call(stmt_expr, context): # ** Internal Call ** # Steps: # (x) push current local variables @@ -89,13 +57,49 @@ def _call_self_internal(stmt_expr, context, sig): # (x) pop return values # (x) pop local variables - method_name, expr_args, sig = _call_lookup_specs(stmt_expr, context) pre_init = [] pop_local_vars = [] push_local_vars = [] pop_return_values = [] push_args = [] + from vyper.parser.expr import Expr + + method_name = stmt_expr.func.attr + + expr_args = [] + for arg in stmt_expr.args: + lll_node = Expr(arg, context).lll_node + if isinstance(arg, vy_ast.Call): + # if the argument is a function call, perform the call seperately and + # assign it's result to memory, then reference the memory location when + # building this call. otherwise there is potential for memory corruption + target = LLLnode.from_list( + context.new_placeholder(lll_node.typ), + typ=lll_node.typ, + location="memory", + pos=getpos(arg), + ) + setter = make_setter(target, lll_node, "memory", pos=getpos(arg)) + expr_args.append( + LLLnode.from_list(target, typ=lll_node.typ, pos=getpos(arg), location="memory") + ) + pre_init.append(setter) + else: + expr_args.append(lll_node) + + sig = FunctionSignature.lookup_sig(context.sigs, method_name, expr_args, stmt_expr, context,) + + if context.is_constant() and sig.mutability not in ("view", "pure"): + raise StateAccessViolation( + f"May not call state modifying function " + f"'{method_name}' within {context.pp_constancy()}.", + getpos(stmt_expr), + ) + + if not sig.internal: + raise StructureException("Cannot call external functions via 'self'", stmt_expr) + # Push local variables. var_slots = [(v.pos, v.size) for name, v in context.vars.items() if v.location == "memory"] if var_slots: