Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support using for directive #1378

Merged
merged 15 commits into from
Jan 5, 2023
6 changes: 6 additions & 0 deletions slither/core/compilation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.scope.scope import FileScope
from slither.core.variables.state_variable import StateVariable
Expand All @@ -41,6 +42,7 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit
self._enums_top_level: List[EnumTopLevel] = []
self._variables_top_level: List[TopLevelVariable] = []
self._functions_top_level: List[FunctionTopLevel] = []
self._using_for_top_level: List[UsingForTopLevel] = []
self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = []
self._custom_errors: List[CustomError] = []
Expand Down Expand Up @@ -205,6 +207,10 @@ def variables_top_level(self) -> List[TopLevelVariable]:
def functions_top_level(self) -> List[FunctionTopLevel]:
return self._functions_top_level

@property
def using_for_top_level(self) -> List[UsingForTopLevel]:
return self._using_for_top_level

@property
def custom_errors(self) -> List[CustomError]:
return self._custom_errors
Expand Down
3 changes: 3 additions & 0 deletions slither/core/declarations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
)
from .structure import Structure
from .enum_contract import EnumContract
from .enum_top_level import EnumTopLevel
from .structure_contract import StructureContract
from .structure_top_level import StructureTopLevel
from .function_contract import FunctionContract
from .function_top_level import FunctionTopLevel
22 changes: 22 additions & 0 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope

# The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._using_for_complete: Dict[Union[str, Type], List[Type]] = None
self._kind: Optional[str] = None
self._is_interface: bool = False
self._is_library: bool = False
Expand Down Expand Up @@ -266,6 +267,27 @@ def events_as_dict(self) -> Dict[str, "Event"]:
def using_for(self) -> Dict[Union[str, Type], List[Type]]:
return self._using_for

@property
def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]:
"""
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
"""

def _merge_using_for(uf1, uf2):
result = {**uf1, **uf2}
for key, value in result.items():
if key in uf1 and key in uf2:
result[key] = value + uf1[key]
return result

if self._using_for_complete is None:
result = self.using_for
top_level_using_for = self.file_scope.using_for_directives
for uftl in top_level_using_for:
result = _merge_using_for(result, uftl.using_for)
self._using_for_complete = result
return self._using_for_complete

# endregion
###################################################################################
###################################################################################
Expand Down
18 changes: 18 additions & 0 deletions slither/core/declarations/using_for_top_level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import TYPE_CHECKING, List, Dict, Union

from slither.core.solidity_types.type import Type
from slither.core.declarations.top_level import TopLevel

if TYPE_CHECKING:
from slither.core.scope.scope import FileScope


class UsingForTopLevel(TopLevel):
def __init__(self, scope: "FileScope"):
super().__init__()
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self.file_scope: "FileScope" = scope

@property
def using_for(self) -> Dict[Type, List[Type]]:
return self._using_for
5 changes: 5 additions & 0 deletions slither/core/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.solidity_types import TypeAlias
from slither.core.variables.top_level_variable import TopLevelVariable
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, filename: Filename):
# Because we parse the function signature later on
# So we simplify the logic and have the scope fields all populated
self.functions: Set[FunctionTopLevel] = set()
self.using_for_directives: Set[UsingForTopLevel] = set()
self.imports: Set[Import] = set()
self.pragmas: Set[Pragma] = set()
self.structures: Dict[str, StructureTopLevel] = {}
Expand Down Expand Up @@ -75,6 +77,9 @@ def add_accesible_scopes(self) -> bool:
if not new_scope.functions.issubset(self.functions):
self.functions |= new_scope.functions
learn_something = True
if not new_scope.using_for_directives.issubset(self.using_for_directives):
self.using_for_directives |= new_scope.using_for_directives
learn_something = True
if not new_scope.imports.issubset(self.imports):
self.imports |= new_scope.imports
learn_something = True
Expand Down
80 changes: 58 additions & 22 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.declarations.solidity_variables import SolidityCustomRevert
from slither.core.expressions import Identifier, Literal
Expand Down Expand Up @@ -199,18 +200,22 @@ def _fits_under_byte(val: Union[int, str]) -> List[str]:
return [f"bytes{f}" for f in range(length, 33)] + ["bytes"]


