Skip to content

Commit

Permalink
Merge pull request #1378 from crytic/dev-fix-usingfor
Browse files Browse the repository at this point in the history
Improve support using for directive
  • Loading branch information
montyly authored Jan 5, 2023
2 parents 16ebaf6 + ea681f9 commit f55cf6d
Show file tree
Hide file tree
Showing 40 changed files with 816 additions and 29 deletions.
6 changes: 6 additions & 0 deletions slither/core/compilation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.scope.scope import FileScope
from slither.core.variables.state_variable import StateVariable
Expand All @@ -41,6 +42,7 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit
self._enums_top_level: List[EnumTopLevel] = []
self._variables_top_level: List[TopLevelVariable] = []
self._functions_top_level: List[FunctionTopLevel] = []
self._using_for_top_level: List[UsingForTopLevel] = []
self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = []
self._custom_errors: List[CustomError] = []
Expand Down Expand Up @@ -205,6 +207,10 @@ def variables_top_level(self) -> List[TopLevelVariable]:
def functions_top_level(self) -> List[FunctionTopLevel]:
return self._functions_top_level

@property
def using_for_top_level(self) -> List[UsingForTopLevel]:
return self._using_for_top_level

@property
def custom_errors(self) -> List[CustomError]:
return self._custom_errors
Expand Down
3 changes: 3 additions & 0 deletions slither/core/declarations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
)
from .structure import Structure
from .enum_contract import EnumContract
from .enum_top_level import EnumTopLevel
from .structure_contract import StructureContract
from .structure_top_level import StructureTopLevel
from .function_contract import FunctionContract
from .function_top_level import FunctionTopLevel
22 changes: 22 additions & 0 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope

# The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._using_for_complete: Dict[Union[str, Type], List[Type]] = None
self._kind: Optional[str] = None
self._is_interface: bool = False
self._is_library: bool = False
Expand Down Expand Up @@ -266,6 +267,27 @@ def events_as_dict(self) -> Dict[str, "Event"]:
def using_for(self) -> Dict[Union[str, Type], List[Type]]:
return self._using_for

@property
def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]:
"""
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
"""

def _merge_using_for(uf1, uf2):
result = {**uf1, **uf2}
for key, value in result.items():
if key in uf1 and key in uf2:
result[key] = value + uf1[key]
return result

if self._using_for_complete is None:
result = self.using_for
top_level_using_for = self.file_scope.using_for_directives
for uftl in top_level_using_for:
result = _merge_using_for(result, uftl.using_for)
self._using_for_complete = result
return self._using_for_complete

# endregion
###################################################################################
###################################################################################
Expand Down
18 changes: 18 additions & 0 deletions slither/core/declarations/using_for_top_level.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import TYPE_CHECKING, List, Dict, Union

from slither.core.solidity_types.type import Type
from slither.core.declarations.top_level import TopLevel

if TYPE_CHECKING:
from slither.core.scope.scope import FileScope


class UsingForTopLevel(TopLevel):
def __init__(self, scope: "FileScope"):
super().__init__()
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self.file_scope: "FileScope" = scope

@property
def using_for(self) -> Dict[Type, List[Type]]:
return self._using_for
5 changes: 5 additions & 0 deletions slither/core/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.solidity_types import TypeAlias
from slither.core.variables.top_level_variable import TopLevelVariable
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, filename: Filename):
# Because we parse the function signature later on
# So we simplify the logic and have the scope fields all populated
self.functions: Set[FunctionTopLevel] = set()
self.using_for_directives: Set[UsingForTopLevel] = set()
self.imports: Set[Import] = set()
self.pragmas: Set[Pragma] = set()
self.structures: Dict[str, StructureTopLevel] = {}
Expand Down Expand Up @@ -75,6 +77,9 @@ def add_accesible_scopes(self) -> bool:
if not new_scope.functions.issubset(self.functions):
self.functions |= new_scope.functions
learn_something = True
if not new_scope.using_for_directives.issubset(self.using_for_directives):
self.using_for_directives |= new_scope.using_for_directives
learn_something = True
if not new_scope.imports.issubset(self.imports):
self.imports |= new_scope.imports
learn_something = True
Expand Down
80 changes: 58 additions & 22 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.declarations.solidity_variables import SolidityCustomRevert
from slither.core.expressions import Identifier, Literal
Expand Down Expand Up @@ -199,18 +200,22 @@ def _fits_under_byte(val: Union[int, str]) -> List[str]:
return [f"bytes{f}" for f in range(length, 33)] + ["bytes"]


