Skip to content

Commit

Permalink
feat: shift operators (#3019)
Browse files Browse the repository at this point in the history
enable `x << y` and `x >> y`, and deprecate the `shift()` builtin
  • Loading branch information
charles-cooper authored Apr 24, 2023
1 parent 4ae20aa commit 7a64d4b
Show file tree
Hide file tree
Showing 13 changed files with 239 additions and 70 deletions.
5 changes: 5 additions & 0 deletions docs/built-in-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ Bitwise Operations
>>> ExampleContract.foo(2, 8)
512
.. note::

This function has been deprecated from version 0.3.8 onwards. Please use the ``<<`` and ``>>`` operators instead.


Chain Interaction
=================

Expand Down
35 changes: 35 additions & 0 deletions docs/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ Operator Description

``x`` and ``y`` must be of the same type.

Shifts
^^^^^^^^^^^^^^^^^

============= ======================
Operator Description
============= ======================
``x << y`` Left shift
``x >> y`` Right shift
============= ======================

Shifting is only available for 256-bit wide types. That is, ``x`` must be ``int256``, and ``y`` can be any unsigned integer. The right shift for ``int256`` compiles to a signed right shift (EVM ``SAR`` instruction).


.. note::
While at runtime shifts are unchecked (that is, they can be for any number of bits), to prevent common mistakes, the compiler is stricter at compile-time and will prevent out of bounds shifts. For instance, at runtime, ``1 << 257`` will evaluate to ``0``, while that expression at compile-time will raise an ``OverflowException``.


.. index:: ! uint, ! uintN, ! unsigned integer

Unsigned Integer (N bit)
Expand Down Expand Up @@ -188,6 +205,24 @@ Operator Description
.. note::
The Bitwise ``not`` operator is currently only available for ``uint256`` type.

Shifts
^^^^^^^^^^^^^^^^^

============= ======================
Operator Description
============= ======================
``x << y`` Left shift
``x >> y`` Right shift
============= ======================

Shifting is only available for 256-bit wide types. That is, ``x`` must be ``uint256``, and ``y`` can be any unsigned integer. The right shift for ``uint256`` compiles to a signed right shift (EVM ``SHR`` instruction).


.. note::
While at runtime shifts are unchecked (that is, they can be for any number of bits), to prevent common mistakes, the compiler is stricter at compile-time and will prevent out of bounds shifts. For instance, at runtime, ``1 << 257`` will evaluate to ``0``, while that expression at compile-time will raise an ``OverflowException``.



Decimals
--------

Expand Down
81 changes: 63 additions & 18 deletions tests/builtins/folding/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,27 @@
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from vyper.exceptions import OverflowException
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.types.shortcuts import INT256_T, UINT256_T
from vyper.utils import unsigned_to_signed

st_uint256 = st.integers(min_value=0, max_value=2**256 - 1)

st_sint256 = st.integers(min_value=-(2**255), max_value=2**255 - 1)


@pytest.mark.fuzzing
@settings(max_examples=50, deadline=1000)
@given(a=st_uint256, b=st_uint256)
@pytest.mark.parametrize("op", ["&", "|", "^"])
def test_bitwise_and_or(get_contract, a, b, op):
@given(a=st_uint256, b=st_uint256)
def test_bitwise_ops(get_contract, a, b, op):
source = f"""
@external
def foo(a: uint256, b: uint256) -> uint256:
return a {op} b
"""

contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
Expand All @@ -29,35 +35,74 @@ def foo(a: uint256, b: uint256) -> uint256:

@pytest.mark.fuzzing
@settings(max_examples=50, deadline=1000)
@given(value=st_uint256)
def test_bitwise_not(get_contract, value):
source = """
@pytest.mark.parametrize("op", ["<<", ">>"])
@given(a=st_uint256, b=st.integers(min_value=0, max_value=256))
def test_bitwise_shift_unsigned(get_contract, a, b, op):
source = f"""
@external
def foo(a: uint256) -> uint256:
return ~a
def foo(a: uint256, b: uint256) -> uint256:
return a {op} b
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"~{value}")
vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.evaluate()

