Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: add linearization check for initializers #4038

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 247 additions & 1 deletion tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,59 @@ def __init__():
assert compile_code(main, input_bundle=input_bundle) is not None


# test multiple uses in different nodes of the import tree
def test_distant_use_initialize(make_input_bundle):
lib3 = """
counter: uint256

@deploy
def __init__():
self.counter = 1
"""
lib2 = """
import lib3

uses: lib3

counter: uint256

@deploy
def __init__():
self.counter = 1

@external
def foo() ->uint256:
return lib3.counter
"""
lib1 = """
import lib2
import lib3

uses: lib3
initializes: lib2[lib3 := lib3]

@deploy
def __init__():
lib2.__init__()
lib3.counter += 1
"""
main = """
import lib1
import lib3

initializes: lib1[lib3 := lib3]
initializes: lib3

@deploy
def __init__():
lib3.__init__()
lib1.__init__()
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3})

assert compile_code(main, input_bundle=input_bundle) is not None


def test_initialize_multi_line_uses(make_input_bundle):
lib1 = """
counter: uint256
Expand Down Expand Up @@ -197,10 +250,10 @@ def foo():

@deploy
def __init__():
lib2.__init__()
# demonstrate we can call lib1.__init__ through lib2.lib1
# (not sure this should be allowed, really.
lib2.lib1.__init__()
lib2.__init__()
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})

Expand Down Expand Up @@ -238,6 +291,199 @@ def __init__():
assert compile_code(main, input_bundle=input_bundle) is not None


def test_initialize_wrong_order(make_input_bundle):
lib1 = """
counter: uint256

@deploy
def __init__():
pass
"""
lib2 = """
import lib1

uses: lib1

counter: uint256

@deploy
def __init__():
pass

@internal
def foo():
lib1.counter += 1
"""
main = """
import lib1
import lib2

initializes: lib2[lib1 := lib1]
initializes: lib1

@deploy
def __init__():
lib2.__init__()
lib1.__init__()
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})

with pytest.raises(InitializerException) as e:
assert compile_code(main, input_bundle=input_bundle) is not None

expected = "Tried to initialize `lib2`, but it depends on `lib1`, "
expected += "which has not been initialized yet."
assert e.value._message == expected
assert e.value._hint == "call `lib1.__init__()` before `lib2.__init__()`."


def test_initializer_order_nested(make_input_bundle):
lib1 = """
a: public(uint256)

@deploy
@payable
def __init__(x: uint256):
self.a = x
"""
lib2 = """
import lib1

uses: lib1

a: uint256

@deploy
def __init__():
# not initialized when called
self.a = lib1.a
"""
lib3 = """
import lib1

initializes: lib1

a: uint256

@deploy
@payable
def __init__(x: uint256):
self.a = x
lib1.__init__(0)
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3})

main1 = """
import lib1
import lib2
import lib3

initializes: lib2[lib1 := lib1]
initializes: lib3

@deploy
def __init__():
lib3.__init__(0)
lib2.__init__()
"""
assert compile_code(main1, input_bundle=input_bundle) is not None

main2 = """
import lib1
import lib2
import lib3

initializes: lib2[lib1 := lib1]
initializes: lib3

@deploy
def __init__():
lib2.__init__() # opposite order!
lib3.__init__(0)
"""
with pytest.raises(InitializerException) as e:
compile_code(main2, input_bundle=input_bundle)

expected = "Tried to initialize `lib2`, but it depends on `lib1`, which "
expected += "has not been initialized yet."
assert e.value._message == expected

assert e.value._hint == "call `lib1.__init__()` before `lib2.__init__()`."


def test_initializer_nested_order2(make_input_bundle):
lib1 = """
import lib4

a: public(uint256)

initializes: lib4

@deploy
@payable
def __init__(x: uint256):
self.a = x
lib4.__init__(x)
"""

lib2 = """
import lib1
import lib4

uses: lib1
uses: lib4

a: uint256

@deploy
def __init__():
# not initialized when called
self.a = lib1.a + lib4.a
"""

