Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: fix array index checks when the subscript is folded #3924

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/unit/ast/nodes/test_fold_subscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from hypothesis import strategies as st

from tests.utils import parse_and_fold
from vyper.compiler import compile_code
from vyper.exceptions import ArrayIndexException


@pytest.mark.fuzzing
Expand All @@ -24,3 +26,23 @@ def foo(array: int128[10], idx: uint256) -> int128:
new_node = old_node.get_folded_value()

assert contract.foo(array, idx) == new_node.value


def test_negative_index():
source = """
@external
def foo(array: int128[10]) -> int128:
return array[0 - 1]
"""
with pytest.raises(ArrayIndexException):
compile_code(source)


def test_oob_index():
source = """
@external
def foo(array: int128[10]) -> int128:
return array[9 + 1]
"""
with pytest.raises(ArrayIndexException):
compile_code(source)
5 changes: 5 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,11 @@ def get_folded_value(self) -> "ExprNode":
except KeyError:
raise UnfoldableNode("not foldable", self)

def reduced(self) -> "ExprNode":
if self.has_folded_value:
return self.get_folded_value()
return self

def _set_folded_value(self, node: "VyperNode") -> None:
# sanity check this is only called once
assert "folded_value" not in self._metadata
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VyperNode:
def get_fields(cls: Any) -> set: ...
def set_parent(self, parent: VyperNode) -> VyperNode: ...
def get_folded_value(self) -> ExprNode: ...
def reduced(self) -> ExprNode: ...
def _set_folded_value(self, node: ExprNode) -> None: ...
@classmethod
def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ...
Expand Down
7 changes: 4 additions & 3 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ class Expr:

def __init__(self, node, context, is_stmt=False):
assert isinstance(node, vy_ast.VyperNode)
if node.has_folded_value:
node = node.get_folded_value()
node = node.reduced()

self.expr = node
self.context = context
Expand Down Expand Up @@ -347,7 +346,9 @@ def parse_Subscript(self):
index = Expr.parse_value_expr(self.expr.slice, self.context)

elif is_tuple_like(sub.typ):
index = self.expr.slice.n
# should we annotate expr.slice in the frontend with the
# folded value instead of calling reduced() here?
index = self.expr.slice.reduced().n
# note: this check should also happen in get_element_ptr
if not 0 <= index < len(sub.typ.member_types):
raise TypeCheckFailure("unreachable")
Expand Down
10 changes: 3 additions & 7 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,7 @@ def _analyse_range_iter(self, iter_node, target_type):

def _analyse_list_iter(self, iter_node, target_type):
# iteration over a variable or literal list
iter_val = iter_node
if iter_val.has_folded_value:
iter_val = iter_val.get_folded_value()
iter_val = iter_node.reduced()

if isinstance(iter_val, vy_ast.List):
len_ = len(iter_val.elements)
Expand Down Expand Up @@ -946,12 +944,10 @@ def _validate_range_call(node: vy_ast.Call):
validate_call_args(node, (1, 2), kwargs=["bound"])
kwargs = {s.arg: s.value for s in node.keywords or []}
start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args
start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)]
start, end = [i.reduced() for i in (start, end)]

if "bound" in kwargs:
bound = kwargs["bound"]
if bound.has_folded_value:
bound = bound.get_folded_value()
bound = kwargs["bound"].reduced()
if not isinstance(bound, vy_ast.Int):
raise StructureException("Bound must be a literal integer", bound)
if bound.value <= 0:
Expand Down
9 changes: 6 additions & 3 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def validate_index_type(self, node):
# TODO break this cycle
from vyper.semantics.analysis.utils import validate_expected_type

node = node.reduced()

if isinstance(node, vy_ast.Int):
if node.value < 0:
raise ArrayIndexException("Vyper does not support negative indexing", node)
Expand Down Expand Up @@ -290,9 +292,7 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "DArrayT":
if not isinstance(node.slice, vy_ast.Tuple) or len(node.slice.elements) != 2:
raise StructureException(err_msg, node.slice)

length_node = node.slice.elements[1]
if length_node.has_folded_value:
length_node = length_node.get_folded_value()
length_node = node.slice.elements[1].reduced()

if not isinstance(length_node, vy_ast.Int):
raise StructureException(err_msg, length_node)
Expand Down Expand Up @@ -367,6 +367,8 @@ def size_in_bytes(self):
return sum(i.size_in_bytes for i in self.member_types)

def validate_index_type(self, node):
node = node.reduced()

if not isinstance(node, vy_ast.Int):
raise InvalidType("Tuple indexes must be literals", node)
if node.value < 0:
Expand All @@ -375,6 +377,7 @@ def validate_index_type(self, node):
raise ArrayIndexException("Index out of range", node)

def get_subscripted_type(self, node):
node = node.reduced()
return self.member_types[node.value]

def compare_type(self, other):
Expand Down
3 changes: 1 addition & 2 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ def get_index_value(node: vy_ast.VyperNode) -> int:
# TODO: revisit this!
from vyper.semantics.analysis.utils import get_possible_types_from_node

if node.has_folded_value:
node = node.get_folded_value()
node = node.reduced()

if not isinstance(node, vy_ast.Int):
# even though the subscript is an invalid type, first check if it's a valid _something_
Expand Down
Loading