def _find_function_from_parameter(ir: Call, candidates: List[Function]) -> Optional[Function]:
def _find_function_from_parameter(
arguments: List[Variable], candidates: List[Function], full_comparison: bool
) -> Optional[Function]:
"""
Look for a function in candidates that can be the target of the ir's call
Look for a function in candidates that can be the target based on the ir's call arguments
Try the implicit type conversion for uint/int/bytes. Constant values can be both uint/int
While variables stick to their base type, but can changed the size
While variables stick to their base type, but can changed the size.
If full_comparison is True it will do a comparison of all the arguments regardless if
the candidate remained is one.
:param ir:
:param arguments:
:param candidates:
:param full_comparison:
:return:
"""
arguments = ir.arguments
type_args: List[str]
for idx, arg in enumerate(arguments):
if isinstance(arg, (list,)):
Expand Down Expand Up @@ -258,7 +263,7 @@ def _find_function_from_parameter(ir: Call, candidates: List[Function]) -> Optio
not_found = False
candidates_kept.append(candidate)

if len(candidates_kept) == 1:
if len(candidates_kept) == 1 and not full_comparison:
return candidates_kept[0]
candidates = candidates_kept
if len(candidates) == 1:
Expand Down Expand Up @@ -503,7 +508,9 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
# propagate the type
node_function = node.function
using_for = (
node_function.contract.using_for if isinstance(node_function, FunctionContract) else {}
node_function.contract.using_for_complete
if isinstance(node_function, FunctionContract)
else {}
)
if isinstance(ir, OperationWithLValue):
# Force assignment in case of missing previous correct type
Expand All @@ -530,9 +537,9 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
if can_be_solidity_func(ir):
return convert_to_solidity_func(ir)

# convert library
# convert library or top level function
if t in using_for or "*" in using_for:
new_ir = convert_to_library(ir, node, using_for)
new_ir = convert_to_library_or_top_level(ir, node, using_for)
if new_ir:
return new_ir

Expand Down Expand Up @@ -881,7 +888,9 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis
# }
node_func = ins.node.function
using_for = (
node_func.contract.using_for if isinstance(node_func, FunctionContract) else {}
node_func.contract.using_for_complete
if isinstance(node_func, FunctionContract)
else {}
)

targeted_libraries = (
Expand All @@ -894,10 +903,14 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis
lib_contract_type.type, Contract
):
continue
lib_contract = lib_contract_type.type
for lib_func in lib_contract.functions:
if lib_func.name == ins.ori.variable_right:
candidates.append(lib_func)
if isinstance(lib_contract_type, FunctionContract):
# Using for with list of functions, this is the function called
candidates.append(lib_contract_type)
else:
lib_contract = lib_contract_type.type
for lib_func in lib_contract.functions:
if lib_func.name == ins.ori.variable_right:
candidates.append(lib_func)

if len(candidates) == 1:
lib_func = candidates[0]
Expand Down Expand Up @@ -1326,9 +1339,32 @@ def convert_to_pop(ir, node):
return ret


def look_for_library(contract, ir, using_for, t):
def look_for_library_or_top_level(contract, ir, using_for, t):
for destination in using_for[t]:
lib_contract = contract.file_scope.get_contract_from_name(str(destination))
if isinstance(destination, FunctionTopLevel) and destination.name == ir.function_name:
arguments = [ir.destination] + ir.arguments
if (
len(destination.parameters) == len(arguments)
and _find_function_from_parameter(arguments, [destination], True) is not None
):
internalcall = InternalCall(destination, ir.nbr_arguments, ir.lvalue, ir.type_call)
internalcall.set_expression(ir.expression)
internalcall.set_node(ir.node)
internalcall.arguments = [ir.destination] + ir.arguments
return_type = internalcall.function.return_type
if return_type:
if len(return_type) == 1:
internalcall.lvalue.set_type(return_type[0])
elif len(return_type) > 1:
internalcall.lvalue.set_type(return_type)
else:
internalcall.lvalue = None
return internalcall

