Skip to content

Commit

Permalink
fix: block mload merging when src and dst overlap (#3635)
Browse files Browse the repository at this point in the history
this commit fixes an optimization bug when the target architecture has
the `mcopy` instruction (i.e. `cancun` or later). the bug was introduced
in 5dc3ac7. specifically, the `merge_mload` step can incorrectly merge
`mload`/`mstore` sequences (into `mcopy`) when the source and
destination buffers overlap, and the destination buffer is "ahead of"
(i.e. greater than) the source buffer. this commit fixes the issue by
blocking the optimization in these cases, and adds unit and functional
tests demonstrating the correct behavior.

---------

Co-authored-by: Robert Chen <[email protected]>
  • Loading branch information
charles-cooper and chen-robert authored Oct 3, 2023
1 parent 8aae7cd commit e9c16e4
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 2 deletions.
107 changes: 107 additions & 0 deletions tests/compiler/ir/test_optimize_ir.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest

from vyper.codegen.ir_node import IRnode
from vyper.evm.opcodes import EVM_VERSIONS, anchor_evm_version
from vyper.exceptions import StaticAssertionException
from vyper.ir import optimizer

POST_CANCUN = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]}


optimize_list = [
(["eq", 1, 2], [0]),
(["lt", 1, 2], [1]),
Expand Down Expand Up @@ -272,3 +276,106 @@ def test_operator_set_values():
assert optimizer.COMPARISON_OPS == {"lt", "gt", "le", "ge", "slt", "sgt", "sle", "sge"}
assert optimizer.STRICT_COMPARISON_OPS == {"lt", "gt", "slt", "sgt"}
assert optimizer.UNSTRICT_COMPARISON_OPS == {"le", "ge", "sle", "sge"}


mload_merge_list = [
# copy "backward" with no overlap between src and dst buffers,
# OK to become mcopy
(
["seq", ["mstore", 32, ["mload", 128]], ["mstore", 64, ["mload", 160]]],
["mcopy", 32, 128, 64],
),
# copy with overlap "backwards", OK to become mcopy
(["seq", ["mstore", 32, ["mload", 64]], ["mstore", 64, ["mload", 96]]], ["mcopy", 32, 64, 64]),
# "stationary" overlap (i.e. a no-op mcopy), OK to become mcopy
(["seq", ["mstore", 32, ["mload", 32]], ["mstore", 64, ["mload", 64]]], ["mcopy", 32, 32, 64]),
# copy "forward" with no overlap, OK to become mcopy
(["seq", ["mstore", 64, ["mload", 0]], ["mstore", 96, ["mload", 32]]], ["mcopy", 64, 0, 64]),
# copy "forwards" with overlap by one word, must NOT become mcopy
(["seq", ["mstore", 64, ["mload", 32]], ["mstore", 96, ["mload", 64]]], None),
# check "forward" overlap by one byte, must NOT become mcopy
(["seq", ["mstore", 64, ["mload", 1]], ["mstore", 96, ["mload", 33]]], None),
# check "forward" overlap by one byte again, must NOT become mcopy
(["seq", ["mstore", 63, ["mload", 0]], ["mstore", 95, ["mload", 32]]], None),
# copy 3 words with partial overlap "forwards", partially becomes mcopy
# (2 words are mcopied and 1 word is mload/mstored
(
[
"seq",
["mstore", 96, ["mload", 32]],
["mstore", 128, ["mload", 64]],
["mstore", 160, ["mload", 96]],
],
["seq", ["mcopy", 96, 32, 64], ["mstore", 160, ["mload", 96]]],
),
# copy 4 words with partial overlap "forwards", becomes 2 mcopies of 2 words each
(
[
"seq",
["mstore", 96, ["mload", 32]],
["mstore", 128, ["mload", 64]],
["mstore", 160, ["mload", 96]],
["mstore", 192, ["mload", 128]],
],
["seq", ["mcopy", 96, 32, 64], ["mcopy", 160, 96, 64]],
),
# copy 4 words with 1 byte of overlap, must NOT become mcopy
(
[
"seq",
["mstore", 96, ["mload", 33]],
["mstore", 128, ["mload", 65]],
["mstore", 160, ["mload", 97]],
["mstore", 192, ["mload", 129]],
],
None,
),
# Ensure only sequential mstore + mload sequences are optimized
(
[
"seq",
["mstore", 0, ["mload", 32]],
["sstore", 0, ["calldataload", 4]],
["mstore", 32, ["mload", 64]],
],
None,
),
# not-word aligned optimizations (not overlap)
(["seq", ["mstore", 0, ["mload", 1]], ["mstore", 32, ["mload", 33]]], ["mcopy", 0, 1, 64]),
# not-word aligned optimizations (overlap)
(["seq", ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]], None),
# not-word aligned optimizations (overlap and not-overlap)
(
[
"seq",
["mstore", 0, ["mload", 1]],
["mstore", 32, ["mload", 33]],
["mstore", 1, ["mload", 0]],
["mstore", 33, ["mload", 32]],
],
["seq", ["mcopy", 0, 1, 64], ["mstore", 1, ["mload", 0]], ["mstore", 33, ["mload", 32]]],
),
# overflow test
(
[
"seq",
["mstore", 2**256 - 1 - 31 - 32, ["mload", 0]],
["mstore", 2**256 - 1 - 31, ["mload", 32]],
],
["mcopy", 2**256 - 1 - 31 - 32, 0, 64],
),
]


