Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: function calls within tuples / nested calls #2186

Merged
merged 4 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/parser/syntax/test_return_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,124 @@ def test_tuple_return_fail(bad_code):

with pytest.raises(FunctionDeclarationException):
compiler.compile_code(bad_code)


def test_self_call_in_return_tuple(get_contract):
code = """
@internal
def _foo() -> uint256:
a: uint256[10] = [6,7,8,9,10,11,12,13,14,15]
return 3

@external
def foo() -> (uint256, uint256, uint256, uint256, uint256):
return 1, 2, self._foo(), 4, 5
"""

c = get_contract(code)

assert c.foo() == [1, 2, 3, 4, 5]


def test_call_in_call(get_contract):
code = """
@internal
def _foo(a: uint256, b: uint256, c: uint256) -> (uint256, uint256, uint256, uint256, uint256):
return 1, a, b, c, 5

@internal
def _foo2() -> uint256:
a: uint256[10] = [6,7,8,9,10,11,12,13,15,16]
return 4

@external
def foo() -> (uint256, uint256, uint256, uint256, uint256):
return self._foo(2, 3, self._foo2())
"""

c = get_contract(code)

assert c.foo() == [1, 2, 3, 4, 5]


def test_nested_calls_in_tuple_return(get_contract):
code = """
@internal
def _foo(a: uint256, b: uint256, c: uint256) -> (uint256, uint256):
return 415, 3

@internal
def _foo2(a: uint256) -> uint256:
b: uint256[10] = [6,7,8,9,10,11,12,13,14,15]
return 99

@internal
def _foo3(a: uint256, b: uint256) -> uint256:
c: uint256[10] = [14,15,16,17,18,19,20,21,22,23]
return 42

@internal
def _foo4() -> uint256:
c: uint256[10] = [14,15,16,17,18,19,20,21,22,23]
return 4

@external
def foo() -> (uint256, uint256, uint256, uint256, uint256):
return 1, 2, self._foo(6, 7, self._foo2(self._foo3(9, 11)))[1], self._foo4(), 5
"""

c = get_contract(code)

assert c.foo() == [1, 2, 3, 4, 5]


def test_external_call_in_return_tuple(get_contract):
code = """
@view
@external
def foo() -> (uint256, uint256):
return 3, 4
"""

code2 = """
interface Foo:
def foo() -> (uint256, uint256): view

@external
def foo(a: address) -> (uint256, uint256, uint256, uint256, uint256):
return 1, 2, Foo(a).foo()[0], 4, 5
"""

c = get_contract(code)
c2 = get_contract(code2)

assert c2.foo(c.address) == [1, 2, 3, 4, 5]


def test_nested_external_call_in_return_tuple(get_contract):
code = """
@view
@external
def foo() -> (uint256, uint256):
return 3, 4

@view
@external
def bar(a: uint256) -> uint256:
return a+1
"""

code2 = """
interface Foo:
def foo() -> (uint256, uint256): view
def bar(a: uint256) -> uint256: view

@external
def foo(a: address) -> (uint256, uint256, uint256, uint256, uint256):
return 1, 2, Foo(a).foo()[0], 4, Foo(a).bar(Foo(a).foo()[1])
"""

c = get_contract(code)
c2 = get_contract(code2)

assert c2.foo(c.address) == [1, 2, 3, 4, 5]
15 changes: 15 additions & 0 deletions vyper/codegen/return_.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ def gen_tuple_return(stmt, context, sub):
]
return LLLnode.from_list(os, typ=None, pos=getpos(stmt), valency=0)

# for tuple return types where a function is called inside the tuple, we
# process the calls prior to encoding the return data
if sub.value == "seq_unchecked" and sub.args[-1].value == "multi":
encode_out = abi_encode(return_buffer, sub.args[-1], pos=getpos(stmt), returns=True)
load_return_len = ["mload", MemoryPositions.FREE_VAR_SPACE]
os = (
["seq"]
+ sub.args[:-1]
+ [
["mstore", MemoryPositions.FREE_VAR_SPACE, encode_out],
make_return_stmt(stmt, context, return_buffer, load_return_len),
]
)
return LLLnode.from_list(os, typ=None, pos=getpos(stmt), valency=0)

