Skip to content

Commit

Permalink
fix: complex arguments to builtin functions (#3167)
Browse files Browse the repository at this point in the history
prior to this commit, some builtin functions including ceil, would panic
if their arguments were function calls (or otherwise determined to be
complex expressions by `is_complex_ir`). this commit fixes the relevant
builtin functions by using `cache_when_complex` where appropriate.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
tserg and charles-cooper authored May 18, 2023
1 parent a8382f5 commit 6ee74f5
Show file tree
Hide file tree
Showing 9 changed files with 422 additions and 116 deletions.
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,38 @@ def _f(_addr, _salt, _initcode):
return keccak(prefix + addr + salt + keccak(initcode))[12:]

return _f


@pytest.fixture
def side_effects_contract(get_contract):
def generate(ret_type):
"""
Generates a Vyper contract with an external `foo()` function, which
returns the specified return value of the specified return type, for
testing side effects using the `assert_side_effects_invoked` fixture.
"""
code = f"""
counter: public(uint256)
@external
def foo(s: {ret_type}) -> {ret_type}:
self.counter += 1
return s
"""
contract = get_contract(code)
return contract

return generate


@pytest.fixture
def assert_side_effects_invoked():
def assert_side_effects_invoked(side_effects_contract, side_effects_trigger, n=1):
start_value = side_effects_contract.counter()

side_effects_trigger()

end_value = side_effects_contract.counter()
assert end_value == start_value + n

return assert_side_effects_invoked
57 changes: 57 additions & 0 deletions tests/parser/functions/test_addmod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation):
uint256_code = """
@external
def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256:
return uint256_addmod(x, y, z)
"""

c = get_contract_with_gas_estimation(uint256_code)

assert c._uint256_addmod(1, 2, 2) == 1
assert c._uint256_addmod(32, 2, 32) == 2
assert c._uint256_addmod((2**256) - 1, 0, 2) == 1
assert c._uint256_addmod(2**255, 2**255, 6) == 4
assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0))


def test_uint256_addmod_ext_call(
w3, side_effects_contract, assert_side_effects_invoked, get_contract
):
code = """
@external
def foo(f: Foo) -> uint256:
return uint256_addmod(32, 2, f.foo(32))
interface Foo:
def foo(x: uint256) -> uint256: payable
"""

c1 = side_effects_contract("uint256")
c2 = get_contract(code)

assert c2.foo(c1.address) == 2
assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_uint256_addmod_internal_call(get_contract_with_gas_estimation):
code = """
@external
def foo() -> uint256:
return uint256_addmod(self.a(), self.b(), self.c())
@internal
def a() -> uint256:
return 32
@internal
def b() -> uint256:
return 2
@internal
def c() -> uint256:
return 32
"""

c = get_contract_with_gas_estimation(code)

assert c.foo() == 2
31 changes: 31 additions & 0 deletions tests/parser/functions/test_as_wei_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> uint256:
return as_wei_value(a.foo(7), "ether")
interface Foo:
def foo(x: uint8) -> uint8: nonpayable
"""

c1 = side_effects_contract("uint8")
c2 = get_contract(code)

assert c2.foo(c1.address) == w3.to_wei(7, "ether")
assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_internal_call(w3, get_contract_with_gas_estimation):
code = """
@external
def foo() -> uint256:
return as_wei_value(self.bar(), "ether")
@internal
def bar() -> uint8:
return 7
"""

c = get_contract_with_gas_estimation(code)

assert c.foo() == w3.to_wei(7, "ether")
34 changes: 34 additions & 0 deletions tests/parser/functions/test_ceil.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,37 @@ def ceil_param(p: decimal) -> int256:
assert c.fou() == -3
assert c.ceil_param(Decimal("-0.5")) == 0
assert c.ceil_param(Decimal("-7777777.7777777")) == -7777777


def test_ceil_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> int256:
return ceil(a.foo(2.5))
interface Foo:
def foo(x: decimal) -> decimal: payable
"""

c1 = side_effects_contract("decimal")
c2 = get_contract(code)

assert c2.foo(c1.address) == 3

assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_ceil_internal_call(get_contract_with_gas_estimation):
code = """
@external
def foo() -> int256:
return ceil(self.bar())
@internal
def bar() -> decimal:
return 2.5
"""

c = get_contract_with_gas_estimation(code)

assert c.foo() == 3
62 changes: 62 additions & 0 deletions tests/parser/functions/test_ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,37 @@ def _ecadd3(x: uint256[2], y: uint256[2]) -> uint256[2]:
assert c._ecadd3(G1, negative_G1) == [0, 0]