@pytest.mark.parametrize("ir", mload_merge_list)
@pytest.mark.parametrize("evm_version", list(POST_CANCUN.keys()))
def test_mload_merge(ir, evm_version):
with anchor_evm_version(evm_version):
optimized = optimizer.optimize(IRnode.from_list(ir[0]))
if ir[1] is None:
# no-op, assert optimizer does nothing
expected = IRnode.from_list(ir[0])
else:
expected = IRnode.from_list(ir[1])

assert optimized == expected
60 changes: 60 additions & 0 deletions tests/parser/features/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,63 @@ def bug(p: Point) -> Point:
"""
c = get_contract(code)
assert c.bug((1, 2)) == (2, 1)


mload_merge_codes = [
(
"""
@external
def foo() -> uint256[4]:
# copy "backwards"
xs: uint256[4] = [1, 2, 3, 4]
# dst < src
xs[0] = xs[1]
xs[1] = xs[2]
xs[2] = xs[3]
return xs
""",
[2, 3, 4, 4],
),
(
"""
@external
def foo() -> uint256[4]:
# copy "forwards"
xs: uint256[4] = [1, 2, 3, 4]
# src < dst
xs[1] = xs[0]
xs[2] = xs[1]
xs[3] = xs[2]
return xs
""",
[1, 1, 1, 1],
),
(
"""
@external
def foo() -> uint256[5]:
# partial "forward" copy
xs: uint256[5] = [1, 2, 3, 4, 5]
# src < dst
xs[2] = xs[0]
xs[3] = xs[1]
xs[4] = xs[2]
return xs
""",
[1, 2, 1, 2, 1],
),
]


# functional test that mload merging does not occur when source and dest
# buffers overlap. (note: mload merging only applies after cancun)
@pytest.mark.parametrize("code,expected_result", mload_merge_codes)
def test_mcopy_overlap(get_contract, code, expected_result):
c = get_contract(code)
assert c.foo() == expected_result
9 changes: 7 additions & 2 deletions vyper/ir/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,10 +662,10 @@ def _rewrite_mstore_dload(argz):
def _merge_mload(argz):
if not version_check(begin="cancun"):
return False
return _merge_load(argz, "mload", "mcopy")
return _merge_load(argz, "mload", "mcopy", allow_overlap=False)


def _merge_load(argz, _LOAD, _COPY):
def _merge_load(argz, _LOAD, _COPY, allow_overlap=True):
# look for sequential operations copying from X to Y
# and merge them into a single copy operation
changed = False
Expand All @@ -689,9 +689,14 @@ def _merge_load(argz, _LOAD, _COPY):
initial_dst_offset = dst_offset
initial_src_offset = src_offset
idx = i

# dst and src overlap, discontinue the optimization
has_overlap = initial_src_offset < initial_dst_offset < src_offset + 32

if (
initial_dst_offset + total_length == dst_offset
and initial_src_offset + total_length == src_offset
and (allow_overlap or not has_overlap)
):
mstore_nodes.append(ir_node)
total_length += 32
Expand Down

0 comments on commit e9c16e4

Please sign in to comment.