assert contract.foo(value) == new_node.value
try:
new_node = old_node.evaluate()
# force bounds check, no-op because validate_numeric_bounds
# already does this, but leave in for hygiene (in case
# more types are added).
validate_expected_type(new_node, UINT256_T)
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except OverflowException: # here: check the wrapped value matches runtime
assert op == "<<"
assert contract.foo(a, b) == (a << b) % (2**256)
else:
assert contract.foo(a, b) == new_node.value


@pytest.mark.fuzzing
@settings(max_examples=50, deadline=1000)
@pytest.mark.parametrize("op", ["<<", ">>"])
@given(a=st_sint256, b=st.integers(min_value=0, max_value=256))
def test_bitwise_shift_signed(get_contract, a, b, op):
source = f"""
@external
def foo(a: int256, b: uint256) -> int256:
return a {op} b
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value

try:
new_node = old_node.evaluate()
validate_expected_type(new_node, INT256_T) # force bounds check
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except OverflowException: # here: check the wrapped value matches runtime
assert op == "<<"
assert contract.foo(a, b) == unsigned_to_signed((a << b) % (2**256), 256)
else:
assert contract.foo(a, b) == new_node.value


@pytest.mark.fuzzing
@settings(max_examples=50, deadline=1000)
@given(value=st_uint256, steps=st.integers(min_value=-256, max_value=256))
def test_shift(get_contract, value, steps):
@given(value=st_uint256)
def test_bitwise_not(get_contract, value):
source = """
@external
def foo(a: uint256, b: int128) -> uint256:
return shift(a, b)
def foo(a: uint256) -> uint256:
return ~a
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"shift({value}, {steps})")
vyper_ast = vy_ast.parse_to_ast(f"~{value}")
old_node = vyper_ast.body[0].value
new_node = vy_fn.Shift().evaluate(old_node)
new_node = old_node.evaluate()

assert contract.foo(value, steps) == new_node.value
assert contract.foo(value) == new_node.value
97 changes: 60 additions & 37 deletions tests/parser/functions/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from vyper.compiler import compile_code
from vyper.evm.opcodes import EVM_VERSIONS
from vyper.exceptions import InvalidLiteral, TypeMismatch
from vyper.exceptions import InvalidLiteral, InvalidOperation, TypeMismatch
from vyper.utils import unsigned_to_signed

code = """
@external
Expand All @@ -22,12 +23,12 @@ def _bitwise_not(x: uint256) -> uint256:
return ~x
@external
def _shift(x: uint256, y: int128) -> uint256:
return shift(x, y)
def _shl(x: uint256, y: uint256) -> uint256:
return x << y
@external
def _negatedShift(x: uint256, y: int128) -> uint256:
return shift(x, -y)
def _shr(x: uint256, y: uint256) -> uint256:
return x >> y
"""


Expand All @@ -51,22 +52,11 @@ def test_test_bitwise(get_contract_with_gas_estimation, evm_version):
assert c._bitwise_or(x, y) == (x | y)
assert c._bitwise_xor(x, y) == (x ^ y)
assert c._bitwise_not(x) == 2**256 - 1 - x
assert c._shift(x, 3) == x * 8
assert c._shift(x, 255) == 0
assert c._shift(y, 255) == 2**255
assert c._shift(x, 256) == 0
assert c._shift(x, 0) == x
assert c._shift(x, -1) == x // 2
assert c._shift(x, -3) == x // 8
assert c._shift(x, -256) == 0
assert c._negatedShift(x, -3) == x * 8
assert c._negatedShift(x, -255) == 0
assert c._negatedShift(y, -255) == 2**255
assert c._negatedShift(x, -256) == 0
assert c._negatedShift(x, -0) == x
assert c._negatedShift(x, 1) == x // 2
assert c._negatedShift(x, 3) == x // 8
assert c._negatedShift(x, 256) == 0

for t in (x, y):
for s in (0, 1, 3, 255, 256):
assert c._shr(t, s) == t >> s
assert c._shl(t, s) == (t << s) % (2**256)