def _find_function_from_parameter(ir: Call, candidates: List[Function]) -> Optional[Function]:
def _find_function_from_parameter(
arguments: List[Variable], candidates: List[Function], full_comparison: bool
) -> Optional[Function]:
"""
Look for a function in candidates that can be the target of the ir's call
Look for a function in candidates that can be the target based on the ir's call arguments

Try the implicit type conversion for uint/int/bytes. Constant values can be both uint/int
While variables stick to their base type, but can changed the size
While variables stick to their base type, but can changed the size.
If full_comparison is True it will do a comparison of all the arguments regardless if
the candidate remained is one.

:param ir:
:param arguments:
:param candidates:
:param full_comparison:
:return:
"""
arguments = ir.arguments
type_args: List[str]
for idx, arg in enumerate(arguments):
if isinstance(arg, (list,)):
Expand Down Expand Up @@ -258,7 +263,7 @@ def _find_function_from_parameter(ir: Call, candidates: List[Function]) -> Optio
not_found = False
candidates_kept.append(candidate)

if len(candidates_kept) == 1:
if len(candidates_kept) == 1 and not full_comparison:
return candidates_kept[0]
candidates = candidates_kept
if len(candidates) == 1:
Expand Down Expand Up @@ -503,7 +508,9 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
# propagate the type
node_function = node.function
using_for = (
node_function.contract.using_for if isinstance(node_function, FunctionContract) else {}
node_function.contract.using_for_complete
if isinstance(node_function, FunctionContract)
else {}
)
if isinstance(ir, OperationWithLValue):
# Force assignment in case of missing previous correct type
Expand All @@ -530,9 +537,9 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
if can_be_solidity_func(ir):
return convert_to_solidity_func(ir)

# convert library
# convert library or top level function
if t in using_for or "*" in using_for:
new_ir = convert_to_library(ir, node, using_for)
new_ir = convert_to_library_or_top_level(ir, node, using_for)
if new_ir:
return new_ir

Expand Down Expand Up @@ -881,7 +888,9 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis
# }
node_func = ins.node.function
using_for = (
node_func.contract.using_for if isinstance(node_func, FunctionContract) else {}
node_func.contract.using_for_complete
if isinstance(node_func, FunctionContract)
else {}
)

targeted_libraries = (
Expand All @@ -894,10 +903,14 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis
lib_contract_type.type, Contract
):
continue
lib_contract = lib_contract_type.type
for lib_func in lib_contract.functions:
if lib_func.name == ins.ori.variable_right:
candidates.append(lib_func)
if isinstance(lib_contract_type, FunctionContract):
# Using for with list of functions, this is the function called
candidates.append(lib_contract_type)
else:
lib_contract = lib_contract_type.type
for lib_func in lib_contract.functions:
if lib_func.name == ins.ori.variable_right:
candidates.append(lib_func)

if len(candidates) == 1:
lib_func = candidates[0]
Expand Down Expand Up @@ -1326,9 +1339,32 @@ def convert_to_pop(ir, node):
return ret


def look_for_library(contract, ir, using_for, t):
def look_for_library_or_top_level(contract, ir, using_for, t):
for destination in using_for[t]:
lib_contract = contract.file_scope.get_contract_from_name(str(destination))
if isinstance(destination, FunctionTopLevel) and destination.name == ir.function_name:
smonicas marked this conversation as resolved.
Show resolved Hide resolved
arguments = [ir.destination] + ir.arguments
if (
len(destination.parameters) == len(arguments)
and _find_function_from_parameter(arguments, [destination], True) is not None
):
internalcall = InternalCall(destination, ir.nbr_arguments, ir.lvalue, ir.type_call)
internalcall.set_expression(ir.expression)
internalcall.set_node(ir.node)
internalcall.arguments = [ir.destination] + ir.arguments
return_type = internalcall.function.return_type
if return_type:
if len(return_type) == 1:
internalcall.lvalue.set_type(return_type[0])
elif len(return_type) > 1:
internalcall.lvalue.set_type(return_type)
else:
internalcall.lvalue = None
return internalcall

