Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fix test for slice #3633

Merged
merged 11 commits into from
Nov 2, 2023
88 changes: 55 additions & 33 deletions tests/parser/functions/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
_bytes_1024 = st.binary(min_size=0, max_size=1024)


@pytest.mark.parametrize("literal_start", (True, False))
@pytest.mark.parametrize("literal_length", (True, False))
@pytest.mark.parametrize("use_literal_start", (True, False))
@pytest.mark.parametrize("use_literal_length", (True, False))
@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024)
@settings(max_examples=100)
Expand All @@ -45,13 +45,13 @@
opt_level,
bytesdata,
start,
literal_start,
use_literal_start,
length,
literal_length,
use_literal_length,
length_bound,
):
_start = start if literal_start else "start"
_length = length if literal_length else "length"
_start = start if use_literal_start else "start"
_length = length if use_literal_length else "length"

code = f"""
IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
Expand All @@ -71,10 +71,10 @@
return get_contract(code, bytesdata, start, length, override_opt_level=opt_level)

if (
(start + length > length_bound and literal_start and literal_length)
or (literal_length and length > length_bound)
or (literal_start and start > length_bound)
or (literal_length and length < 1)
(start + length > length_bound and use_literal_start and use_literal_length)
or (use_literal_length and length > length_bound)
or (use_literal_start and start > length_bound)
or (use_literal_length and length == 0)
):
assert_compile_failed(lambda: _get_contract(), ArgumentException)
elif start + length > len(bytesdata) or (len(bytesdata) > length_bound):
Expand All @@ -86,32 +86,42 @@


@pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code"))
@pytest.mark.parametrize("literal_start", (True, False))
@pytest.mark.parametrize("literal_length", (True, False))
@pytest.mark.parametrize("use_literal_start", (True, False))
@pytest.mark.parametrize("use_literal_length", (True, False))
@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024)
@settings(max_examples=100)
@pytest.mark.fuzzing
def test_slice_bytes(
def test_slice_bytes_fuzz(
get_contract,
assert_compile_failed,
assert_tx_failed,
opt_level,
location,
bytesdata,
start,
literal_start,
use_literal_start,
length,
literal_length,
use_literal_length,
length_bound,
):
preamble = ""
if location == "memory":
spliced_code = f"foo: Bytes[{length_bound}] = inp"
foo = "foo"
elif location == "storage":
preamble = f"""
foo: Bytes[{length_bound}]
"""
spliced_code = "self.foo = inp"
foo = "self.foo"
elif location == "code":
preamble = f"""
IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
@external
def __init__(foo: Bytes[{length_bound}]):
IMMUTABLE_BYTES = foo
"""
spliced_code = ""
foo = "IMMUTABLE_BYTES"
elif location == "literal":
Expand All @@ -123,15 +133,11 @@
else:
raise Exception("unreachable")

_start = start if literal_start else "start"
_length = length if literal_length else "length"
_start = start if use_literal_start else "start"
_length = length if use_literal_length else "length"

code = f"""
foo: Bytes[{length_bound}]
IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
@external
def __init__(foo: Bytes[{length_bound}]):
IMMUTABLE_BYTES = foo
{preamble}

@external
def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Bytes[{length_bound}]:
Expand All @@ -142,24 +148,40 @@
def _get_contract():
return get_contract(code, bytesdata, override_opt_level=opt_level)

data_length = len(bytesdata) if location == "literal" else length_bound
if (
(start + length > data_length and literal_start and literal_length)
or (literal_length and length > data_length)
or (location == "literal" and len(bytesdata) > length_bound)
or (literal_start and start > data_length)
or (literal_length and length < 1)
):
# length bound is the container size; input_bound is the bound on the input
# (which can be different, if the input is a literal)
input_bound = length_bound
slice_output_too_large = False

if location == "literal":
input_bound = len(bytesdata)

# ex.:
# @external
# def do_slice(inp: Bytes[1], start: uint256, length: uint256) -> Bytes[1]:
# return slice(b'\x00\x00', 0, length)
Dismissed Show dismissed Hide dismissed
output_length = length if use_literal_length else input_bound
slice_output_too_large = output_length > length_bound

end = start + length

compile_time_oob = (
(use_literal_length and (length > input_bound or length == 0))
or (use_literal_start and start > input_bound)
or (use_literal_start and use_literal_length and start + length > input_bound)
)

if compile_time_oob or slice_output_too_large:
assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch))
elif len(bytesdata) > data_length:
elif location == "code" and len(bytesdata) > length_bound:
# deploy fail
assert_tx_failed(lambda: _get_contract())
elif start + length > len(bytesdata):
elif end > len(bytesdata) or len(bytesdata) > length_bound:
c = _get_contract()
assert_tx_failed(lambda: c.do_slice(bytesdata, start, length))
else:
c = _get_contract()
assert c.do_slice(bytesdata, start, length) == bytesdata[start : start + length], code
assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code


def test_slice_private(get_contract):
Expand Down
Loading