From bc68427b21457cf350733de19d834dfbf3f84c87 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 30 Sep 2023 09:46:38 +0800 Subject: [PATCH 01/11] relax slice return type checking; get return type from semantics --- vyper/builtins/functions.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index f07202831d..7a128567a0 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -333,8 +333,6 @@ def fetch_call_return(self, node): # we know the length statically if length_literal is not None: return_type.set_length(length_literal) - else: - return_type.set_min_length(arg_type.length) return return_type @@ -378,12 +376,8 @@ def build_IR(self, expr, args, kwargs, context): buflen += 32 # Get returntype string or bytes - assert isinstance(src.typ, _BytestringT) or is_bytes32 - # TODO: try to get dst_typ from semantic analysis - if isinstance(src.typ, StringT): - dst_typ = StringT(dst_maxlen) - else: - dst_typ = BytesT(dst_maxlen) + dst_typ = expr._metadata.get("type") + assert isinstance(dst_typ, _BytestringT) or is_bytes32 # allocate a buffer for the return value buf = context.new_internal_variable(BytesT(buflen)) From 008e7021e56eb79f9842d0d58b650e93f90c2c41 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 30 Sep 2023 09:46:44 +0800 Subject: [PATCH 02/11] modify tests --- tests/parser/functions/test_slice.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 3090dafda0..37beffa598 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -146,12 +146,11 @@ def _get_contract(): if ( (start + length > data_length and literal_start and literal_length) or (literal_length and length > data_length) - or (location == "literal" and len(bytesdata) > length_bound) or (literal_start and start > data_length) or (literal_length and length < 1) ): assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif len(bytesdata) > data_length: + elif len(bytesdata) > data_length or (location == "literal" and len(bytesdata) > length_bound): # deploy fail assert_tx_failed(lambda: _get_contract()) elif start + length > len(bytesdata): From 0933c9f8d5685c696684a02aca413faa2dc0ce57 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sat, 30 Sep 2023 10:43:57 +0800 Subject: [PATCH 03/11] set max len if not annotated --- vyper/builtins/functions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 7a128567a0..d71592ee1f 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -377,7 +377,11 @@ def build_IR(self, expr, args, kwargs, context): # Get returntype string or bytes dst_typ = expr._metadata.get("type") - assert isinstance(dst_typ, _BytestringT) or is_bytes32 + assert isinstance(dst_typ, _BytestringT) + + # set the length of the return type if it was not defined in annotation + if dst_typ.length == 0: + dst_typ.set_length(dst_maxlen) # allocate a buffer for the return value buf = context.new_internal_variable(BytesT(buflen)) From 56ff2399676ae3343cf5a9191d5ea1573eaf51ba Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:03:56 +0800 Subject: [PATCH 04/11] Revert "set max len if not annotated" This reverts commit 0933c9f8d5685c696684a02aca413faa2dc0ce57. --- vyper/builtins/functions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index d71592ee1f..7a128567a0 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -377,11 +377,7 @@ def build_IR(self, expr, args, kwargs, context): # Get returntype string or bytes dst_typ = expr._metadata.get("type") - assert isinstance(dst_typ, _BytestringT) - - # set the length of the return type if it was not defined in annotation - if dst_typ.length == 0: - dst_typ.set_length(dst_maxlen) + assert isinstance(dst_typ, _BytestringT) or is_bytes32 # allocate a buffer for the return value buf = context.new_internal_variable(BytesT(buflen)) From a702209194592c97fd5b742ece144b57ca3d4297 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:04:04 +0800 Subject: [PATCH 05/11] Revert "modify tests" This reverts commit 008e7021e56eb79f9842d0d58b650e93f90c2c41. --- tests/parser/functions/test_slice.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 37beffa598..3090dafda0 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -146,11 +146,12 @@ def _get_contract(): if ( (start + length > data_length and literal_start and literal_length) or (literal_length and length > data_length) + or (location == "literal" and len(bytesdata) > length_bound) or (literal_start and start > data_length) or (literal_length and length < 1) ): assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif len(bytesdata) > data_length or (location == "literal" and len(bytesdata) > length_bound): + elif len(bytesdata) > data_length: # deploy fail assert_tx_failed(lambda: _get_contract()) elif start + length > len(bytesdata): From 6aaa48ce4fe03ec4ad28d4da1aae7bfc15b23cf3 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:04:11 +0800 Subject: [PATCH 06/11] Revert "relax slice return type checking; get return type from semantics" This reverts commit bc68427b21457cf350733de19d834dfbf3f84c87. --- vyper/builtins/functions.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 7a128567a0..f07202831d 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -333,6 +333,8 @@ def fetch_call_return(self, node): # we know the length statically if length_literal is not None: return_type.set_length(length_literal) + else: + return_type.set_min_length(arg_type.length) return return_type @@ -376,8 +378,12 @@ def build_IR(self, expr, args, kwargs, context): buflen += 32 # Get returntype string or bytes - dst_typ = expr._metadata.get("type") - assert isinstance(dst_typ, _BytestringT) or is_bytes32 + assert isinstance(src.typ, _BytestringT) or is_bytes32 + # TODO: try to get dst_typ from semantic analysis + if isinstance(src.typ, StringT): + dst_typ = StringT(dst_maxlen) + else: + dst_typ = BytesT(dst_maxlen) # allocate a buffer for the return value buf = context.new_internal_variable(BytesT(buflen)) From df1378871577ebb6b184c9f3a0214f1c07c5a12c Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 4 Oct 2023 12:26:38 +0800 Subject: [PATCH 07/11] fix test --- tests/parser/functions/test_slice.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 3090dafda0..0098242348 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -105,13 +105,23 @@ def test_slice_bytes( literal_length, length_bound, ): + preamble = "" if location == "memory": spliced_code = f"foo: Bytes[{length_bound}] = inp" foo = "foo" elif location == "storage": + preamble = f""" +foo: Bytes[{length_bound}] + """ spliced_code = "self.foo = inp" foo = "self.foo" elif location == "code": + preamble = f""" +IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) +@external +def __init__(foo: Bytes[{length_bound}]): + IMMUTABLE_BYTES = foo + """ spliced_code = "" foo = "IMMUTABLE_BYTES" elif location == "literal": @@ -127,11 +137,7 @@ def test_slice_bytes( _length = length if literal_length else "length" code = f""" -foo: Bytes[{length_bound}] -IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) -@external -def __init__(foo: Bytes[{length_bound}]): - IMMUTABLE_BYTES = foo +{preamble} @external def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Bytes[{length_bound}]: @@ -146,15 +152,15 @@ def _get_contract(): if ( (start + length > data_length and literal_start and literal_length) or (literal_length and length > data_length) - or (location == "literal" and len(bytesdata) > length_bound) + or (location == "literal" and len(bytesdata) > length_bound and not literal_length) or (literal_start and start > data_length) or (literal_length and length < 1) ): assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif len(bytesdata) > data_length: + elif location in "code" and len(bytesdata) > data_length: # deploy fail assert_tx_failed(lambda: _get_contract()) - elif start + length > len(bytesdata): + elif start + length > len(bytesdata) or len(bytesdata) > length_bound: c = _get_contract() assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) else: From ccbbbabfad0e5c0d1f0e63626eef991b9ba375e1 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:03:31 +0800 Subject: [PATCH 08/11] fix literal case --- tests/parser/functions/test_slice.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 0098242348..5db7bed63b 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -152,7 +152,11 @@ def _get_contract(): if ( (start + length > data_length and literal_start and literal_length) or (literal_length and length > data_length) - or (location == "literal" and len(bytesdata) > length_bound and not literal_length) + or ( + location == "literal" + and len(bytesdata) > length_bound + and ((literal_length and length > length_bound) or (not literal_length)) + ) or (literal_start and start > data_length) or (literal_length and length < 1) ): From f0f4d72ee4d6dd9d9dfdd3112925d78076796f7d Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 6 Oct 2023 23:40:37 +0800 Subject: [PATCH 09/11] fix location check --- tests/parser/functions/test_slice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 5db7bed63b..411c61bdf1 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -161,7 +161,7 @@ def _get_contract(): or (literal_length and length < 1) ): assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif location in "code" and len(bytesdata) > data_length: + elif location == "code" and len(bytesdata) > data_length: # deploy fail assert_tx_failed(lambda: _get_contract()) elif start + length > len(bytesdata) or len(bytesdata) > length_bound: From 84f0a83f00a94ce38596e02ff6cbdf3325e0915d Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Thu, 12 Oct 2023 15:48:02 +0800 Subject: [PATCH 10/11] add bools; add comment --- tests/parser/functions/test_slice.py | 38 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 411c61bdf1..28cd21dc8c 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -149,27 +149,41 @@ def _get_contract(): return get_contract(code, bytesdata, override_opt_level=opt_level) data_length = len(bytesdata) if location == "literal" else length_bound + end = start + length + + is_zero_literal_length = literal_length and length < 1 + literal_start_exceeds_data_length = literal_start and start > data_length + literal_length_exceeds_data = (literal_length and length > data_length) or ( + end > data_length and literal_start and literal_length + ) + + data_longer_than_length_bound = len(bytesdata) > length_bound + # `not literal_length` condition catches this contract: + # @external + # def do_slice(inp: Bytes[1], start: uint256, length: uint256) -> Bytes[1]: + # return slice(b'\x00\x00', 0, length) + invalid_slice_literal = ( + location == "literal" + and data_longer_than_length_bound + and ((literal_length and length > length_bound) or not literal_length) + ) + if ( - (start + length > data_length and literal_start and literal_length) - or (literal_length and length > data_length) - or ( - location == "literal" - and len(bytesdata) > length_bound - and ((literal_length and length > length_bound) or (not literal_length)) - ) - or (literal_start and start > data_length) - or (literal_length and length < 1) + is_zero_literal_length + or literal_start_exceeds_data_length + or literal_length_exceeds_data + or invalid_slice_literal ): assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif location == "code" and len(bytesdata) > data_length: + elif location == "code" and data_longer_than_length_bound: # deploy fail assert_tx_failed(lambda: _get_contract()) - elif start + length > len(bytesdata) or len(bytesdata) > length_bound: + elif end > len(bytesdata) or data_longer_than_length_bound: c = _get_contract() assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) else: c = _get_contract() - assert c.do_slice(bytesdata, start, length) == bytesdata[start : start + length], code + assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code def test_slice_private(get_contract): From da640030e7f890312af7e54333d4e5f823ba91c7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 30 Oct 2023 11:08:03 -0400 Subject: [PATCH 11/11] simplify oob logic --- tests/parser/functions/test_slice.py | 80 ++++++++++++++-------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 28cd21dc8c..53e092019f 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -32,8 +32,8 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: _bytes_1024 = st.binary(min_size=0, max_size=1024) -@pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("use_literal_start", (True, False)) +@pytest.mark.parametrize("use_literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) @settings(max_examples=100) @@ -45,13 +45,13 @@ def test_slice_immutable( opt_level, bytesdata, start, - literal_start, + use_literal_start, length, - literal_length, + use_literal_length, length_bound, ): - _start = start if literal_start else "start" - _length = length if literal_length else "length" + _start = start if use_literal_start else "start" + _length = length if use_literal_length else "length" code = f""" IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) @@ -71,10 +71,10 @@ def _get_contract(): return get_contract(code, bytesdata, start, length, override_opt_level=opt_level) if ( - (start + length > length_bound and literal_start and literal_length) - or (literal_length and length > length_bound) - or (literal_start and start > length_bound) - or (literal_length and length < 1) + (start + length > length_bound and use_literal_start and use_literal_length) + or (use_literal_length and length > length_bound) + or (use_literal_start and start > length_bound) + or (use_literal_length and length == 0) ): assert_compile_failed(lambda: _get_contract(), ArgumentException) elif start + length > len(bytesdata) or (len(bytesdata) > length_bound): @@ -86,13 +86,13 @@ def _get_contract(): @pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code")) -@pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("use_literal_start", (True, False)) +@pytest.mark.parametrize("use_literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) @settings(max_examples=100) @pytest.mark.fuzzing -def test_slice_bytes( +def test_slice_bytes_fuzz( get_contract, assert_compile_failed, assert_tx_failed, @@ -100,9 +100,9 @@ def test_slice_bytes( location, bytesdata, start, - literal_start, + use_literal_start, length, - literal_length, + use_literal_length, length_bound, ): preamble = "" @@ -133,8 +133,8 @@ def __init__(foo: Bytes[{length_bound}]): else: raise Exception("unreachable") - _start = start if literal_start else "start" - _length = length if literal_length else "length" + _start = start if use_literal_start else "start" + _length = length if use_literal_length else "length" code = f""" {preamble} @@ -148,37 +148,35 @@ def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Byt def _get_contract(): return get_contract(code, bytesdata, override_opt_level=opt_level) - data_length = len(bytesdata) if location == "literal" else length_bound - end = start + length + # length bound is the container size; input_bound is the bound on the input + # (which can be different, if the input is a literal) + input_bound = length_bound + slice_output_too_large = False - is_zero_literal_length = literal_length and length < 1 - literal_start_exceeds_data_length = literal_start and start > data_length - literal_length_exceeds_data = (literal_length and length > data_length) or ( - end > data_length and literal_start and literal_length - ) + if location == "literal": + input_bound = len(bytesdata) - data_longer_than_length_bound = len(bytesdata) > length_bound - # `not literal_length` condition catches this contract: - # @external - # def do_slice(inp: Bytes[1], start: uint256, length: uint256) -> Bytes[1]: - # return slice(b'\x00\x00', 0, length) - invalid_slice_literal = ( - location == "literal" - and data_longer_than_length_bound - and ((literal_length and length > length_bound) or not literal_length) + # ex.: + # @external + # def do_slice(inp: Bytes[1], start: uint256, length: uint256) -> Bytes[1]: + # return slice(b'\x00\x00', 0, length) + output_length = length if use_literal_length else input_bound + slice_output_too_large = output_length > length_bound + + end = start + length + + compile_time_oob = ( + (use_literal_length and (length > input_bound or length == 0)) + or (use_literal_start and start > input_bound) + or (use_literal_start and use_literal_length and start + length > input_bound) ) - if ( - is_zero_literal_length - or literal_start_exceeds_data_length - or literal_length_exceeds_data - or invalid_slice_literal - ): + if compile_time_oob or slice_output_too_large: assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif location == "code" and data_longer_than_length_bound: + elif location == "code" and len(bytesdata) > length_bound: # deploy fail assert_tx_failed(lambda: _get_contract()) - elif end > len(bytesdata) or data_longer_than_length_bound: + elif end > len(bytesdata) or len(bytesdata) > length_bound: c = _get_contract() assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) else: