Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve name resolution of type aliases #2061

Merged
merged 2 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions slither/core/compilation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit
self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = []
self._custom_errors: List[CustomErrorTopLevel] = []
self._user_defined_value_types: Dict[str, TypeAliasTopLevel] = {}
self._type_aliases: Dict[str, TypeAliasTopLevel] = {}

self._all_functions: Set[Function] = set()
self._all_modifiers: Set[Modifier] = set()
Expand Down Expand Up @@ -220,8 +220,8 @@ def custom_errors(self) -> List[CustomErrorTopLevel]:
return self._custom_errors

@property
def user_defined_value_types(self) -> Dict[str, TypeAliasTopLevel]:
return self._user_defined_value_types
def type_aliases(self) -> Dict[str, TypeAliasTopLevel]:
return self._type_aliases

# endregion
###################################################################################
Expand Down
34 changes: 34 additions & 0 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
from slither.core.cfg.node import Node
from slither.core.solidity_types import TypeAliasContract


LOGGER = logging.getLogger("Contract")
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope
self._functions: Dict[str, "FunctionContract"] = {}
self._linearizedBaseContracts: List[int] = []
self._custom_errors: Dict[str, "CustomErrorContract"] = {}
self._type_aliases: Dict[str, "TypeAliasContract"] = {}

# The only str is "*"
self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {}
Expand Down Expand Up @@ -364,6 +366,38 @@ def custom_errors_declared(self) -> List["CustomErrorContract"]:
def custom_errors_as_dict(self) -> Dict[str, "CustomErrorContract"]:
return self._custom_errors

# endregion
###################################################################################
###################################################################################
# region Custom Errors
###################################################################################
###################################################################################

