Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: fix panic in call cycle detection #4200

Merged
merged 11 commits into from
Aug 7, 2024
58 changes: 55 additions & 3 deletions tests/unit/semantics/analysis/test_cyclic_function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,28 @@ def foo():
self.foo()
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation):
with pytest.raises(CallViolation) as e:
analyze_module(vyper_module, dummy_input_bundle)

assert e.value.message == "Contract contains cyclic function call: foo -> foo"


def test_self_function_call2(dummy_input_bundle):
code = """
@external
def foo():
self.bar()

@internal
def bar():
self.bar()
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation) as e:
analyze_module(vyper_module, dummy_input_bundle)

assert e.value.message == "Contract contains cyclic function call: foo -> bar -> bar"


def test_cyclic_function_call(dummy_input_bundle):
code = """
Expand All @@ -27,9 +46,11 @@ def bar():
self.foo()
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation):
with pytest.raises(CallViolation) as e:
analyze_module(vyper_module, dummy_input_bundle)

assert e.value.message == "Contract contains cyclic function call: foo -> bar -> foo"


def test_multi_cyclic_function_call(dummy_input_bundle):
code = """
Expand All @@ -50,9 +71,40 @@ def potato():
self.foo()
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation):
with pytest.raises(CallViolation) as e:
analyze_module(vyper_module, dummy_input_bundle)

expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> foo"

assert e.value.message == expected_message


def test_multi_cyclic_function_call2(dummy_input_bundle):
code = """
@internal
def foo():
self.bar()

@internal
def bar():
self.baz()

@internal
def baz():
self.potato()

@internal
def potato():
self.bar()
"""
vyper_module = parse_to_ast(code)
with pytest.raises(CallViolation) as e:
analyze_module(vyper_module, dummy_input_bundle)

expected_message = "Contract contains cyclic function call: foo -> bar -> baz -> potato -> bar"

assert e.value.message == expected_message


def test_global_ann_assign_callable_no_crash(dummy_input_bundle):
code = """
Expand Down
6 changes: 3 additions & 3 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT
path = path or []

path.append(fn_t)
root = path[0]

for g in fn_t.called_functions:
if g in fn_t.reachable_internal_functions:
# already seen
continue

if g == root:
message = " -> ".join([f.name for f in path])
if g in path:
extended_path = path + [g]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, very clean!

message = " -> ".join([f.name for f in extended_path])
raise CallViolation(f"Contract contains cyclic function call: {message}")

_compute_reachable_set(g, path=path)
Expand Down
Loading