diff --git a/tests/parser/functions/test_reset.py b/tests/parser/functions/test_reset.py index d9979389866..efdba3af783 100644 --- a/tests/parser/functions/test_reset.py +++ b/tests/parser/functions/test_reset.py @@ -92,7 +92,7 @@ def foo(): c.foo() -def test_reset_lists(get_contract_with_gas_estimation): +def test_reset_basic_type_lists(get_contract_with_gas_estimation): contracts = [ """ foobar: int128[3] @@ -209,32 +209,90 @@ def foo(): assert bar[0] == ZERO_ADDRESS assert bar[1] == ZERO_ADDRESS assert bar[2] == ZERO_ADDRESS + """ + ] + + for contract in contracts: + c = get_contract_with_gas_estimation(contract) + c.foo() -foobar: address + +def test_reset_bytes(get_contract_with_gas_estimation): + code = """ +foobar: bytes[5] @public -def foo(): - self.foobar = msg.sender - bar: address = msg.sender +def foo() -> (bytes[5], bytes[5]): + self.foobar = 'Hello' + bar: bytes[5] = 'World' reset(self.foobar) reset(bar) - assert self.foobar == ZERO_ADDRESS - assert bar == ZERO_ADDRESS + return (self.foobar, bar) """ - ] - for contract in contracts: - c = get_contract_with_gas_estimation(contract) - c.foo() + c = get_contract_with_gas_estimation(code) + a, b = c.foo() + assert a == b'' + assert b == b'' + + +def test_reset_struct(get_contract_with_gas_estimation): + code = """ +foobar: { + a: int128, + b: uint256, + c: bool, + d: decimal, + e: bytes32, + f: address +} +@public +def foo(): + self.foobar = { + a: 1, + b: 2, + c: True, + d: 3.0, + e: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + f: msg.sender + } + bar: { + a: int128, + b: uint256, + c: bool, + d: decimal, + e: bytes32, + f: address + } = { + a: 1, + b: 2, + c: True, + d: 3.0, + e: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF, + f: msg.sender + } + + reset(self.foobar) + reset(bar) -# def test_reset_struct(get_contract_with_gas_estimation): -# code = """ -# #TODO -# """ + assert self.foobar.a == 0 + assert self.foobar.b == 0 + assert self.foobar.c == False + assert self.foobar.d == 0.0 + assert self.foobar.e == 0x0000000000000000000000000000000000000000000000000000000000000000 + assert self.foobar.f == ZERO_ADDRESS + + assert bar.a == 0 + assert bar.b == 0 + assert bar.c == False + assert bar.d == 0.0 + assert bar.e == 0x0000000000000000000000000000000000000000000000000000000000000000 + assert bar.f == ZERO_ADDRESS + """ -# c = get_contract_with_gas_estimation(code) -# c.foo() + c = get_contract_with_gas_estimation(code) + c.foo() diff --git a/tests/parser/syntax/test_maps.py b/tests/parser/syntax/test_maps.py index 67f4f575783..6f68e8d0c36 100644 --- a/tests/parser/syntax/test_maps.py +++ b/tests/parser/syntax/test_maps.py @@ -172,7 +172,8 @@ def foo(): nom: {c: int128}[3] @public def foo(): - self.mom = {b: 5} + empty: {c: int128}[3] + self.mom = {a: empty, b: 5} """, """ mom: {a: {c: int128}[3], b: int128} diff --git a/vyper/functions/functions.py b/vyper/functions/functions.py index 117a9a011cc..c428af2b545 100644 --- a/vyper/functions/functions.py +++ b/vyper/functions/functions.py @@ -5,6 +5,7 @@ InvalidLiteralException, StructureException, TypeMismatchException, + ParserException, ) from .signature import ( signature, @@ -50,9 +51,6 @@ from vyper.types.convert import ( convert, ) -from vyper.types.reset import ( - reset, -) def enforce_units(typ, obj, expected): @@ -104,10 +102,6 @@ def _convert(expr, context): return convert(expr, context) -def _reset(expr, context): - return reset(expr, context) - - @signature('bytes', start='int128', len='int128') def _slice(expr, args, kwargs, context): sub, start, length = args[0], kwargs['start'], kwargs['len'] @@ -689,6 +683,10 @@ def _can_compare_with_uint256(operand): return LLLnode.from_list(['with', '_l', left, ['with', '_r', right, o]], typ=otyp, pos=getpos(expr)) +def _reset(): + raise ParserException("This function should never be called! `reset()` is currently handled differently than other functions as it self modifies its input argument statement. Please see `_reset()` in `stmt.py`") + + dispatch_table = { 'floor': floor, 'ceil': ceil, diff --git a/vyper/parser/stmt.py b/vyper/parser/stmt.py index 86ce897b268..667ebd79a3b 100644 --- a/vyper/parser/stmt.py +++ b/vyper/parser/stmt.py @@ -228,7 +228,7 @@ def parse_if(self): self.context.end_blockscope(block_scope_id) return o - def reset_var(self): + def _reset(self): # Create zero node none = ast.NameConstant(None) none.lineno = self.stmt.lineno @@ -252,7 +252,7 @@ def call(self): if isinstance(self.stmt.func, ast.Name): if self.stmt.func.id in stmt_dispatch_table: if self.stmt.func.id == 'reset': - return self.reset_var() + return self._reset() else: return stmt_dispatch_table[self.stmt.func.id](self.stmt, self.context) elif self.stmt.func.id in dispatch_table: diff --git a/vyper/types/reset.py b/vyper/types/reset.py deleted file mode 100644 index 73d3eed4f92..00000000000 --- a/vyper/types/reset.py +++ /dev/null @@ -1,65 +0,0 @@ -import ast -import warnings - -from vyper.functions.signature import ( - signature -) -from vyper.exceptions import ( - TypeMismatchException, - ParserException, -) -from vyper.parser.parser_utils import ( - LLLnode, - getpos, -) -from vyper.types import ( - BaseType, - get_type, -) -from vyper.parser.expr import ( - Expr -) - - -def get_mem_location(load, expr): - # Get appropriate load op - if load.value == 'mload': - op = 'mstore' - elif load.value == 'sload': - op = 'sstore' - else: - raise ParserException( - "Malformed variable input load! {}".format(load.value), - expr - ) - - # Get appropriate memory location - args = load.args - if len(args) != 1: - raise ParserException( - "Malformed variable input load! {}".format(args), - expr - ) - memloc = args[0] - return op, memloc - - -@signature(('int128', 'uint256', 'bool', 'decimal', 'bytes32', 'address')) -def reset(expr, args, kwargs, context): - in_arg = args[0] - input_typ, _len = get_type(in_arg) - - if input_typ in ('int128', 'uint256', 'bool', 'decimal', 'bytes32', 'address'): - # Generate op-code - op, memloc = get_mem_location(in_arg, expr) - return LLLnode.from_list( - [op, memloc, 0], - typ=BaseType(input_typ, None, is_literal=True), - pos=getpos(expr) - ) - - else: - raise TypeMismatchException( - "Unable to reset type {}".format(input_typ), - expr - )