if isinstance(destination, FunctionContract) and destination.contract.is_library:
lib_contract = destination.contract
else:
lib_contract = contract.file_scope.get_contract_from_name(str(destination))
if lib_contract:
lib_call = LibraryCall(
lib_contract,
Expand All @@ -1348,19 +1384,19 @@ def look_for_library(contract, ir, using_for, t):
return None


def convert_to_library(ir, node, using_for):
def convert_to_library_or_top_level(ir, node, using_for):
# We use contract_declarer, because Solidity resolve the library
# before resolving the inheritance.
# Though we could use .contract as libraries cannot be shadowed
contract = node.function.contract_declarer
t = ir.destination.type
if t in using_for:
new_ir = look_for_library(contract, ir, using_for, t)
new_ir = look_for_library_or_top_level(contract, ir, using_for, t)
if new_ir:
return new_ir

if "*" in using_for:
new_ir = look_for_library(contract, ir, using_for, "*")
new_ir = look_for_library_or_top_level(contract, ir, using_for, "*")
if new_ir:
return new_ir

Expand Down Expand Up @@ -1406,7 +1442,7 @@ def convert_type_library_call(ir: HighLevelCall, lib_contract: Contract):
# TODO: handle collision with multiple state variables/functions
func = lib_contract.get_state_variable_from_name(ir.function_name)
if func is None and candidates:
func = _find_function_from_parameter(ir, candidates)
func = _find_function_from_parameter(ir.arguments, candidates, False)

# In case of multiple binding to the same type
# TODO: this part might not be needed with _find_function_from_parameter
Expand Down Expand Up @@ -1502,7 +1538,7 @@ def convert_type_of_high_and_internal_level_call(ir: Operation, contract: Option
if f.name == ir.function_name and len(f.parameters) == len(ir.arguments)
]

func = _find_function_from_parameter(ir, candidates)
func = _find_function_from_parameter(ir.arguments, candidates, False)

if not func:
assert contract
Expand All @@ -1525,7 +1561,7 @@ def convert_type_of_high_and_internal_level_call(ir: Operation, contract: Option
# TODO: handle collision with multiple state variables/functions
func = contract.get_state_variable_from_name(ir.function_name)
if func is None and candidates:
func = _find_function_from_parameter(ir, candidates)
func = _find_function_from_parameter(ir.arguments, candidates, False)

# lowlelvel lookup needs to be done at last step
if not func:
Expand Down
52 changes: 47 additions & 5 deletions slither/solc_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type
from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
Expand Down Expand Up @@ -577,21 +577,26 @@ def analyze_state_variables(self):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing state variable {e}")

def analyze_using_for(self):
def analyze_using_for(self): # pylint: disable=too-many-branches
try:
for father in self._contract.inheritance:
self._contract.using_for.update(father.using_for)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smonicas One thing I wasn't sure about is if there'd be a reason we also need to update child contracts' using_for_complete

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a need because the first use of using_for_complete is when converting to slithir which comes after parsing is finished


if self.is_compact_ast:
for using_for in self._usingForNotParsed:
lib_name = parse_type(using_for["libraryName"], self)
if "typeName" in using_for and using_for["typeName"]:
type_name = parse_type(using_for["typeName"], self)
else:
type_name = "*"
if type_name not in self._contract.using_for:
self._contract.using_for[type_name] = []
self._contract.using_for[type_name].append(lib_name)

if "libraryName" in using_for:
self._contract.using_for[type_name].append(
parse_type(using_for["libraryName"], self)
)
else:
# We have a list of functions. A function can be topLevel or a library function
self._analyze_function_list(using_for["functionList"], type_name)
else:
for using_for in self._usingForNotParsed:
children = using_for[self.get_children()]
Expand All @@ -609,6 +614,43 @@ def analyze_using_for(self):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing using for {e}")

def _analyze_function_list(self, function_list: List, type_name: Type):
for f in function_list:
function_name = f["function"]["name"]
if function_name.find(".") != -1:
# Library function
self._analyze_library_function(function_name, type_name)
else:
# Top level function
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function)

def _analyze_library_function(self, function_name: str, type_name: Type) -> None:
function_name_split = function_name.split(".")
# TODO this doesn't handle the case if there is an import with an alias
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we open a separate issue to track this, with an example?

# e.g. MyImport.MyLib.a
if len(function_name_split) == 2:
library_name = function_name_split[0]
function_name = function_name_split[1]
# Get the library function
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for f in c.functions:
if f.name == function_name:
self._contract.using_for[type_name].append(f)
found = True
break
if not found:
self.log_incorrect_parsing(f"Library function not found {function_name}")
else:
self.log_incorrect_parsing(
f"Expected library function instead received {function_name}"
)

def analyze_enums(self):
try:
for father in self._contract.inheritance:
Expand Down
Loading