if isinstance(destination, FunctionContract) and destination.contract.is_library:
lib_contract = destination.contract
else:
lib_contract = contract.file_scope.get_contract_from_name(str(destination))
if lib_contract:
lib_call = LibraryCall(
lib_contract,
Expand All @@ -1348,19 +1384,19 @@ def look_for_library(contract, ir, using_for, t):
return None


def convert_to_library(ir, node, using_for):
def convert_to_library_or_top_level(ir, node, using_for):
# We use contract_declarer, because Solidity resolve the library
# before resolving the inheritance.
# Though we could use .contract as libraries cannot be shadowed
contract = node.function.contract_declarer
t = ir.destination.type
if t in using_for:
new_ir = look_for_library(contract, ir, using_for, t)
new_ir = look_for_library_or_top_level(contract, ir, using_for, t)
if new_ir:
return new_ir

if "*" in using_for:
new_ir = look_for_library(contract, ir, using_for, "*")
new_ir = look_for_library_or_top_level(contract, ir, using_for, "*")
if new_ir:
return new_ir

Expand Down Expand Up @@ -1406,7 +1442,7 @@ def convert_type_library_call(ir: HighLevelCall, lib_contract: Contract):
# TODO: handle collision with multiple state variables/functions
func = lib_contract.get_state_variable_from_name(ir.function_name)
if func is None and candidates:
func = _find_function_from_parameter(ir, candidates)
func = _find_function_from_parameter(ir.arguments, candidates, False)

# In case of multiple binding to the same type
# TODO: this part might not be needed with _find_function_from_parameter
Expand Down Expand Up @@ -1502,7 +1538,7 @@ def convert_type_of_high_and_internal_level_call(ir: Operation, contract: Option
if f.name == ir.function_name and len(f.parameters) == len(ir.arguments)
]

func = _find_function_from_parameter(ir, candidates)
func = _find_function_from_parameter(ir.arguments, candidates, False)

if not func:
assert contract
Expand All @@ -1525,7 +1561,7 @@ def convert_type_of_high_and_internal_level_call(ir: Operation, contract: Option
# TODO: handle collision with multiple state variables/functions
func = contract.get_state_variable_from_name(ir.function_name)
if func is None and candidates:
func = _find_function_from_parameter(ir, candidates)
func = _find_function_from_parameter(ir.arguments, candidates, False)

# lowlelvel lookup needs to be done at last step
if not func:
Expand Down
52 changes: 47 additions & 5 deletions slither/solc_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type
from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
Expand Down Expand Up @@ -577,21 +577,26 @@ def analyze_state_variables(self):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing state variable {e}")

def analyze_using_for(self):
def analyze_using_for(self): # pylint: disable=too-many-branches
try:
for father in self._contract.inheritance:
self._contract.using_for.update(father.using_for)

if self.is_compact_ast:
for using_for in self._usingForNotParsed:
lib_name = parse_type(using_for["libraryName"], self)
if "typeName" in using_for and using_for["typeName"]:
type_name = parse_type(using_for["typeName"], self)
else:
type_name = "*"
if type_name not in self._contract.using_for:
self._contract.using_for[type_name] = []
self._contract.using_for[type_name].append(lib_name)

if "libraryName" in using_for:
self._contract.using_for[type_name].append(
parse_type(using_for["libraryName"], self)
)
else:
# We have a list of functions. A function can be topLevel or a library function
self._analyze_function_list(using_for["functionList"], type_name)
else:
for using_for in self._usingForNotParsed:
children = using_for[self.get_children()]
Expand All @@ -609,6 +614,43 @@ def analyze_using_for(self):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing using for {e}")

def _analyze_function_list(self, function_list: List, type_name: Type):
for f in function_list:
function_name = f["function"]["name"]
if function_name.find(".") != -1:
# Library function
self._analyze_library_function(function_name, type_name)
else:
# Top level function
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function)

def _analyze_library_function(self, function_name: str, type_name: Type) -> None:
function_name_split = function_name.split(".")
# TODO this doesn't handle the case if there is an import with an alias
# e.g. MyImport.MyLib.a
if len(function_name_split) == 2:
library_name = function_name_split[0]
function_name = function_name_split[1]
# Get the library function
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for f in c.functions:
if f.name == function_name:
self._contract.using_for[type_name].append(f)
found = True
break
if not found:
self.log_incorrect_parsing(f"Library function not found {function_name}")
else:
self.log_incorrect_parsing(
f"Expected library function instead received {function_name}"
)

def analyze_enums(self):
try:
for father in self._contract.inheritance:
Expand Down
Loading

0 comments on commit f55cf6d

Please sign in to comment.