@property
def type_aliases(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the contract's custom errors
"""
return list(self._type_aliases.values())

@property
def type_aliases_inherited(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the inherited custom errors
"""
return [s for s in self.type_aliases if s.contract != self]

@property
def type_aliases_declared(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the custom errors declared within the contract (not inherited)
"""
return [s for s in self.type_aliases if s.contract == self]

@property
def type_aliases_as_dict(self) -> Dict[str, "TypeAliasContract"]:
return self._type_aliases

# endregion
###################################################################################
###################################################################################
Expand Down
6 changes: 3 additions & 3 deletions slither/core/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, filename: Filename) -> None:

# User defined types
# Name -> type alias
self.user_defined_types: Dict[str, TypeAlias] = {}
self.type_aliases: Dict[str, TypeAlias] = {}

def add_accesible_scopes(self) -> bool:
"""
Expand Down Expand Up @@ -95,8 +95,8 @@ def add_accesible_scopes(self) -> bool:
if not _dict_contain(new_scope.renaming, self.renaming):
self.renaming.update(new_scope.renaming)
learn_something = True
if not _dict_contain(new_scope.user_defined_types, self.user_defined_types):
self.user_defined_types.update(new_scope.user_defined_types)
if not _dict_contain(new_scope.type_aliases, self.type_aliases):
self.type_aliases.update(new_scope.type_aliases)
learn_something = True

return learn_something
Expand Down
8 changes: 4 additions & 4 deletions slither/solc_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,10 @@ def _parse_type_alias(self, item: Dict) -> None:
alias = item["name"]
alias_canonical = self._contract.name + "." + item["name"]

user_defined_type = TypeAliasContract(original_type, alias, self.underlying_contract)
user_defined_type.set_offset(item["src"], self.compilation_unit)
self._contract.file_scope.user_defined_types[alias] = user_defined_type
self._contract.file_scope.user_defined_types[alias_canonical] = user_defined_type
type_alias = TypeAliasContract(original_type, alias, self.underlying_contract)
type_alias.set_offset(item["src"], self.compilation_unit)
self._contract.type_aliases_as_dict[alias] = type_alias
self._contract.file_scope.type_aliases[alias_canonical] = type_alias

def _parse_struct(self, struct: Dict) -> None:

Expand Down
2 changes: 1 addition & 1 deletion slither/solc_parsing/declarations/using_for_top_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _propagate_global(self, type_name: Union[TypeAliasTopLevel, UserDefinedType]
if self._global:
for scope in self.compilation_unit.scopes.values():
if isinstance(type_name, TypeAliasTopLevel):
for alias in scope.user_defined_types.values():
for alias in scope.type_aliases.values():
if alias == type_name:
scope.using_for_directives.add(self._using_for)
elif isinstance(type_name, UserDefinedType):
Expand Down
9 changes: 6 additions & 3 deletions slither/solc_parsing/expressions/find_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def find_top_level(
:return:
:rtype:
"""
if var_name in scope.type_aliases:
return scope.type_aliases[var_name], False

if var_name in scope.structures:
return scope.structures[var_name], False
Expand Down Expand Up @@ -205,6 +207,10 @@ def _find_in_contract(
if sig == var_name:
return modifier

type_aliases = contract.type_aliases_as_dict
if var_name in type_aliases:
return type_aliases[var_name]

# structures are looked on the contract declarer
structures = contract.structures_as_dict
if var_name in structures:
Expand Down Expand Up @@ -362,9 +368,6 @@ def find_variable(
if var_name in current_scope.renaming:
var_name = current_scope.renaming[var_name]

if var_name in current_scope.user_defined_types:
return current_scope.user_defined_types[var_name], False

# Use ret0/ret1 to help mypy
ret0 = _find_variable_from_ref_declaration(
referenced_declaration, direct_contracts, direct_functions
Expand Down
8 changes: 4 additions & 4 deletions slither/solc_parsing/slither_compilation_unit_solc.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,10 @@ def parse_top_level_from_loaded_json(self, data_loaded: Dict, filename: str) ->

original_type = ElementaryType(underlying_type["name"])

user_defined_type = TypeAliasTopLevel(original_type, alias, scope)
user_defined_type.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.user_defined_value_types[alias] = user_defined_type
scope.user_defined_types[alias] = user_defined_type
type_alias = TypeAliasTopLevel(original_type, alias, scope)
type_alias.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.type_aliases[alias] = type_alias
scope.type_aliases[alias] = type_alias

else:
raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported")
Expand Down
30 changes: 15 additions & 15 deletions slither/solc_parsing/solidity_types/type_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def parse_type(

sl: "SlitherCompilationUnit"
renaming: Dict[str, str]
user_defined_types: Dict[str, TypeAlias]
type_aliases: Dict[str, TypeAlias]
enums_direct_access: List["Enum"] = []
# Note: for convenicence top level functions use the same parser than function in contract
# but contract_parser is set to None
Expand All @@ -247,13 +247,13 @@ def parse_type(
sl = caller_context.compilation_unit
next_context = caller_context
renaming = {}
user_defined_types = sl.user_defined_value_types
type_aliases = sl.type_aliases
else:
assert isinstance(caller_context, FunctionSolc)
sl = caller_context.underlying_function.compilation_unit
next_context = caller_context.slither_parser
renaming = caller_context.underlying_function.file_scope.renaming
user_defined_types = caller_context.underlying_function.file_scope.user_defined_types
type_aliases = caller_context.underlying_function.file_scope.type_aliases
structures_direct_access = sl.structures_top_level
all_structuress = [c.structures for c in sl.contracts]
all_structures = [item for sublist in all_structuress for item in sublist]
Expand Down Expand Up @@ -299,7 +299,7 @@ def parse_type(
functions = list(scope.functions)

renaming = scope.renaming
user_defined_types = scope.user_defined_types
type_aliases = scope.type_aliases
elif isinstance(caller_context, (ContractSolc, FunctionSolc)):
sl = caller_context.compilation_unit
if isinstance(caller_context, FunctionSolc):
Expand Down Expand Up @@ -329,7 +329,7 @@ def parse_type(
functions = contract.functions + contract.modifiers

renaming = scope.renaming
user_defined_types = scope.user_defined_types
type_aliases = scope.type_aliases
else:
raise ParsingError(f"Incorrect caller context: {type(caller_context)}")

Expand All @@ -343,8 +343,8 @@ def parse_type(
name = t.name
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
if name in type_aliases:
return type_aliases[name]
return _find_from_type_name(
name,
functions,
Expand All @@ -365,9 +365,9 @@ def parse_type(
name = t["typeDescriptions"]["typeString"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
_add_type_references(user_defined_types[name], t["src"], sl)
return user_defined_types[name]
if name in type_aliases:
_add_type_references(type_aliases[name], t["src"], sl)
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
Expand All @@ -386,9 +386,9 @@ def parse_type(
name = t["attributes"][type_name_key]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
_add_type_references(user_defined_types[name], t["src"], sl)
return user_defined_types[name]
if name in type_aliases:
_add_type_references(type_aliases[name], t["src"], sl)
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
Expand All @@ -407,8 +407,8 @@ def parse_type(
name = t["name"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
if name in type_aliases:
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
Expand Down
4 changes: 2 additions & 2 deletions slither/visitors/slithir/expression_to_slithir.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,8 @@ def _post_member_access(self, expression: MemberAccess) -> None:
# contract A { type MyInt is int}
# contract B { function f() public{ A.MyInt test = A.MyInt.wrap(1);}}
# The logic is handled by _post_call_expression
if expression.member_name in expr.file_scope.user_defined_types:
set_val(expression, expr.file_scope.user_defined_types[expression.member_name])
if expression.member_name in expr.file_scope.type_aliases:
set_val(expression, expr.file_scope.type_aliases[expression.member_name])
return
# Lookup errors referred to as member of contract e.g. Test.myError.selector
if expression.member_name in expr.custom_errors_as_dict:
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/solc_parsing/test_ast_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def make_version(minor: int, patch_min: int, patch_max: int) -> List[str]:
["0.6.9", "0.7.6", "0.8.16"],
),
Test("user_defined_operators-0.8.19.sol", ["0.8.19"]),
Test("type-aliases.sol", ["0.8.19"]),
]
# create the output folder if needed
try:
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"OtherTest": {
"myfunc()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n}\n"
},
"DeleteTest": {}
}
20 changes: 20 additions & 0 deletions tests/e2e/solc_parsing/test_data/type-aliases.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

struct Z {
int x;
int y;
}

contract OtherTest {
struct Z {
int x;
int y;
}

function myfunc() external {
Z memory z = Z(2,3);
}
}

contract DeleteTest {
type Z is int;
}
6 changes: 2 additions & 4 deletions tests/unit/core/test_source_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,13 @@ def test_references_user_defined_aliases(solc_binary_path):
file = Path(SRC_MAPPING_TEST_ROOT, "ReferencesUserDefinedAliases.sol").as_posix()
slither = Slither(file, solc=solc_path)

alias_top_level = slither.compilation_units[0].user_defined_value_types["aliasTopLevel"]
alias_top_level = slither.compilation_units[0].type_aliases["aliasTopLevel"]
assert len(alias_top_level.references) == 2
lines = _sort_references_lines(alias_top_level.references)
assert lines == [12, 16]

alias_contract_level = (
slither.compilation_units[0]
.contracts[0]
.file_scope.user_defined_types["C.aliasContractLevel"]
slither.compilation_units[0].contracts[0].file_scope.type_aliases["C.aliasContractLevel"]
)
assert len(alias_contract_level.references) == 2
lines = _sort_references_lines(alias_contract_level.references)
Expand Down