lib3 = """
import lib1

initializes: lib1

a: uint256

@deploy
@payable
def __init__(x: uint256):
self.a = x
lib1.__init__(0)
"""
lib4 = """
a: uint256

@deploy
@payable
def __init__(x: uint256):
self.a = x
"""
main = """
import lib1
import lib2
import lib3
import lib4

initializes: lib2[lib1 := lib1, lib4 := lib4]
initializes: lib3

@deploy
def __init__():
lib3.__init__(0)
lib2.__init__()
"""

input_bundle = make_input_bundle(
{"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3, "lib4.vy": lib4}
)

assert compile_code(main, input_bundle=input_bundle) is not None


def test_imported_as_different_names(make_input_bundle):
lib1 = """
counter: uint256
Expand Down
26 changes: 25 additions & 1 deletion vyper/semantics/analysis/global_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from vyper.exceptions import ExceptionList, InitializerException
from vyper.semantics.analysis.base import InitializesInfo, UsesInfo
from vyper.semantics.types.module import ModuleT
from vyper.utils import OrderedSet


def validate_compilation_target(module_t: ModuleT):
Expand Down Expand Up @@ -54,14 +55,37 @@ def _validate_global_initializes_constraint(module_t: ModuleT):
all_used_modules = _collect_used_modules_r(module_t)
all_initialized_modules = _collect_initialized_modules_r(module_t)

hint = None

init_calls = []
if module_t.init_function is not None:
init_calls = list(module_t.init_function.reachable_internal_functions)
seen: OrderedSet = OrderedSet()

for init_t in init_calls:
seen.add(init_t)
init_m = init_t.decl_node.module_node._metadata["type"]
init_info = all_initialized_modules[init_m]
for dep in init_info.dependencies:
m = dep.module_t
if m.init_function is None:
continue
if m.init_function not in seen:
# TODO: recover source info
msg = f"Tried to initialize `{init_info.module_info.alias}`, "
msg += f"but it depends on `{dep.alias}`, which has not been "
msg += "initialized yet."
hint = f"call `{dep.alias}.__init__()` before "
hint += f"`{init_info.module_info.alias}.__init__()`."
raise InitializerException(msg, hint=hint)

err_list = ExceptionList()

for u, uses in all_used_modules.items():
if u not in all_initialized_modules:
msg = f"module `{u}` is used but never initialized!"

# construct a hint if the module is in scope
hint = None
found_module = module_t.find_module_info(u)
if found_module is not None:
# TODO: do something about these constants
Expand Down
19 changes: 9 additions & 10 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,33 +321,32 @@ def validate_used_modules(self):
err_list.raise_if_not_empty()

def validate_initialized_modules(self):
# check all `initializes:` modules have `__init__()` called exactly once
# check all `initializes:` modules have `__init__()` called exactly once,
# and check they are called in dependency order
module_t = self.ast._metadata["type"]
should_initialize = {t.module_info.module_t: t for t in module_t.initialized_modules}

# don't call `__init__()` for modules which don't have
# `__init__()` function
for m in should_initialize.copy():
for f in m.functions.values():
if f.is_constructor:
break
else:
if m.init_function is None:
del should_initialize[m]

init_calls = []
for f in self.ast.get_children(vy_ast.FunctionDef):
if f._metadata["func_type"].is_constructor:
init_calls = f.get_descendants(vy_ast.Call)
break
if module_t.init_function is not None:
init_calls = module_t.init_function.ast_def.get_descendants(vy_ast.Call)

# map of seen __init__() function calls
seen_initializers = {}

for call_node in init_calls:
expr_info = call_node.func._expr_info
if expr_info is None:
# this can happen for range() calls; CMC 2024-02-05 try to
# refactor so that range() is properly tagged.
continue

call_t = call_node.func._expr_info.typ
call_t = expr_info.typ

if not isinstance(call_t, ContractFunctionT):
continue
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ def used_modules(self):
ret.append(used_module)
return ret

@property
@cached_property
def initialized_modules(self):
# modules which are initialized to
# modules which are initialized
ret = []
for node in self.initializes_decls:
info = node._metadata["initializes_info"]
Expand Down
Loading