From 6ee74f5af58507029192732f828e13c431c273a3 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 19 May 2023 01:05:49 +0800 Subject: [PATCH] fix: complex arguments to builtin functions (#3167) prior to this commit, some builtin functions including ceil, would panic if their arguments were function calls (or otherwise determined to be complex expressions by `is_complex_ir`). this commit fixes the relevant builtin functions by using `cache_when_complex` where appropriate. --------- Co-authored-by: Charles Cooper --- tests/conftest.py | 35 ++++ tests/parser/functions/test_addmod.py | 57 ++++++ tests/parser/functions/test_as_wei_value.py | 31 ++++ tests/parser/functions/test_ceil.py | 34 ++++ tests/parser/functions/test_ec.py | 62 +++++++ tests/parser/functions/test_floor.py | 34 ++++ tests/parser/functions/test_mulmod.py | 75 ++++++++ .../types/numbers/test_unsigned_ints.py | 43 ----- vyper/builtins/functions.py | 167 ++++++++++-------- 9 files changed, 422 insertions(+), 116 deletions(-) create mode 100644 tests/parser/functions/test_addmod.py create mode 100644 tests/parser/functions/test_as_wei_value.py create mode 100644 tests/parser/functions/test_mulmod.py diff --git a/tests/conftest.py b/tests/conftest.py index e1d0996767..1cc9e4e72e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -192,3 +192,38 @@ def _f(_addr, _salt, _initcode): return keccak(prefix + addr + salt + keccak(initcode))[12:] return _f + + +@pytest.fixture +def side_effects_contract(get_contract): + def generate(ret_type): + """ + Generates a Vyper contract with an external `foo()` function, which + returns the specified return value of the specified return type, for + testing side effects using the `assert_side_effects_invoked` fixture. + """ + code = f""" +counter: public(uint256) + +@external +def foo(s: {ret_type}) -> {ret_type}: + self.counter += 1 + return s + """ + contract = get_contract(code) + return contract + + return generate + + +@pytest.fixture +def assert_side_effects_invoked(): + def assert_side_effects_invoked(side_effects_contract, side_effects_trigger, n=1): + start_value = side_effects_contract.counter() + + side_effects_trigger() + + end_value = side_effects_contract.counter() + assert end_value == start_value + n + + return assert_side_effects_invoked diff --git a/tests/parser/functions/test_addmod.py b/tests/parser/functions/test_addmod.py new file mode 100644 index 0000000000..67a7e9b101 --- /dev/null +++ b/tests/parser/functions/test_addmod.py @@ -0,0 +1,57 @@ +def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation): + uint256_code = """ +@external +def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: + return uint256_addmod(x, y, z) + """ + + c = get_contract_with_gas_estimation(uint256_code) + + assert c._uint256_addmod(1, 2, 2) == 1 + assert c._uint256_addmod(32, 2, 32) == 2 + assert c._uint256_addmod((2**256) - 1, 0, 2) == 1 + assert c._uint256_addmod(2**255, 2**255, 6) == 4 + assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) + + +def test_uint256_addmod_ext_call( + w3, side_effects_contract, assert_side_effects_invoked, get_contract +): + code = """ +@external +def foo(f: Foo) -> uint256: + return uint256_addmod(32, 2, f.foo(32)) + +interface Foo: + def foo(x: uint256) -> uint256: payable + """ + + c1 = side_effects_contract("uint256") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 2 + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_uint256_addmod_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> uint256: + return uint256_addmod(self.a(), self.b(), self.c()) + +@internal +def a() -> uint256: + return 32 + +@internal +def b() -> uint256: + return 2 + +@internal +def c() -> uint256: + return 32 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 2 diff --git a/tests/parser/functions/test_as_wei_value.py b/tests/parser/functions/test_as_wei_value.py new file mode 100644 index 0000000000..bab0aed616 --- /dev/null +++ b/tests/parser/functions/test_as_wei_value.py @@ -0,0 +1,31 @@ +def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +@external +def foo(a: Foo) -> uint256: + return as_wei_value(a.foo(7), "ether") + +interface Foo: + def foo(x: uint8) -> uint8: nonpayable + """ + + c1 = side_effects_contract("uint8") + c2 = get_contract(code) + + assert c2.foo(c1.address) == w3.to_wei(7, "ether") + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_internal_call(w3, get_contract_with_gas_estimation): + code = """ +@external +def foo() -> uint256: + return as_wei_value(self.bar(), "ether") + +@internal +def bar() -> uint8: + return 7 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == w3.to_wei(7, "ether") diff --git a/tests/parser/functions/test_ceil.py b/tests/parser/functions/test_ceil.py index a9bcf62da2..daa9cb7c1b 100644 --- a/tests/parser/functions/test_ceil.py +++ b/tests/parser/functions/test_ceil.py @@ -104,3 +104,37 @@ def ceil_param(p: decimal) -> int256: assert c.fou() == -3 assert c.ceil_param(Decimal("-0.5")) == 0 assert c.ceil_param(Decimal("-7777777.7777777")) == -7777777 + + +def test_ceil_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +@external +def foo(a: Foo) -> int256: + return ceil(a.foo(2.5)) + +interface Foo: + def foo(x: decimal) -> decimal: payable + """ + + c1 = side_effects_contract("decimal") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 3 + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_ceil_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> int256: + return ceil(self.bar()) + +@internal +def bar() -> decimal: + return 2.5 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 3 diff --git a/tests/parser/functions/test_ec.py b/tests/parser/functions/test_ec.py index be0f6f7ed2..9ce37d0721 100644 --- a/tests/parser/functions/test_ec.py +++ b/tests/parser/functions/test_ec.py @@ -45,6 +45,37 @@ def _ecadd3(x: uint256[2], y: uint256[2]) -> uint256[2]: assert c._ecadd3(G1, negative_G1) == [0, 0] +def test_ecadd_internal_call(get_contract_with_gas_estimation): + code = """ +@internal +def a() -> uint256[2]: + return [1, 2] + +@external +def foo() -> uint256[2]: + return ecadd([1, 2], self.a()) + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == G1_times_two + + +def test_ecadd_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +interface Foo: + def foo(x: uint256[2]) -> uint256[2]: payable + +@external +def foo(a: Foo) -> uint256[2]: + return ecadd([1, 2], a.foo([1, 2])) + """ + c1 = side_effects_contract("uint256[2]") + c2 = get_contract(code) + + assert c2.foo(c1.address) == G1_times_two + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + def test_ecmul(get_contract_with_gas_estimation): ecmuller = """ x3: uint256[2] @@ -74,3 +105,34 @@ def _ecmul3(x: uint256[2], y: uint256) -> uint256[2]: assert c._ecmul(G1, 3) == G1_times_three assert c._ecmul(G1, curve_order - 1) == negative_G1 assert c._ecmul(G1, curve_order) == [0, 0] + + +def test_ecmul_internal_call(get_contract_with_gas_estimation): + code = """ +@internal +def a() -> uint256: + return 3 + +@external +def foo() -> uint256[2]: + return ecmul([1, 2], self.a()) + """ + c = get_contract_with_gas_estimation(code) + assert c.foo() == G1_times_three + + +def test_ecmul_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +interface Foo: + def foo(x: uint256) -> uint256: payable + +@external +def foo(a: Foo) -> uint256[2]: + return ecmul([1, 2], a.foo(3)) + """ + c1 = side_effects_contract("uint256") + c2 = get_contract(code) + + assert c2.foo(c1.address) == G1_times_three + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) diff --git a/tests/parser/functions/test_floor.py b/tests/parser/functions/test_floor.py index dc53545ac3..d2fd993785 100644 --- a/tests/parser/functions/test_floor.py +++ b/tests/parser/functions/test_floor.py @@ -108,3 +108,37 @@ def floor_param(p: decimal) -> int256: assert c.fou() == -4 assert c.floor_param(Decimal("-5.6")) == -6 assert c.floor_param(Decimal("-0.0000000001")) == -1 + + +def test_floor_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract): + code = """ +@external +def foo(a: Foo) -> int256: + return floor(a.foo(2.5)) + +interface Foo: + def foo(x: decimal) -> decimal: nonpayable + """ + + c1 = side_effects_contract("decimal") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 2 + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_floor_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> int256: + return floor(self.bar()) + +@internal +def bar() -> decimal: + return 2.5 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 2 diff --git a/tests/parser/functions/test_mulmod.py b/tests/parser/functions/test_mulmod.py new file mode 100644 index 0000000000..1ea7a3f8e8 --- /dev/null +++ b/tests/parser/functions/test_mulmod.py @@ -0,0 +1,75 @@ +def test_uint256_mulmod(assert_tx_failed, get_contract_with_gas_estimation): + uint256_code = """ +@external +def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: + return uint256_mulmod(x, y, z) + """ + + c = get_contract_with_gas_estimation(uint256_code) + + assert c._uint256_mulmod(3, 1, 2) == 1 + assert c._uint256_mulmod(200, 3, 601) == 600 + assert c._uint256_mulmod(2**255, 1, 3) == 2 + assert c._uint256_mulmod(2**255, 2, 6) == 4 + assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) + + +def test_uint256_mulmod_complex(get_contract_with_gas_estimation): + modexper = """ +@external +def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: + o: uint256 = 1 + for i in range(256): + o = uint256_mulmod(o, o, modulus) + if exponent & shift(1, 255 - i) != 0: + o = uint256_mulmod(o, base, modulus) + return o + """ + + c = get_contract_with_gas_estimation(modexper) + assert c.exponential(3, 5, 100) == 43 + assert c.exponential(2, 997, 997) == 2 + + +def test_uint256_mulmod_ext_call( + w3, side_effects_contract, assert_side_effects_invoked, get_contract +): + code = """ +@external +def foo(f: Foo) -> uint256: + return uint256_mulmod(200, 3, f.foo(601)) + +interface Foo: + def foo(x: uint256) -> uint256: nonpayable + """ + + c1 = side_effects_contract("uint256") + c2 = get_contract(code) + + assert c2.foo(c1.address) == 600 + + assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={})) + + +def test_uint256_mulmod_internal_call(get_contract_with_gas_estimation): + code = """ +@external +def foo() -> uint256: + return uint256_mulmod(self.a(), self.b(), self.c()) + +@internal +def a() -> uint256: + return 200 + +@internal +def b() -> uint256: + return 3 + +@internal +def c() -> uint256: + return 601 + """ + + c = get_contract_with_gas_estimation(code) + + assert c.foo() == 600 diff --git a/tests/parser/types/numbers/test_unsigned_ints.py b/tests/parser/types/numbers/test_unsigned_ints.py index 82c0f8484c..683684e6be 100644 --- a/tests/parser/types/numbers/test_unsigned_ints.py +++ b/tests/parser/types/numbers/test_unsigned_ints.py @@ -195,49 +195,6 @@ def foo(x: {typ}, y: {typ}) -> bool: assert c.foo(x, y) is expected -# TODO move to tests/parser/functions/test_mulmod.py and test_addmod.py -def test_uint256_mod(assert_tx_failed, get_contract_with_gas_estimation): - uint256_code = """ -@external -def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: - return uint256_addmod(x, y, z) - -@external -def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: - return uint256_mulmod(x, y, z) - """ - - c = get_contract_with_gas_estimation(uint256_code) - - assert c._uint256_addmod(1, 2, 2) == 1 - assert c._uint256_addmod(32, 2, 32) == 2 - assert c._uint256_addmod((2**256) - 1, 0, 2) == 1 - assert c._uint256_addmod(2**255, 2**255, 6) == 4 - assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) - assert c._uint256_mulmod(3, 1, 2) == 1 - assert c._uint256_mulmod(200, 3, 601) == 600 - assert c._uint256_mulmod(2**255, 1, 3) == 2 - assert c._uint256_mulmod(2**255, 2, 6) == 4 - assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) - - -def test_uint256_modmul(get_contract_with_gas_estimation): - modexper = """ -@external -def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: - o: uint256 = 1 - for i in range(256): - o = uint256_mulmod(o, o, modulus) - if exponent & (1 << (255 - i)) != 0: - o = uint256_mulmod(o, base, modulus) - return o - """ - - c = get_contract_with_gas_estimation(modexper) - assert c.exponential(3, 5, 100) == 43 - assert c.exponential(2, 997, 997) == 2 - - @pytest.mark.parametrize("typ", types) def test_uint_literal(get_contract, assert_compile_failed, typ): lo, hi = typ.ast_bounds diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index bfe90bb669..915f10ede3 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -148,15 +148,18 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - return IRnode.from_list( - [ - "if", - ["slt", args[0], 0], - ["sdiv", ["sub", args[0], DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], - ["sdiv", args[0], DECIMAL_DIVISOR], - ], - typ=INT256_T, - ) + arg = args[0] + with arg.cache_when_complex("arg") as (b1, arg): + ret = IRnode.from_list( + [ + "if", + ["slt", arg, 0], + ["sdiv", ["sub", arg, DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], + ["sdiv", arg, DECIMAL_DIVISOR], + ], + typ=INT256_T, + ) + return b1.resolve(ret) class Ceil(BuiltinFunction): @@ -175,15 +178,18 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - return IRnode.from_list( - [ - "if", - ["slt", args[0], 0], - ["sdiv", args[0], DECIMAL_DIVISOR], - ["sdiv", ["add", args[0], DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], - ], - typ=INT256_T, - ) + arg = args[0] + with arg.cache_when_complex("arg") as (b1, arg): + ret = IRnode.from_list( + [ + "if", + ["slt", arg, 0], + ["sdiv", arg, DECIMAL_DIVISOR], + ["sdiv", ["add", arg, DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR], + ], + typ=INT256_T, + ) + return b1.resolve(ret) class Convert(BuiltinFunction): @@ -800,20 +806,25 @@ def build_IR(self, expr, args, kwargs, context): placeholder_node = IRnode.from_list( context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY ) - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(args[0], 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(args[0], 1)], - ["mstore", ["add", placeholder_node, 64], _getelem(args[1], 0)], - ["mstore", ["add", placeholder_node, 96], _getelem(args[1], 1)], - ["assert", ["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64]], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return o + + with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): + o = IRnode.from_list( + [ + "seq", + ["mstore", placeholder_node, _getelem(a, 0)], + ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], + ["mstore", ["add", placeholder_node, 64], _getelem(b, 0)], + ["mstore", ["add", placeholder_node, 96], _getelem(b, 1)], + [ + "assert", + ["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64], + ], + placeholder_node, + ], + typ=SArrayT(UINT256_T, 2), + location=MEMORY, + ) + return b2.resolve(b1.resolve(o)) class ECMul(BuiltinFunction): @@ -826,19 +837,24 @@ def build_IR(self, expr, args, kwargs, context): placeholder_node = IRnode.from_list( context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY ) - o = IRnode.from_list( - [ - "seq", - ["mstore", placeholder_node, _getelem(args[0], 0)], - ["mstore", ["add", placeholder_node, 32], _getelem(args[0], 1)], - ["mstore", ["add", placeholder_node, 64], args[1]], - ["assert", ["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64]], - placeholder_node, - ], - typ=SArrayT(UINT256_T, 2), - location=MEMORY, - ) - return o + + with args[0].cache_when_complex("a") as (b1, a), args[1].cache_when_complex("b") as (b2, b): + o = IRnode.from_list( + [ + "seq", + ["mstore", placeholder_node, _getelem(a, 0)], + ["mstore", ["add", placeholder_node, 32], _getelem(a, 1)], + ["mstore", ["add", placeholder_node, 64], b], + [ + "assert", + ["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64], + ], + placeholder_node, + ], + typ=SArrayT(UINT256_T, 2), + location=MEMORY, + ) + return b2.resolve(b1.resolve(o)) def _generic_element_getter(op): @@ -1030,34 +1046,35 @@ def build_IR(self, expr, args, kwargs, context): value = args[0] denom_divisor = self.get_denomination(expr) - if value.typ in (UINT256_T, UINT8_T): - sub = [ - "with", - "ans", - ["mul", value, denom_divisor], - [ - "seq", + with value.cache_when_complex("value") as (b1, value): + if value.typ in (UINT256_T, UINT8_T): + sub = [ + "with", + "ans", + ["mul", value, denom_divisor], [ - "assert", - ["or", ["eq", ["div", "ans", value], denom_divisor], ["iszero", value]], + "seq", + [ + "assert", + ["or", ["eq", ["div", "ans", value], denom_divisor], ["iszero", value]], + ], + "ans", ], - "ans", - ], - ] - elif value.typ == INT128_T: - # signed types do not require bounds checks because the - # largest possible converted value will not overflow 2**256 - sub = ["seq", ["assert", ["sgt", value, -1]], ["mul", value, denom_divisor]] - elif value.typ == DecimalT(): - sub = [ - "seq", - ["assert", ["sgt", value, -1]], - ["div", ["mul", value, denom_divisor], DECIMAL_DIVISOR], - ] - else: - raise CompilerPanic(f"Unexpected type: {value.typ}") + ] + elif value.typ == INT128_T: + # signed types do not require bounds checks because the + # largest possible converted value will not overflow 2**256 + sub = ["seq", ["assert", ["sgt", value, -1]], ["mul", value, denom_divisor]] + elif value.typ == DecimalT(): + sub = [ + "seq", + ["assert", ["sgt", value, -1]], + ["div", ["mul", value, denom_divisor], DECIMAL_DIVISOR], + ] + else: + raise CompilerPanic(f"Unexpected type: {value.typ}") - return IRnode.from_list(sub, typ=UINT256_T) + return IRnode.from_list(b1.resolve(sub), typ=UINT256_T) zero_value = IRnode.from_list(0, typ=UINT256_T) @@ -1516,9 +1533,13 @@ def evaluate(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - return IRnode.from_list( - ["seq", ["assert", args[2]], [self._opcode, args[0], args[1], args[2]]], typ=UINT256_T - ) + c = args[2] + + with c.cache_when_complex("c") as (b1, c): + ret = IRnode.from_list( + ["seq", ["assert", c], [self._opcode, args[0], args[1], c]], typ=UINT256_T + ) + return b1.resolve(ret) class AddMod(_AddMulMod):