diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst index ab3ebb3fff..e56dc680bf 100644 --- a/docs/built-in-functions.rst +++ b/docs/built-in-functions.rst @@ -106,6 +106,11 @@ Bitwise Operations >>> ExampleContract.foo(2, 8) 512 +.. note:: + + This function has been deprecated from version 0.3.8 onwards. Please use the ``<<`` and ``>>`` operators instead. + + Chain Interaction ================= diff --git a/docs/types.rst b/docs/types.rst index 5e82c6ecba..2c3a1e75d8 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -115,6 +115,23 @@ Operator Description ``x`` and ``y`` must be of the same type. +Shifts +^^^^^^^^^^^^^^^^^ + +============= ====================== +Operator Description +============= ====================== +``x << y`` Left shift +``x >> y`` Right shift +============= ====================== + +Shifting is only available for 256-bit wide types. That is, ``x`` must be ``int256``, and ``y`` can be any unsigned integer. The right shift for ``int256`` compiles to a signed right shift (EVM ``SAR`` instruction). + + +.. note:: + While at runtime shifts are unchecked (that is, they can be for any number of bits), to prevent common mistakes, the compiler is stricter at compile-time and will prevent out of bounds shifts. For instance, at runtime, ``1 << 257`` will evaluate to ``0``, while that expression at compile-time will raise an ``OverflowException``. + + .. index:: ! uint, ! uintN, ! unsigned integer Unsigned Integer (N bit) @@ -188,6 +205,24 @@ Operator Description .. note:: The Bitwise ``not`` operator is currently only available for ``uint256`` type. +Shifts +^^^^^^^^^^^^^^^^^ + +============= ====================== +Operator Description +============= ====================== +``x << y`` Left shift +``x >> y`` Right shift +============= ====================== + +Shifting is only available for 256-bit wide types. That is, ``x`` must be ``uint256``, and ``y`` can be any unsigned integer. The right shift for ``uint256`` compiles to a signed right shift (EVM ``SHR`` instruction). + + +.. note:: + While at runtime shifts are unchecked (that is, they can be for any number of bits), to prevent common mistakes, the compiler is stricter at compile-time and will prevent out of bounds shifts. For instance, at runtime, ``1 << 257`` will evaluate to ``0``, while that expression at compile-time will raise an ``OverflowException``. + + + Decimals -------- diff --git a/tests/builtins/folding/test_bitwise.py b/tests/builtins/folding/test_bitwise.py index 9be0b6817d..385cd43084 100644 --- a/tests/builtins/folding/test_bitwise.py +++ b/tests/builtins/folding/test_bitwise.py @@ -3,21 +3,27 @@ from hypothesis import strategies as st from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from vyper.exceptions import OverflowException +from vyper.semantics.analysis.utils import validate_expected_type +from vyper.semantics.types.shortcuts import INT256_T, UINT256_T +from vyper.utils import unsigned_to_signed st_uint256 = st.integers(min_value=0, max_value=2**256 - 1) +st_sint256 = st.integers(min_value=-(2**255), max_value=2**255 - 1) + @pytest.mark.fuzzing @settings(max_examples=50, deadline=1000) -@given(a=st_uint256, b=st_uint256) @pytest.mark.parametrize("op", ["&", "|", "^"]) -def test_bitwise_and_or(get_contract, a, b, op): +@given(a=st_uint256, b=st_uint256) +def test_bitwise_ops(get_contract, a, b, op): source = f""" @external def foo(a: uint256, b: uint256) -> uint256: return a {op} b """ + contract = get_contract(source) vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") @@ -29,35 +35,74 @@ def foo(a: uint256, b: uint256) -> uint256: @pytest.mark.fuzzing @settings(max_examples=50, deadline=1000) -@given(value=st_uint256) -def test_bitwise_not(get_contract, value): - source = """ +@pytest.mark.parametrize("op", ["<<", ">>"]) +@given(a=st_uint256, b=st.integers(min_value=0, max_value=256)) +def test_bitwise_shift_unsigned(get_contract, a, b, op): + source = f""" @external -def foo(a: uint256) -> uint256: - return ~a +def foo(a: uint256, b: uint256) -> uint256: + return a {op} b """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"~{value}") + vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") old_node = vyper_ast.body[0].value - new_node = old_node.evaluate() - assert contract.foo(value) == new_node.value + try: + new_node = old_node.evaluate() + # force bounds check, no-op because validate_numeric_bounds + # already does this, but leave in for hygiene (in case + # more types are added). + validate_expected_type(new_node, UINT256_T) + # compile time behavior does not match runtime behavior. + # compile-time will throw on OOB, runtime will wrap. + except OverflowException: # here: check the wrapped value matches runtime + assert op == "<<" + assert contract.foo(a, b) == (a << b) % (2**256) + else: + assert contract.foo(a, b) == new_node.value + + +@pytest.mark.fuzzing +@settings(max_examples=50, deadline=1000) +@pytest.mark.parametrize("op", ["<<", ">>"]) +@given(a=st_sint256, b=st.integers(min_value=0, max_value=256)) +def test_bitwise_shift_signed(get_contract, a, b, op): + source = f""" +@external +def foo(a: int256, b: uint256) -> int256: + return a {op} b + """ + contract = get_contract(source) + + vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value + + try: + new_node = old_node.evaluate() + validate_expected_type(new_node, INT256_T) # force bounds check + # compile time behavior does not match runtime behavior. + # compile-time will throw on OOB, runtime will wrap. + except OverflowException: # here: check the wrapped value matches runtime + assert op == "<<" + assert contract.foo(a, b) == unsigned_to_signed((a << b) % (2**256), 256) + else: + assert contract.foo(a, b) == new_node.value @pytest.mark.fuzzing @settings(max_examples=50, deadline=1000) -@given(value=st_uint256, steps=st.integers(min_value=-256, max_value=256)) -def test_shift(get_contract, value, steps): +@given(value=st_uint256) +def test_bitwise_not(get_contract, value): source = """ @external -def foo(a: uint256, b: int128) -> uint256: - return shift(a, b) +def foo(a: uint256) -> uint256: + return ~a """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"shift({value}, {steps})") + vyper_ast = vy_ast.parse_to_ast(f"~{value}") old_node = vyper_ast.body[0].value - new_node = vy_fn.Shift().evaluate(old_node) + new_node = old_node.evaluate() - assert contract.foo(value, steps) == new_node.value + assert contract.foo(value) == new_node.value diff --git a/tests/parser/functions/test_bitwise.py b/tests/parser/functions/test_bitwise.py index 171ba7daeb..800803907a 100644 --- a/tests/parser/functions/test_bitwise.py +++ b/tests/parser/functions/test_bitwise.py @@ -2,7 +2,8 @@ from vyper.compiler import compile_code from vyper.evm.opcodes import EVM_VERSIONS -from vyper.exceptions import InvalidLiteral, TypeMismatch +from vyper.exceptions import InvalidLiteral, InvalidOperation, TypeMismatch +from vyper.utils import unsigned_to_signed code = """ @external @@ -22,12 +23,12 @@ def _bitwise_not(x: uint256) -> uint256: return ~x @external -def _shift(x: uint256, y: int128) -> uint256: - return shift(x, y) +def _shl(x: uint256, y: uint256) -> uint256: + return x << y @external -def _negatedShift(x: uint256, y: int128) -> uint256: - return shift(x, -y) +def _shr(x: uint256, y: uint256) -> uint256: + return x >> y """ @@ -51,22 +52,11 @@ def test_test_bitwise(get_contract_with_gas_estimation, evm_version): assert c._bitwise_or(x, y) == (x | y) assert c._bitwise_xor(x, y) == (x ^ y) assert c._bitwise_not(x) == 2**256 - 1 - x - assert c._shift(x, 3) == x * 8 - assert c._shift(x, 255) == 0 - assert c._shift(y, 255) == 2**255 - assert c._shift(x, 256) == 0 - assert c._shift(x, 0) == x - assert c._shift(x, -1) == x // 2 - assert c._shift(x, -3) == x // 8 - assert c._shift(x, -256) == 0 - assert c._negatedShift(x, -3) == x * 8 - assert c._negatedShift(x, -255) == 0 - assert c._negatedShift(y, -255) == 2**255 - assert c._negatedShift(x, -256) == 0 - assert c._negatedShift(x, -0) == x - assert c._negatedShift(x, 1) == x // 2 - assert c._negatedShift(x, 3) == x // 8 - assert c._negatedShift(x, 256) == 0 + + for t in (x, y): + for s in (0, 1, 3, 255, 256): + assert c._shr(t, s) == t >> s + assert c._shl(t, s) == (t << s) % (2**256) POST_BYZANTIUM = [k for (k, v) in EVM_VERSIONS.items() if v > 0] @@ -76,8 +66,12 @@ def test_test_bitwise(get_contract_with_gas_estimation, evm_version): def test_signed_shift(get_contract_with_gas_estimation, evm_version): code = """ @external -def _signedShift(x: int256, y: int128) -> int256: - return shift(x, y) +def _sar(x: int256, y: uint256) -> int256: + return x >> y + +@external +def _shl(x: int256, y: uint256) -> int256: + return x << y """ c = get_contract_with_gas_estimation(code, evm_version=evm_version) x = 126416208461208640982146408124 @@ -85,10 +79,9 @@ def _signedShift(x: int256, y: int128) -> int256: cases = [x, y, -x, -y] for t in cases: - assert c._signedShift(t, 0) == t >> 0 - assert c._signedShift(t, -1) == t >> 1 - assert c._signedShift(t, -3) == t >> 3 - assert c._signedShift(t, -256) == t >> 256 + for s in (0, 1, 3, 255, 256): + assert c._sar(t, s) == t >> s + assert c._shl(t, s) == unsigned_to_signed((t << s) % (2**256), 256) def test_precedence(get_contract): @@ -115,41 +108,71 @@ def baz(a: uint256, b: uint256, c: uint256) -> (uint256, uint256): def test_literals(get_contract, evm_version): code = """ @external -def left(x: uint256) -> uint256: - return shift(x, -3) +def _shr(x: uint256) -> uint256: + return x >> 3 @external -def right(x: uint256) -> uint256: - return shift(x, 3) +def _shl(x: uint256) -> uint256: + return x << 3 """ c = get_contract(code, evm_version=evm_version) - assert c.left(80) == 10 - assert c.right(80) == 640 + assert c._shr(80) == 10 + assert c._shl(80) == 640 fail_list = [ ( + # cannot shift non-uint256/int256 argument + """ +@external +def foo(x: uint8, y: uint8) -> uint8: + return x << y + """, + InvalidOperation, + ), + ( + # cannot shift non-uint256/int256 argument + """ +@external +def foo(x: int8, y: uint8) -> int8: + return x << y + """, + InvalidOperation, + ), + ( + # cannot shift by non-uint bits """ @external -def foo(x: uint8, y: int128) -> uint256: - return shift(x, y) +def foo(x: uint256, y: int128) -> uint256: + return x << y """, TypeMismatch, ), ( + # cannot left shift by more than 256 bits + """ +@external +def foo() -> uint256: + return 2 << 257 + """, + InvalidLiteral, + ), + ( + # cannot shift by negative amount """ @external def foo() -> uint256: - return shift(2, 257) + return 2 << -1 """, InvalidLiteral, ), ( + # cannot shift by negative amount """ @external def foo() -> uint256: - return shift(2, -257) + return 2 << -1 """, InvalidLiteral, ), diff --git a/tests/parser/types/numbers/test_unsigned_ints.py b/tests/parser/types/numbers/test_unsigned_ints.py index 97a4097923..82c0f8484c 100644 --- a/tests/parser/types/numbers/test_unsigned_ints.py +++ b/tests/parser/types/numbers/test_unsigned_ints.py @@ -228,7 +228,7 @@ def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: o: uint256 = 1 for i in range(256): o = uint256_mulmod(o, o, modulus) - if exponent & shift(1, 255 - i) != 0: + if exponent & (1 << (255 - i)) != 0: o = uint256_mulmod(o, base, modulus) return o """ diff --git a/tests/parser/types/test_bytes.py b/tests/parser/types/test_bytes.py index 28602d61b1..01ec75d5c1 100644 --- a/tests/parser/types/test_bytes.py +++ b/tests/parser/types/test_bytes.py @@ -268,9 +268,8 @@ def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value for _ in range(8): - y = shift(y, 8) - y = y + (x & 255) - x = shift(x, -8) + y = (y << 8) | (x & 255) + x >>= 8 return slice(convert(y, bytes32), 24, 8) @external diff --git a/tests/parser/types/test_bytes_zero_padding.py b/tests/parser/types/test_bytes_zero_padding.py index 9bc774f12f..ee938fdffb 100644 --- a/tests/parser/types/test_bytes_zero_padding.py +++ b/tests/parser/types/test_bytes_zero_padding.py @@ -11,9 +11,8 @@ def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value for _ in range(8): - y = shift(y, 8) - y = y + (x & 255) - x = shift(x, -8) + y = (y << 8) | (x & 255) + x >>= 8 return slice(convert(y, bytes32), 24, 8) @external diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index a553de82a8..19e50d8895 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -960,6 +960,12 @@ def evaluate(self) -> ExprNode: if not isinstance(left, (Int, Decimal)): raise UnfoldableNode("Node contains invalid field(s) for evaluation") + # this validation is performed to prevent the compiler from hanging + # on very large shifts and improve the error message for negative + # values. + if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): + raise InvalidLiteral("Shift bits must be between 0 and 256", right) + value = self.op._op(left.value, right.value) _validate_numeric_bounds(self, value) return type(left).from_node(self, value=value) @@ -1072,6 +1078,20 @@ class BitXor(Operator): _op = operator.xor +class LShift(Operator): + __slots__ = () + _description = "bitwise left shift" + _pretty = "<<" + _op = operator.lshift + + +class RShift(Operator): + __slots__ = () + _description = "bitwise right shift" + _pretty = ">>" + _op = operator.rshift + + class BoolOp(ExprNode): __slots__ = ("op", "values") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 942640b6e2..93563516f3 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -155,6 +155,8 @@ class Mult(VyperNode): ... class Div(VyperNode): ... class Mod(VyperNode): ... class Pow(VyperNode): ... +class LShift(VyperNode): ... +class RShift(VyperNode): ... class BitAnd(VyperNode): ... class BitOr(VyperNode): ... class BitXor(VyperNode): ... diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index f81fb20a64..e71da851cd 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -1432,10 +1432,15 @@ def build_IR(self, expr, args, kwargs, context): class Shift(BuiltinFunction): _id = "shift" - _inputs = [("x", (UINT256_T, INT256_T)), ("_shift", IntegerT.any())] + _inputs = [("x", (UINT256_T, INT256_T)), ("_shift_bits", IntegerT.any())] _return_type = UINT256_T + _warned = False def evaluate(self, node): + if not self.__class__._warned: + vyper_warn("`shift()` is deprecated! Please use the << or >> operator instead.") + self.__class__._warned = True + validate_call_args(node, 2) if [i for i in node.args if not isinstance(i, vy_ast.Num)]: raise UnfoldableNode diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 33c400941e..908f410321 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -17,6 +17,9 @@ is_numeric_type, is_tuple_like, pop_dyn_array, + sar, + shl, + shr, unwrap_location, ) from vyper.codegen.ir_node import IRnode @@ -357,9 +360,11 @@ def parse_BinOp(self): left = Expr.parse_value_expr(self.expr.left, self.context) right = Expr.parse_value_expr(self.expr.right, self.context) - # Sanity check - ensure that we aren't dealing with different types - # This should be unreachable due to the type check pass - assert left.typ == right.typ, f"unreachable, {left.typ}!={right.typ}" + if not isinstance(self.expr.op, (vy_ast.LShift, vy_ast.RShift)): + # Sanity check - ensure that we aren't dealing with different types + # This should be unreachable due to the type check pass + assert left.typ == right.typ, f"unreachable, {left.typ} != {right.typ}" + assert is_numeric_type(left.typ) or is_enum_type(left.typ) out_typ = left.typ @@ -371,6 +376,21 @@ def parse_BinOp(self): if isinstance(self.expr.op, vy_ast.BitXor): return IRnode.from_list(["xor", left, right], typ=out_typ) + if isinstance(self.expr.op, vy_ast.LShift): + new_typ = left.typ + if new_typ.bits != 256: + # TODO implement me. ["and", 2**bits - 1, shl(right, left)] + return + return IRnode.from_list(shl(right, left), typ=new_typ) + if isinstance(self.expr.op, vy_ast.RShift): + new_typ = left.typ + if new_typ.bits != 256: + # TODO implement me. promote_signed_int(op(right, left), bits) + return + op = shr if not left.typ.is_signed else sar + # note: sar NotImplementedError for pre-constantinople + return IRnode.from_list(op(right, left), typ=new_typ) + # enums can only do bit ops, not arithmetic. assert is_numeric_type(left.typ) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 15b79dde44..0ae59e4e5f 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -198,7 +198,14 @@ def types_from_Attribute(self, node): def types_from_BinOp(self, node): # binary operation: `x + y` - types_list = get_common_types(node.left, node.right) + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)): + # ad-hoc handling for LShift and RShift, since operands + # can be different types + types_list = get_possible_types_from_node(node.left) + # check rhs is unsigned integer + validate_expected_type(node.right, IntegerT.unsigneds()) + else: + types_list = get_common_types(node.left, node.right) if ( isinstance(node.op, (vy_ast.Div, vy_ast.Mod)) diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 5b1ab5ab8e..07d1a21a94 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -141,14 +141,23 @@ def validate_numeric_op( if isinstance(node.op, self._invalid_ops): self._raise_invalid_op(node) - if isinstance(node.op, vy_ast.Pow): + def _get_lr(): if isinstance(node, vy_ast.BinOp): - left, right = node.left, node.right + return node.left, node.right elif isinstance(node, vy_ast.AugAssign): - left, right = node.target, node.value + return node.target, node.value else: raise CompilerPanic(f"Unexpected node type for numeric op: {type(node).__name__}") + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)): + if self._bits != 256: + raise InvalidOperation( + f"Cannot perform {node.op.description} on non-int256/uint256 type!", node + ) + + if isinstance(node.op, vy_ast.Pow): + left, right = _get_lr() + value_bits = self._bits - (1 if self._is_signed else 0) # TODO double check: this code seems duplicated with constant eval