def test_ecadd_internal_call(get_contract_with_gas_estimation):
code = """
@internal
def a() -> uint256[2]:
return [1, 2]
@external
def foo() -> uint256[2]:
return ecadd([1, 2], self.a())
"""
c = get_contract_with_gas_estimation(code)
assert c.foo() == G1_times_two


def test_ecadd_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
interface Foo:
def foo(x: uint256[2]) -> uint256[2]: payable
@external
def foo(a: Foo) -> uint256[2]:
return ecadd([1, 2], a.foo([1, 2]))
"""
c1 = side_effects_contract("uint256[2]")
c2 = get_contract(code)

assert c2.foo(c1.address) == G1_times_two

assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_ecmul(get_contract_with_gas_estimation):
ecmuller = """
x3: uint256[2]
Expand Down Expand Up @@ -74,3 +105,34 @@ def _ecmul3(x: uint256[2], y: uint256) -> uint256[2]:
assert c._ecmul(G1, 3) == G1_times_three
assert c._ecmul(G1, curve_order - 1) == negative_G1
assert c._ecmul(G1, curve_order) == [0, 0]


def test_ecmul_internal_call(get_contract_with_gas_estimation):
code = """
@internal
def a() -> uint256:
return 3
@external
def foo() -> uint256[2]:
return ecmul([1, 2], self.a())
"""
c = get_contract_with_gas_estimation(code)
assert c.foo() == G1_times_three


def test_ecmul_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
interface Foo:
def foo(x: uint256) -> uint256: payable
@external
def foo(a: Foo) -> uint256[2]:
return ecmul([1, 2], a.foo(3))
"""
c1 = side_effects_contract("uint256")
c2 = get_contract(code)

assert c2.foo(c1.address) == G1_times_three

assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
34 changes: 34 additions & 0 deletions tests/parser/functions/test_floor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,37 @@ def floor_param(p: decimal) -> int256:
assert c.fou() == -4
assert c.floor_param(Decimal("-5.6")) == -6
assert c.floor_param(Decimal("-0.0000000001")) == -1


def test_floor_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
code = """
@external
def foo(a: Foo) -> int256:
return floor(a.foo(2.5))
interface Foo:
def foo(x: decimal) -> decimal: nonpayable
"""

c1 = side_effects_contract("decimal")
c2 = get_contract(code)

assert c2.foo(c1.address) == 2

assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_floor_internal_call(get_contract_with_gas_estimation):
code = """
@external
def foo() -> int256:
return floor(self.bar())
@internal
def bar() -> decimal:
return 2.5
"""

c = get_contract_with_gas_estimation(code)

assert c.foo() == 2
75 changes: 75 additions & 0 deletions tests/parser/functions/test_mulmod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
def test_uint256_mulmod(assert_tx_failed, get_contract_with_gas_estimation):
uint256_code = """
@external
def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256:
return uint256_mulmod(x, y, z)
"""

c = get_contract_with_gas_estimation(uint256_code)

assert c._uint256_mulmod(3, 1, 2) == 1
assert c._uint256_mulmod(200, 3, 601) == 600
assert c._uint256_mulmod(2**255, 1, 3) == 2
assert c._uint256_mulmod(2**255, 2, 6) == 4
assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0))


def test_uint256_mulmod_complex(get_contract_with_gas_estimation):
modexper = """
@external
def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256:
o: uint256 = 1
for i in range(256):
o = uint256_mulmod(o, o, modulus)
if exponent & shift(1, 255 - i) != 0:
o = uint256_mulmod(o, base, modulus)
return o
"""

c = get_contract_with_gas_estimation(modexper)
assert c.exponential(3, 5, 100) == 43
assert c.exponential(2, 997, 997) == 2


def test_uint256_mulmod_ext_call(
w3, side_effects_contract, assert_side_effects_invoked, get_contract
):
code = """
@external
def foo(f: Foo) -> uint256:
return uint256_mulmod(200, 3, f.foo(601))
interface Foo:
def foo(x: uint256) -> uint256: nonpayable
"""

c1 = side_effects_contract("uint256")
c2 = get_contract(code)

assert c2.foo(c1.address) == 600

assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))


def test_uint256_mulmod_internal_call(get_contract_with_gas_estimation):
code = """
@external
def foo() -> uint256:
return uint256_mulmod(self.a(), self.b(), self.c())
@internal
def a() -> uint256:
return 200
@internal
def b() -> uint256:
return 3
@internal
def c() -> uint256:
return 601
"""

c = get_contract_with_gas_estimation(code)

assert c.foo() == 600
Loading

0 comments on commit 6ee74f5

Please sign in to comment.