POST_BYZANTIUM = [k for (k, v) in EVM_VERSIONS.items() if v > 0]
Expand All @@ -76,19 +66,22 @@ def test_test_bitwise(get_contract_with_gas_estimation, evm_version):
def test_signed_shift(get_contract_with_gas_estimation, evm_version):
code = """
@external
def _signedShift(x: int256, y: int128) -> int256:
return shift(x, y)
def _sar(x: int256, y: uint256) -> int256:
return x >> y
@external
def _shl(x: int256, y: uint256) -> int256:
return x << y
"""
c = get_contract_with_gas_estimation(code, evm_version=evm_version)
x = 126416208461208640982146408124
y = 7128468721412412459
cases = [x, y, -x, -y]

for t in cases:
assert c._signedShift(t, 0) == t >> 0
assert c._signedShift(t, -1) == t >> 1
assert c._signedShift(t, -3) == t >> 3
assert c._signedShift(t, -256) == t >> 256
for s in (0, 1, 3, 255, 256):
assert c._sar(t, s) == t >> s
assert c._shl(t, s) == unsigned_to_signed((t << s) % (2**256), 256)


def test_precedence(get_contract):
Expand All @@ -115,41 +108,71 @@ def baz(a: uint256, b: uint256, c: uint256) -> (uint256, uint256):
def test_literals(get_contract, evm_version):
code = """
@external
def left(x: uint256) -> uint256:
return shift(x, -3)
def _shr(x: uint256) -> uint256:
return x >> 3
@external
def right(x: uint256) -> uint256:
return shift(x, 3)
def _shl(x: uint256) -> uint256:
return x << 3
"""

c = get_contract(code, evm_version=evm_version)
assert c.left(80) == 10
assert c.right(80) == 640
assert c._shr(80) == 10
assert c._shl(80) == 640


fail_list = [
(
# cannot shift non-uint256/int256 argument
"""
@external
def foo(x: uint8, y: uint8) -> uint8:
return x << y
""",
InvalidOperation,
),
(
# cannot shift non-uint256/int256 argument
"""
@external
def foo(x: int8, y: uint8) -> int8:
return x << y
""",
InvalidOperation,
),
(
# cannot shift by non-uint bits
"""
@external
def foo(x: uint8, y: int128) -> uint256:
return shift(x, y)
def foo(x: uint256, y: int128) -> uint256:
return x << y
""",
TypeMismatch,
),
(
# cannot left shift by more than 256 bits
"""
@external
def foo() -> uint256:
return 2 << 257
""",
InvalidLiteral,
),
(
# cannot shift by negative amount
"""
@external
def foo() -> uint256:
return shift(2, 257)
return 2 << -1
""",
InvalidLiteral,
),
(
# cannot shift by negative amount
"""
@external
def foo() -> uint256:
return shift(2, -257)
return 2 << -1
""",
InvalidLiteral,
),
Expand Down
2 changes: 1 addition & 1 deletion tests/parser/types/numbers/test_unsigned_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256:
o: uint256 = 1
for i in range(256):
o = uint256_mulmod(o, o, modulus)
if exponent & shift(1, 255 - i) != 0:
if exponent & (1 << (255 - i)) != 0:
o = uint256_mulmod(o, base, modulus)
return o
"""
Expand Down
5 changes: 2 additions & 3 deletions tests/parser/types/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,8 @@ def to_little_endian_64(_value: uint256) -> Bytes[8]:
y: uint256 = 0
x: uint256 = _value
for _ in range(8):
y = shift(y, 8)
y = y + (x & 255)
x = shift(x, -8)
y = (y << 8) | (x & 255)
x >>= 8
return slice(convert(y, bytes32), 24, 8)
@external
Expand Down
5 changes: 2 additions & 3 deletions tests/parser/types/test_bytes_zero_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ def to_little_endian_64(_value: uint256) -> Bytes[8]:
y: uint256 = 0
x: uint256 = _value
for _ in range(8):
y = shift(y, 8)
y = y + (x & 255)
x = shift(x, -8)
y = (y << 8) | (x & 255)
x >>= 8
return slice(convert(y, bytes32), 24, 8)
@external
Expand Down
Loading

0 comments on commit 7a64d4b

Please sign in to comment.