# for all othe cases we are creating a stack variable named sub_loc to store the location
# of the return expression. This is done so that the return expression does not get evaluated
# abi-encode uses a function named o_list which evaluate the expression multiple times
Expand Down
32 changes: 28 additions & 4 deletions vyper/parser/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,11 +961,35 @@ def struct_literals(expr, name, context):
def parse_Tuple(self):
if not len(self.expr.elements):
return
lll_node = []
call_lll = []
multi_lll = []
for node in self.expr.elements:
lll_node.append(Expr(node, self.context).lll_node)
typ = TupleType([x.typ for x in lll_node], is_literal=True)
return LLLnode.from_list(["multi"] + lll_node, typ=typ, pos=getpos(self.expr))
if isinstance(node, vy_ast.Call):
# for calls inside the tuple, we perform the call prior to building the tuple and
# assign it's result to memory - otherwise there is potential for memory corruption
lll_node = Expr(node, self.context).lll_node
target = LLLnode.from_list(
self.context.new_placeholder(lll_node.typ),
typ=lll_node.typ,
location="memory",
pos=getpos(self.expr),
)
call_lll.append(make_setter(target, lll_node, "memory", pos=getpos(self.expr)))
multi_lll.append(
LLLnode.from_list(
target, typ=lll_node.typ, pos=getpos(self.expr), location="memory"
),
)
else:
multi_lll.append(Expr(node, self.context).lll_node)

typ = TupleType([x.typ for x in multi_lll], is_literal=True)
multi_lll = LLLnode.from_list(["multi"] + multi_lll, typ=typ, pos=getpos(self.expr))
if not call_lll:
return multi_lll

lll_node = ["seq_unchecked"] + call_lll + [multi_lll]
return LLLnode.from_list(lll_node, typ=typ, pos=getpos(self.expr))

# Parse an expression that results in a value
@classmethod
Expand Down
76 changes: 40 additions & 36 deletions vyper/parser/self_call.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import itertools

from vyper import ast as vy_ast
from vyper.codegen.abi import abi_decode
from vyper.exceptions import (
StateAccessViolation,
StructureException,
TypeCheckFailure,
TypeMismatch,
)
from vyper.parser.lll_node import LLLnode
from vyper.parser.parser_utils import getpos, pack_arguments
from vyper.parser.parser_utils import getpos, make_setter, pack_arguments
from vyper.signatures.function_signature import FunctionSignature
from vyper.types import (
BaseType,
Expand All @@ -21,38 +21,6 @@
)


def _call_lookup_specs(stmt_expr, context):
from vyper.parser.expr import Expr

method_name = stmt_expr.func.attr

if len(stmt_expr.keywords):
raise TypeMismatch(
"Cannot use keyword arguments in calls to functions via 'self'", stmt_expr,
)
expr_args = [Expr(arg, context).lll_node for arg in stmt_expr.args]

sig = FunctionSignature.lookup_sig(context.sigs, method_name, expr_args, stmt_expr, context,)

return method_name, expr_args, sig


def make_call(stmt_expr, context):
method_name, _, sig = _call_lookup_specs(stmt_expr, context)

if context.is_constant() and sig.mutability not in ("view", "pure"):
raise StateAccessViolation(
f"May not call state modifying function "
f"'{method_name}' within {context.pp_constancy()}.",
getpos(stmt_expr),
)

if not sig.internal:
raise StructureException("Cannot call external functions via 'self'", stmt_expr)

return _call_self_internal(stmt_expr, context, sig)


def _call_make_placeholder(stmt_expr, context, sig):
if sig.output_type is None:
return 0, 0, 0
Expand All @@ -79,7 +47,7 @@ def _call_make_placeholder(stmt_expr, context, sig):
return output_placeholder, returner, output_size


def _call_self_internal(stmt_expr, context, sig):
def make_call(stmt_expr, context):
# ** Internal Call **
# Steps:
# (x) push current local variables
Expand All @@ -89,13 +57,49 @@ def _call_self_internal(stmt_expr, context, sig):
# (x) pop return values
# (x) pop local variables

method_name, expr_args, sig = _call_lookup_specs(stmt_expr, context)
pre_init = []
pop_local_vars = []
push_local_vars = []
pop_return_values = []
push_args = []

from vyper.parser.expr import Expr

method_name = stmt_expr.func.attr

expr_args = []
for arg in stmt_expr.args:
lll_node = Expr(arg, context).lll_node
if isinstance(arg, vy_ast.Call):
# if the argument is a function call, perform the call seperately and
# assign it's result to memory, then reference the memory location when
# building this call. otherwise there is potential for memory corruption
target = LLLnode.from_list(
context.new_placeholder(lll_node.typ),
typ=lll_node.typ,
location="memory",
pos=getpos(arg),
)
setter = make_setter(target, lll_node, "memory", pos=getpos(arg))
expr_args.append(
LLLnode.from_list(target, typ=lll_node.typ, pos=getpos(arg), location="memory")
)
pre_init.append(setter)
else:
expr_args.append(lll_node)

sig = FunctionSignature.lookup_sig(context.sigs, method_name, expr_args, stmt_expr, context,)

if context.is_constant() and sig.mutability not in ("view", "pure"):
raise StateAccessViolation(
f"May not call state modifying function "
f"'{method_name}' within {context.pp_constancy()}.",
getpos(stmt_expr),
)

if not sig.internal:
raise StructureException("Cannot call external functions via 'self'", stmt_expr)

# Push local variables.
var_slots = [(v.pos, v.size) for name, v in context.vars.items() if v.location == "memory"]
if var_slots:
Expand Down