diff --git a/tests/parser/features/test_external_contract_calls.py b/tests/parser/features/test_external_contract_calls.py index 8fe6d82831a..cef169eb06a 100644 --- a/tests/parser/features/test_external_contract_calls.py +++ b/tests/parser/features/test_external_contract_calls.py @@ -199,6 +199,28 @@ def __init__(arg1: address): print('Successfully executed a multiple external contract calls') +def test_invalid_external_contract_call_to_the_same_contract(assert_tx_failed): + contract = """ +class Bar(): + def bar() -> num: pass + +def bar() -> num: + return 1 + +def _expr(x: address) -> num: + return Bar(x).bar() + +def _stmt(x: address) -> num: + return Bar(x).bar() + """ + t.s = t.Chain() + c = get_contract(contract) + c._expr(t.a1) + c._stmt(t.a1) + assert_tx_failed(t, lambda: c._expr(c.address)) + assert_tx_failed(t, lambda: c._stmt(c.address)) + + def test_invalid_contract_reference_declaration(assert_tx_failed): contract = """ class Bar(): diff --git a/viper/parser/parser.py b/viper/parser/parser.py index c4a5408bdf6..97fd7ce73df 100644 --- a/viper/parser/parser.py +++ b/viper/parser/parser.py @@ -460,7 +460,9 @@ def external_contract_call_stmt(stmt, context): sig = context.sigs[contract_name][method_name] contract_address = parse_expr(stmt.func.value.args[0], context) inargs, inargsize = pack_arguments(sig, [parse_expr(arg, context) for arg in stmt.args], context) - o = LLLnode.from_list(['assert', ['call', ['gas'], ['mload', contract_address], 0, inargs, inargsize, 0, 0]], + o = LLLnode.from_list(['seq', + ['assert', ['ne', 'address', ['mload', contract_address]]], + ['assert', ['call', ['gas'], ['mload', contract_address], 0, inargs, inargsize, 0, 0]]], typ=None, location='memory', pos=getpos(stmt)) return o @@ -484,6 +486,7 @@ def external_contract_call_expr(expr, context): else: raise TypeMismatchException("Invalid output type: %r" % sig.output_type, expr) o = LLLnode.from_list(['seq', + ['assert', ['ne', 'address', ['mload', contract_address]]], ['assert', ['call', ['gas'], ['mload', contract_address], 0, inargs, inargsize, output_placeholder, get_size_of_type(sig.output_type) * 32]],