diff --git a/examples/scripts/data_dependency.py b/examples/scripts/data_dependency.py index 4783947669..23c82cae11 100644 --- a/examples/scripts/data_dependency.py +++ b/examples/scripts/data_dependency.py @@ -18,6 +18,8 @@ contract = contracts[0] destination = contract.get_state_variable_from_name("destination") source = contract.get_state_variable_from_name("source") +assert source +assert destination print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}") assert not is_dependent(source, destination, contract) @@ -47,9 +49,11 @@ assert is_tainted(destination, contract) destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1") +assert destination_indirect_1 print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}") assert is_tainted(destination_indirect_1, contract) destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2") +assert destination_indirect_2 print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}") assert is_tainted(destination_indirect_2, contract) @@ -88,6 +92,8 @@ contract_derived = slither.get_contract_from_name("Derived")[0] destination = contract.get_state_variable_from_name("destination") source = contract.get_state_variable_from_name("source") +assert destination +assert source print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}") assert not is_dependent(destination, source, contract) diff --git a/examples/scripts/variable_in_condition.py b/examples/scripts/variable_in_condition.py index 43dcf41e7e..bde41424da 100644 --- a/examples/scripts/variable_in_condition.py +++ b/examples/scripts/variable_in_condition.py @@ -14,6 +14,7 @@ contract = contracts[0] # Get the variable var_a = contract.get_state_variable_from_name("a") +assert var_a # Get the functions reading the variable functions_reading_a = contract.get_functions_reading_from_variable(var_a) diff --git a/slither/__main__.py b/slither/__main__.py index 5d0dda9e04..ca61a82691 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -66,7 +66,7 @@ def process_single( args: argparse.Namespace, detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[Slither, List[Dict], List[Dict], int]: +) -> Tuple[Slither, List[Dict], List[Output], int]: """ The core high-level code for running Slither static analysis. @@ -86,7 +86,7 @@ def process_all( args: argparse.Namespace, detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[List[Slither], List[Dict], List[Dict], int]: +) -> Tuple[List[Slither], List[Dict], List[Output], int]: compilations = compile_all(target, **vars(args)) slither_instances = [] results_detectors = [] @@ -141,23 +141,6 @@ def _process( return slither, results_detectors, results_printers, analyzed_contracts_count -# TODO: delete me? -def process_from_asts( - filenames: List[str], - args: argparse.Namespace, - detector_classes: List[Type[AbstractDetector]], - printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[Slither, List[Dict], List[Dict], int]: - all_contracts: List[str] = [] - - for filename in filenames: - with open(filename, encoding="utf8") as file_open: - contract_loaded = json.load(file_open) - all_contracts.append(contract_loaded["ast"]) - - return process_single(all_contracts, args, detector_classes, printer_classes) - - # endregion ################################################################################### ################################################################################### @@ -605,9 +588,6 @@ def parse_args( default=False, ) - # if the json is splitted in different files - parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False) - # Disable the throw/catch on partial analyses parser.add_argument( "--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False @@ -623,7 +603,7 @@ def parse_args( args.filter_paths = parse_filter_paths(args) # Verify our json-type output is valid - args.json_types = set(args.json_types.split(",")) + args.json_types = set(args.json_types.split(",")) # type:ignore for json_type in args.json_types: if json_type not in JSON_OUTPUT_TYPES: raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.') @@ -632,7 +612,9 @@ def parse_args( class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs detectors, _ = get_detectors_and_printers() output_detectors(detectors) parser.exit() @@ -694,14 +676,14 @@ def __call__( class FormatterCryticCompile(logging.Formatter): - def format(self, record): + def format(self, record: logging.LogRecord) -> str: # for i, msg in enumerate(record.msg): if record.msg.startswith("Compilation warnings/errors on "): - txt = record.args[1] - txt = txt.split("\n") + txt = record.args[1] # type:ignore + txt = txt.split("\n") # type:ignore txt = [red(x) if "Error" in x else x for x in txt] txt = "\n".join(txt) - record.args = (record.args[0], txt) + record.args = (record.args[0], txt) # type:ignore return super().format(record) @@ -744,7 +726,7 @@ def main_impl( set_colorization_enabled(False if args.disable_color else sys.stdout.isatty()) # Define some variables for potential JSON output - json_results = {} + json_results: Dict[str, Any] = {} output_error = None outputting_json = args.json is not None outputting_json_stdout = args.json == "-" @@ -793,7 +775,7 @@ def main_impl( crytic_compile_error.setLevel(logging.INFO) results_detectors: List[Dict] = [] - results_printers: List[Dict] = [] + results_printers: List[Output] = [] try: filename = args.filename @@ -806,26 +788,17 @@ def main_impl( number_contracts = 0 slither_instances = [] - if args.splitted: + for filename in filenames: ( slither_instance, - results_detectors, - results_printers, - number_contracts, - ) = process_from_asts(filenames, args, detector_classes, printer_classes) + results_detectors_tmp, + results_printers_tmp, + number_contracts_tmp, + ) = process_single(filename, args, detector_classes, printer_classes) + number_contracts += number_contracts_tmp + results_detectors += results_detectors_tmp + results_printers += results_printers_tmp slither_instances.append(slither_instance) - else: - for filename in filenames: - ( - slither_instance, - results_detectors_tmp, - results_printers_tmp, - number_contracts_tmp, - ) = process_single(filename, args, detector_classes, printer_classes) - number_contracts += number_contracts_tmp - results_detectors += results_detectors_tmp - results_printers += results_printers_tmp - slither_instances.append(slither_instance) # Rely on CryticCompile to discern the underlying type of compilations. else: diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index b2a1546729..d133cd2dc1 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -2,8 +2,9 @@ Compute the data depenency between all the SSA variables """ from collections import defaultdict -from typing import Union, Set, Dict, TYPE_CHECKING +from typing import Union, Set, Dict, TYPE_CHECKING, List +from slither.core.cfg.node import Node from slither.core.declarations import ( Contract, Enum, @@ -12,11 +13,14 @@ SolidityVariable, SolidityVariableComposed, Structure, + FunctionContract, ) from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder +from slither.core.solidity_types.type import Type from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.variable import Variable from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation +from slither.slithir.utils.utils import LVALUE from slither.slithir.variables import ( Constant, LocalIRVariable, @@ -26,12 +30,11 @@ TemporaryVariableSSA, TupleVariableSSA, ) -from slither.core.solidity_types.type import Type +from slither.slithir.variables.variable import SlithIRVariable if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit - ################################################################################### ################################################################################### # region User APIs @@ -39,26 +42,39 @@ ################################################################################### -Variable_types = Union[Variable, SolidityVariable] +SUPPORTED_TYPES = Union[Variable, SolidityVariable] + +# TODO refactor the data deps to be better suited for top level function object +# Right now we allow to pass a node to ease the API, but we need something +# better +# The deps propagation for top level elements is also not working as expected +Context_types_API = Union[Contract, Function, Node] Context_types = Union[Contract, Function] def is_dependent( - variable: Variable_types, - source: Variable_types, - context: Context_types, + variable: SUPPORTED_TYPES, + source: SUPPORTED_TYPES, + context: Context_types_API, only_unprotected: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable (Variable) source (Variable) - context (Contract|Function) + context (Contract|Function|Node). only_unprotected (bool): True only unprotected function are considered Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func + if isinstance(variable, Constant): return False if variable == source: @@ -74,12 +90,15 @@ def is_dependent( def is_dependent_ssa( - variable: Variable_types, - source: Variable_types, - context: Context_types, + variable: SUPPORTED_TYPES, + source: SUPPORTED_TYPES, + context: Context_types_API, only_unprotected: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable (Variable) taint (Variable) @@ -88,7 +107,10 @@ def is_dependent_ssa( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func context_dict = context.context if isinstance(variable, Constant): return False @@ -111,12 +133,15 @@ def is_dependent_ssa( def is_tainted( - variable: Variable_types, - context: Context_types, + variable: SUPPORTED_TYPES, + context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable context (Contract|Function) @@ -124,7 +149,10 @@ def is_tainted( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if isinstance(variable, Constant): return False @@ -138,12 +166,15 @@ def is_tainted( def is_tainted_ssa( - variable: Variable_types, - context: Context_types, + variable: SUPPORTED_TYPES, + context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, -): +) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable context (Contract|Function) @@ -151,7 +182,10 @@ def is_tainted_ssa( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if isinstance(variable, Constant): return False @@ -165,19 +199,24 @@ def is_tainted_ssa( def get_dependencies( - variable: Variable_types, - context: Context_types, + variable: SUPPORTED_TYPES, + context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param variable: The target :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: set(Variable) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set()) @@ -185,16 +224,21 @@ def get_dependencies( def get_all_dependencies( - context: Context_types, only_unprotected: bool = False + context: Context_types_API, only_unprotected: bool = False ) -> Dict[Variable, Set[Variable]]: """ Return the dictionary of dependencies. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: Dict(Variable, set(Variable)) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_NON_SSA_UNPROTECTED] @@ -202,19 +246,24 @@ def get_all_dependencies( def get_dependencies_ssa( - variable: Variable_types, - context: Context_types, + variable: SUPPORTED_TYPES, + context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on (SSA version). + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param variable: The target (must be SSA variable) :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: set(Variable) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_SSA_UNPROTECTED].get(variable, set()) @@ -222,16 +271,21 @@ def get_dependencies_ssa( def get_all_dependencies_ssa( - context: Context_types, only_unprotected: bool = False + context: Context_types_API, only_unprotected: bool = False ) -> Dict[Variable, Set[Variable]]: """ Return the dictionary of dependencies. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: Dict(Variable, set(Variable)) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_SSA_UNPROTECTED] @@ -341,13 +395,9 @@ def transitive_close_dependencies( while changed: changed = False to_add = defaultdict(set) - [ # pylint: disable=expression-not-assigned - [ + for key, items in context.context[context_key].items(): + for item in items & keys: to_add[key].update(context.context[context_key][item] - {key} - items) - for item in items & keys - ] - for key, items in context.context[context_key].items() - ] for k, v in to_add.items(): # Because we dont have any check on the update operation # We might update an empty set with an empty set @@ -366,20 +416,20 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote function.context[KEY_SSA][lvalue] = set() if not is_protected: function.context[KEY_SSA_UNPROTECTED][lvalue] = set() + read: Union[List[Union[LVALUE, SolidityVariableComposed]], List[SlithIRVariable]] if isinstance(ir, Index): read = [ir.variable_left] - elif isinstance(ir, InternalCall): + elif isinstance(ir, InternalCall) and ir.function: read = ir.function.return_values_ssa else: read = ir.read - # pylint: disable=expression-not-assigned - [function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] + for v in read: + if not isinstance(v, Constant): + function.context[KEY_SSA][lvalue].add(v) if not is_protected: - [ - function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) - for v in read - if not isinstance(v, Constant) - ] + for v in read: + if not isinstance(v, Constant): + function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) def compute_dependency_function(function: Function) -> None: @@ -407,7 +457,7 @@ def compute_dependency_function(function: Function) -> None: ) -def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types: +def convert_variable_to_non_ssa(v: SUPPORTED_TYPES) -> SUPPORTED_TYPES: if isinstance( v, ( @@ -438,10 +488,10 @@ def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types: def convert_to_non_ssa( - data_depencies: Dict[Variable_types, Set[Variable_types]] -) -> Dict[Variable_types, Set[Variable_types]]: + data_depencies: Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]] +) -> Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]]: # Need to create new set() as its changed during iteration - ret: Dict[Variable_types, Set[Variable_types]] = {} + ret: Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]] = {} for (k, values) in data_depencies.items(): var = convert_variable_to_non_ssa(k) if not var in ret: diff --git a/slither/analyses/write/are_variables_written.py b/slither/analyses/write/are_variables_written.py index 1b430012f3..2f8f83063d 100644 --- a/slither/analyses/write/are_variables_written.py +++ b/slither/analyses/write/are_variables_written.py @@ -2,10 +2,10 @@ Detect if all the given variables are written in all the paths of the function """ from collections import defaultdict -from typing import Dict, Set, List +from typing import Dict, Set, List, Any, Optional from slither.core.cfg.node import NodeType, Node -from slither.core.declarations import SolidityFunction +from slither.core.declarations import SolidityFunction, Function from slither.core.variables.variable import Variable from slither.slithir.operations import ( Index, @@ -18,7 +18,7 @@ class State: # pylint: disable=too-few-public-methods - def __init__(self): + def __init__(self) -> None: # Map node -> list of variables set # Were each variables set represents a configuration of a path # If two paths lead to the exact same set of variables written, we dont need to explore both @@ -34,11 +34,11 @@ def __init__(self): # pylint: disable=too-many-branches def _visit( - node: Node, + node: Optional[Node], state: State, variables_written: Set[Variable], variables_to_write: List[Variable], -): +) -> List[Variable]: """ Explore all the nodes to look for values not written when the node's function return Fixpoint reaches if no new written variables are found @@ -51,6 +51,8 @@ def _visit( refs = {} variables_written = set(variables_written) + if not node: + return [] for ir in node.irs: if isinstance(ir, SolidityCall): # TODO convert the revert to a THROW node @@ -70,17 +72,20 @@ def _visit( if ir.lvalue and not isinstance(ir.lvalue, (TemporaryVariable, ReferenceVariable)): variables_written.add(ir.lvalue) - lvalue = ir.lvalue + lvalue: Any = ir.lvalue while isinstance(lvalue, ReferenceVariable): if lvalue not in refs: break - if refs[lvalue] and not isinstance( - refs[lvalue], (TemporaryVariable, ReferenceVariable) + refs_lvalues = refs[lvalue] + if ( + refs_lvalues + and isinstance(refs_lvalues, Variable) + and not isinstance(refs_lvalues, (TemporaryVariable, ReferenceVariable)) ): - variables_written.add(refs[lvalue]) - lvalue = refs[lvalue] + variables_written.add(refs_lvalues) + lvalue = refs_lvalues - ret = [] + ret: List[Variable] = [] if not node.sons and node.type not in [NodeType.THROW, NodeType.RETURN]: ret += [v for v in variables_to_write if v not in variables_written] @@ -96,7 +101,7 @@ def _visit( return ret -def are_variables_written(function, variables_to_write): +def are_variables_written(function: Function, variables_to_write: List[Variable]) -> List[Variable]: """ Return the list of variable that are not written at the end of the function diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 7643b19b7c..5138e796aa 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -5,8 +5,7 @@ from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING from slither.all_exceptions import SlitherException -from slither.core.children.child_function import ChildFunction -from slither.core.declarations import Contract, Function +from slither.core.declarations import Contract, Function, FunctionContract from slither.core.declarations.solidity_variables import ( SolidityVariable, SolidityFunction, @@ -33,6 +32,7 @@ Return, Operation, ) +from slither.slithir.utils.utils import RVALUE from slither.slithir.variables import ( Constant, LocalIRVariable, @@ -106,7 +106,7 @@ class NodeType(Enum): # I am not sure why, but pylint reports a lot of "no-member" issue that are not real (Josselin) # pylint: disable=no-member -class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-methods +class Node(SourceMapping): # pylint: disable=too-many-public-methods """ Node class @@ -146,12 +146,12 @@ def __init__( self._node_id: int = node_id self._vars_written: List[Variable] = [] - self._vars_read: List[Variable] = [] + self._vars_read: List[Union[Variable, SolidityVariable]] = [] self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = [] - self._internal_calls: List["Function"] = [] + self._internal_calls: List[Union["Function", "SolidityFunction"]] = [] self._solidity_calls: List[SolidityFunction] = [] self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls self._library_calls: List["LibraryCallType"] = [] @@ -172,7 +172,9 @@ def __init__( self._local_vars_read: List[LocalVariable] = [] self._local_vars_written: List[LocalVariable] = [] - self._slithir_vars: Set["SlithIRVariable"] = set() # non SSA + self._slithir_vars: Set[ + Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable] + ] = set() # non SSA self._ssa_local_vars_read: List[LocalIRVariable] = [] self._ssa_local_vars_written: List[LocalIRVariable] = [] @@ -189,6 +191,7 @@ def __init__( self.scope: Union["Scope", "Function"] = scope self.file_scope: "FileScope" = file_scope + self._function: Optional["Function"] = None ################################################################################### ################################################################################### @@ -213,7 +216,7 @@ def type(self) -> NodeType: return self._node_type @type.setter - def type(self, new_type: NodeType): + def type(self, new_type: NodeType) -> None: self._node_type = new_type @property @@ -224,6 +227,13 @@ def will_return(self) -> bool: return True return False + def set_function(self, function: "Function") -> None: + self._function = function + + @property + def function(self) -> "Function": + return self._function + # endregion ################################################################################### ################################################################################### @@ -232,7 +242,7 @@ def will_return(self) -> bool: ################################################################################### @property - def variables_read(self) -> List[Variable]: + def variables_read(self) -> List[Union[Variable, SolidityVariable]]: """ list(Variable): Variables read (local/state/solidity) """ @@ -285,11 +295,13 @@ def variables_read_as_expression(self) -> List[Expression]: return self._expression_vars_read @variables_read_as_expression.setter - def variables_read_as_expression(self, exprs: List[Expression]): + def variables_read_as_expression(self, exprs: List[Expression]) -> None: self._expression_vars_read = exprs @property - def slithir_variables(self) -> List["SlithIRVariable"]: + def slithir_variables( + self, + ) -> List[Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]]: return list(self._slithir_vars) @property @@ -339,7 +351,7 @@ def variables_written_as_expression(self) -> List[Expression]: return self._expression_vars_written @variables_written_as_expression.setter - def variables_written_as_expression(self, exprs: List[Expression]): + def variables_written_as_expression(self, exprs: List[Expression]) -> None: self._expression_vars_written = exprs # endregion @@ -399,7 +411,7 @@ def external_calls_as_expressions(self) -> List[Expression]: return self._external_calls_as_expressions @external_calls_as_expressions.setter - def external_calls_as_expressions(self, exprs: List[Expression]): + def external_calls_as_expressions(self, exprs: List[Expression]) -> None: self._external_calls_as_expressions = exprs @property @@ -410,7 +422,7 @@ def internal_calls_as_expressions(self) -> List[Expression]: return self._internal_calls_as_expressions @internal_calls_as_expressions.setter - def internal_calls_as_expressions(self, exprs: List[Expression]): + def internal_calls_as_expressions(self, exprs: List[Expression]) -> None: self._internal_calls_as_expressions = exprs @property @@ -418,10 +430,10 @@ def calls_as_expression(self) -> List[Expression]: return list(self._expression_calls) @calls_as_expression.setter - def calls_as_expression(self, exprs: List[Expression]): + def calls_as_expression(self, exprs: List[Expression]) -> None: self._expression_calls = exprs - def can_reenter(self, callstack=None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Check if the node can re-enter Do not consider CREATE as potential re-enter, but check if the @@ -567,7 +579,7 @@ def add_father(self, father: "Node") -> None: """ self._fathers.append(father) - def set_fathers(self, fathers: List["Node"]): + def set_fathers(self, fathers: List["Node"]) -> None: """Set the father nodes Args: @@ -663,20 +675,20 @@ def irs_ssa(self) -> List[Operation]: return self._irs_ssa @irs_ssa.setter - def irs_ssa(self, irs): + def irs_ssa(self, irs: List[Operation]) -> None: self._irs_ssa = irs def add_ssa_ir(self, ir: Operation) -> None: """ Use to place phi operation """ - ir.set_node(self) + ir.set_node(self) # type: ignore self._irs_ssa.append(ir) def slithir_generation(self) -> None: if self.expression: expression = self.expression - self._irs = convert_expression(expression, self) + self._irs = convert_expression(expression, self) # type:ignore self._find_read_write_call() @@ -713,7 +725,7 @@ def dominators(self) -> Set["Node"]: return self._dominators @dominators.setter - def dominators(self, dom: Set["Node"]): + def dominators(self, dom: Set["Node"]) -> None: self._dominators = dom @property @@ -725,7 +737,7 @@ def immediate_dominator(self) -> Optional["Node"]: return self._immediate_dominator @immediate_dominator.setter - def immediate_dominator(self, idom: "Node"): + def immediate_dominator(self, idom: "Node") -> None: self._immediate_dominator = idom @property @@ -737,7 +749,7 @@ def dominance_frontier(self) -> Set["Node"]: return self._dominance_frontier @dominance_frontier.setter - def dominance_frontier(self, doms: Set["Node"]): + def dominance_frontier(self, doms: Set["Node"]) -> None: """ Returns: set(Node) @@ -789,6 +801,7 @@ def phi_origins_state_variables( def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None: if variable.name not in self._phi_origins_local_variables: + assert variable.name self._phi_origins_local_variables[variable.name] = (variable, set()) (v, nodes) = self._phi_origins_local_variables[variable.name] assert v == variable @@ -827,7 +840,8 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements if isinstance(ir, OperationWithLValue): var = ir.lvalue if var and self._is_valid_slithir_var(var): - self._slithir_vars.add(var) + # The type is checked by is_valid_slithir_var + self._slithir_vars.add(var) # type: ignore if not isinstance(ir, (Phi, Index, Member)): self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)] @@ -835,8 +849,9 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements if isinstance(var, ReferenceVariable): self._vars_read.append(var.points_to_origin) elif isinstance(ir, (Member, Index)): + # TODO investigate types for member variable left var = ir.variable_left if isinstance(ir, Member) else ir.variable_right - if self._is_non_slithir_var(var): + if var and self._is_non_slithir_var(var): self._vars_read.append(var) if isinstance(var, ReferenceVariable): origin = var.points_to_origin @@ -860,14 +875,21 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements self._internal_calls.append(ir.function) if isinstance(ir, LowLevelCall): assert isinstance(ir.destination, (Variable, SolidityVariable)) - self._low_level_calls.append((ir.destination, ir.function_name.value)) + self._low_level_calls.append((ir.destination, str(ir.function_name.value))) elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): + # Todo investigate this if condition + # It does seem right to compare against a contract + # This might need a refactoring if isinstance(ir.destination.type, Contract): self._high_level_calls.append((ir.destination.type, ir.function)) elif ir.destination == SolidityVariable("this"): - self._high_level_calls.append((self.function.contract, ir.function)) + func = self.function + # Can't use this in a top level function + assert isinstance(func, FunctionContract) + self._high_level_calls.append((func.contract, ir.function)) else: try: + # Todo this part needs more tests and documentation self._high_level_calls.append((ir.destination.type.type, ir.function)) except AttributeError as error: # pylint: disable=raise-missing-from @@ -883,7 +905,9 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements self._vars_read = list(set(self._vars_read)) self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] - self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] + self._solidity_vars_read = [ + v_ for v_ in self._vars_read if isinstance(v_, SolidityVariable) + ] self._vars_written = list(set(self._vars_written)) self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] @@ -895,12 +919,15 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements @staticmethod def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]: + non_ssa_var: Optional[Union[StateVariable, LocalVariable]] if isinstance(v, StateIRVariable): contract = v.contract + assert v.name non_ssa_var = contract.get_state_variable_from_name(v.name) return non_ssa_var assert isinstance(v, LocalIRVariable) function = v.function + assert v.name non_ssa_var = function.get_local_variable_from_name(v.name) return non_ssa_var @@ -921,10 +948,11 @@ def update_read_write_using_ssa(self) -> None: self._ssa_vars_read.append(origin) elif isinstance(ir, (Member, Index)): - if isinstance(ir.variable_right, (StateIRVariable, LocalIRVariable)): - self._ssa_vars_read.append(ir.variable_right) - if isinstance(ir.variable_right, ReferenceVariable): - origin = ir.variable_right.points_to_origin + variable_right: RVALUE = ir.variable_right + if isinstance(variable_right, (StateIRVariable, LocalIRVariable)): + self._ssa_vars_read.append(variable_right) + if isinstance(variable_right, ReferenceVariable): + origin = variable_right.points_to_origin if isinstance(origin, (StateIRVariable, LocalIRVariable)): self._ssa_vars_read.append(origin) @@ -944,20 +972,20 @@ def update_read_write_using_ssa(self) -> None: self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] self._ssa_vars_written = list(set(self._ssa_vars_written)) self._ssa_state_vars_written = [ - v for v in self._ssa_vars_written if isinstance(v, StateVariable) + v for v in self._ssa_vars_written if v and isinstance(v, StateIRVariable) ] self._ssa_local_vars_written = [ - v for v in self._ssa_vars_written if isinstance(v, LocalVariable) + v for v in self._ssa_vars_written if v and isinstance(v, LocalIRVariable) ] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] - self._vars_read += [v for v in vars_read if v not in self._vars_read] + self._vars_read += [v_ for v_ in vars_read if v_ and v_ not in self._vars_read] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] - self._vars_written += [v for v in vars_written if v not in self._vars_written] + self._vars_written += [v_ for v_ in vars_written if v_ and v_ not in self._vars_written] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] @@ -974,7 +1002,7 @@ def __str__(self) -> str: additional_info += " " + str(self.expression) elif self.variable_declaration: additional_info += " " + str(self.variable_declaration) - txt = self._node_type.value + additional_info + txt = str(self._node_type.value) + additional_info return txt diff --git a/slither/core/children/child_contract.py b/slither/core/children/child_contract.py deleted file mode 100644 index 86f9dea532..0000000000 --- a/slither/core/children/child_contract.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import TYPE_CHECKING - -from slither.core.source_mapping.source_mapping import SourceMapping - -if TYPE_CHECKING: - from slither.core.declarations import Contract - - -class ChildContract(SourceMapping): - def __init__(self) -> None: - super().__init__() - self._contract = None - - def set_contract(self, contract: "Contract") -> None: - self._contract = contract - - @property - def contract(self) -> "Contract": - return self._contract diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py index df91596e34..e69de29bb2 100644 --- a/slither/core/children/child_event.py +++ b/slither/core/children/child_event.py @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Event - - -class ChildEvent: - def __init__(self) -> None: - super().__init__() - self._event = None - - def set_event(self, event: "Event"): - self._event = event - - @property - def event(self) -> "Event": - return self._event diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py deleted file mode 100644 index 0064658c01..0000000000 --- a/slither/core/children/child_expression.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from slither.core.expressions.expression import Expression - from slither.slithir.operations import Operation - - -class ChildExpression: - def __init__(self) -> None: - super().__init__() - self._expression = None - - def set_expression(self, expression: Union["Expression", "Operation"]) -> None: - self._expression = expression - - @property - def expression(self) -> Union["Expression", "Operation"]: - return self._expression diff --git a/slither/core/children/child_function.py b/slither/core/children/child_function.py deleted file mode 100644 index 5367320cab..0000000000 --- a/slither/core/children/child_function.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Function - - -class ChildFunction: - def __init__(self) -> None: - super().__init__() - self._function = None - - def set_function(self, function: "Function") -> None: - self._function = function - - @property - def function(self) -> "Function": - return self._function diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py deleted file mode 100644 index 30b32f6c19..0000000000 --- a/slither/core/children/child_inheritance.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Contract - - -class ChildInheritance: - def __init__(self) -> None: - super().__init__() - self._contract_declarer = None - - def set_contract_declarer(self, contract: "Contract") -> None: - self._contract_declarer = contract - - @property - def contract_declarer(self) -> "Contract": - return self._contract_declarer diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py deleted file mode 100644 index 8e6e1f0b5d..0000000000 --- a/slither/core/children/child_node.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.compilation_unit import SlitherCompilationUnit - from slither.core.cfg.node import Node - from slither.core.declarations import Function, Contract - - -class ChildNode: - def __init__(self) -> None: - super().__init__() - self._node = None - - def set_node(self, node: "Node") -> None: - self._node = node - - @property - def node(self) -> "Node": - return self._node - - @property - def function(self) -> "Function": - return self.node.function - - @property - def contract(self) -> "Contract": - return self.node.function.contract - - @property - def compilation_unit(self) -> "SlitherCompilationUnit": - return self.node.compilation_unit diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py deleted file mode 100644 index abcb041c21..0000000000 --- a/slither/core/children/child_structure.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Structure - - -class ChildStructure: - def __init__(self) -> None: - super().__init__() - self._structure = None - - def set_structure(self, structure: "Structure") -> None: - self._structure = structure - - @property - def structure(self) -> "Structure": - return self._structure diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index f54f08ab37..8d71674515 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -57,7 +57,7 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} - self._contract_with_missing_inheritance = set() + self._contract_with_missing_inheritance: Set[Contract] = set() self._source_units: Dict[int, str] = {} @@ -88,7 +88,8 @@ def compiler_version(self) -> CompilerVersion: @property def solc_version(self) -> str: - return self._crytic_compile_compilation_unit.compiler_version.version + # TODO: make version a non optional argument of compiler version in cc + return self._crytic_compile_compilation_unit.compiler_version.version # type:ignore @property def crytic_compile_compilation_unit(self) -> CompilationUnit: @@ -162,13 +163,14 @@ def add_modifier(self, modif: Modifier) -> None: @property def functions_and_modifiers(self) -> List[Function]: - return self.functions + self.modifiers + return self.functions + list(self.modifiers) def propagate_function_calls(self) -> None: for f in self.functions_and_modifiers: for node in f.nodes: for ir in node.irs_ssa: if isinstance(ir, InternalCall): + assert ir.function ir.function.add_reachable_from_node(node, ir) # endregion @@ -181,8 +183,8 @@ def propagate_function_calls(self) -> None: @property def state_variables(self) -> List[StateVariable]: if self._all_state_variables is None: - state_variables = [c.state_variables for c in self.contracts] - state_variables = [item for sublist in state_variables for item in sublist] + state_variabless = [c.state_variables for c in self.contracts] + state_variables = [item for sublist in state_variabless for item in sublist] self._all_state_variables = set(state_variables) return list(self._all_state_variables) @@ -229,7 +231,7 @@ def user_defined_value_types(self) -> Dict[str, TypeAliasTopLevel]: ################################################################################### @property - def contracts_with_missing_inheritance(self) -> Set: + def contracts_with_missing_inheritance(self) -> Set[Contract]: return self._contract_with_missing_inheritance # endregion @@ -266,6 +268,7 @@ def compute_storage_layout(self) -> None: if var.is_constant or var.is_immutable: continue + assert var.type size, new_slot = var.type.storage_size if new_slot: @@ -285,7 +288,7 @@ def compute_storage_layout(self) -> None: else: offset += size - def storage_layout_of(self, contract, var) -> Tuple[int, int]: + def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]: return self._storage_layouts[contract.name][var.canonical_name] # endregion diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index b6a0b3d243..fd8f761c6b 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -49,6 +49,9 @@ LOGGER = logging.getLogger("Contract") +USING_FOR_KEY = Union[str, Type] +USING_FOR_ITEM = List[Union[Type, Function]] + class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ @@ -80,8 +83,8 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope self._custom_errors: Dict[str, "CustomErrorContract"] = {} # The only str is "*" - self._using_for: Dict[Union[str, Type], List[Type]] = {} - self._using_for_complete: Dict[Union[str, Type], List[Type]] = None + self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {} + self._using_for_complete: Optional[Dict[USING_FOR_KEY, USING_FOR_ITEM]] = None self._kind: Optional[str] = None self._is_interface: bool = False self._is_library: bool = False @@ -126,7 +129,7 @@ def name(self) -> str: return self._name @name.setter - def name(self, name: str): + def name(self, name: str) -> None: self._name = name @property @@ -136,7 +139,7 @@ def id(self) -> int: return self._id @id.setter - def id(self, new_id): + def id(self, new_id: int) -> None: """Unique id.""" self._id = new_id @@ -149,7 +152,7 @@ def contract_kind(self) -> Optional[str]: return self._kind @contract_kind.setter - def contract_kind(self, kind): + def contract_kind(self, kind: str) -> None: self._kind = kind @property @@ -157,7 +160,7 @@ def is_interface(self) -> bool: return self._is_interface @is_interface.setter - def is_interface(self, is_interface: bool): + def is_interface(self, is_interface: bool) -> None: self._is_interface = is_interface @property @@ -165,7 +168,7 @@ def is_library(self) -> bool: return self._is_library @is_library.setter - def is_library(self, is_library: bool): + def is_library(self, is_library: bool) -> None: self._is_library = is_library @property @@ -302,16 +305,18 @@ def events_as_dict(self) -> Dict[str, "Event"]: ################################################################################### @property - def using_for(self) -> Dict[Union[str, Type], List[Type]]: + def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: return self._using_for @property - def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]: + def using_for_complete(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: """ Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive """ - def _merge_using_for(uf1, uf2): + def _merge_using_for( + uf1: Dict[USING_FOR_KEY, USING_FOR_ITEM], uf2: Dict[USING_FOR_KEY, USING_FOR_ITEM] + ) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: result = {**uf1, **uf2} for key, value in result.items(): if key in uf1 and key in uf2: @@ -491,7 +496,7 @@ def constructors_declared(self) -> Optional["Function"]: ) @property - def constructors(self) -> List["Function"]: + def constructors(self) -> List["FunctionContract"]: """ Return the list of constructors (including inherited) """ @@ -560,14 +565,14 @@ def functions(self) -> List["FunctionContract"]: """ return list(self._functions.values()) - def available_functions_as_dict(self) -> Dict[str, "FunctionContract"]: + def available_functions_as_dict(self) -> Dict[str, "Function"]: if self._available_functions_as_dict is None: self._available_functions_as_dict = { f.full_name: f for f in self._functions.values() if not f.is_shadowed } return self._available_functions_as_dict - def add_function(self, func: "FunctionContract"): + def add_function(self, func: "FunctionContract") -> None: self._functions[func.canonical_name] = func def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None: @@ -735,7 +740,7 @@ def derived_contracts(self) -> List["Contract"]: list(Contract): Return the list of contracts derived from self """ candidates = self.compilation_unit.contracts - return [c for c in candidates if self in c.inheritance] + return [c for c in candidates if self in c.inheritance] # type: ignore # endregion ################################################################################### @@ -891,7 +896,7 @@ def get_enum_from_name(self, enum_name: str) -> Optional["Enum"]: """ return next((e for e in self.enums if e.name == enum_name), None) - def get_enum_from_canonical_name(self, enum_name) -> Optional["Enum"]: + def get_enum_from_canonical_name(self, enum_name: str) -> Optional["Enum"]: """ Return an enum from a canonical name Args: @@ -992,7 +997,9 @@ def all_high_level_calls(self) -> List["HighLevelCallType"]: ################################################################################### ################################################################################### - def get_summary(self, include_shadowed=True) -> Tuple[str, List[str], List[str], List, List]: + def get_summary( + self, include_shadowed: bool = True + ) -> Tuple[str, List[str], List[str], List, List]: """Return the function summary :param include_shadowed: boolean to indicate if shadowed functions should be included (default True) @@ -1245,7 +1252,7 @@ def is_truffle_migration(self) -> bool: @property def is_test(self) -> bool: - return is_test_contract(self) or self.is_truffle_migration + return is_test_contract(self) or self.is_truffle_migration # type: ignore # endregion ################################################################################### @@ -1255,7 +1262,7 @@ def is_test(self) -> bool: ################################################################################### def update_read_write_using_ssa(self) -> None: - for function in self.functions + self.modifiers: + for function in self.functions + list(self.modifiers): function.update_read_write_using_ssa() # endregion @@ -1290,7 +1297,7 @@ def is_upgradeable(self) -> bool: return self._is_upgradeable @is_upgradeable.setter - def is_upgradeable(self, upgradeable: bool): + def is_upgradeable(self, upgradeable: bool) -> None: self._is_upgradeable = upgradeable @property @@ -1319,7 +1326,7 @@ def is_upgradeable_proxy(self) -> bool: return self._is_upgradeable_proxy @is_upgradeable_proxy.setter - def is_upgradeable_proxy(self, upgradeable_proxy: bool): + def is_upgradeable_proxy(self, upgradeable_proxy: bool) -> None: self._is_upgradeable_proxy = upgradeable_proxy @property @@ -1327,7 +1334,7 @@ def upgradeable_version(self) -> Optional[str]: return self._upgradeable_version @upgradeable_version.setter - def upgradeable_version(self, version_name: str): + def upgradeable_version(self, version_name: str) -> None: self._upgradeable_version = version_name # endregion @@ -1346,7 +1353,7 @@ def is_incorrectly_constructed(self) -> bool: return self._is_incorrectly_parsed @is_incorrectly_constructed.setter - def is_incorrectly_constructed(self, incorrect: bool): + def is_incorrectly_constructed(self, incorrect: bool) -> None: self._is_incorrectly_parsed = incorrect def add_constructor_variables(self) -> None: @@ -1358,8 +1365,8 @@ def add_constructor_variables(self) -> None: constructor_variable = FunctionContract(self.compilation_unit) constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) + constructor_variable.set_contract(self) # type: ignore + constructor_variable.set_contract_declarer(self) # type: ignore constructor_variable.set_visibility("internal") # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping @@ -1390,8 +1397,8 @@ def add_constructor_variables(self) -> None: constructor_variable.set_function_type( FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES ) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) + constructor_variable.set_contract(self) # type: ignore + constructor_variable.set_contract_declarer(self) # type: ignore constructor_variable.set_visibility("internal") # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping @@ -1472,22 +1479,23 @@ def convert_expression_to_slithir_ssa(self) -> None: all_ssa_state_variables_instances[v.canonical_name] = new_var self._initial_state_variables.append(new_var) - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): func.generate_slithir_ssa(all_ssa_state_variables_instances) def fix_phi(self) -> None: - last_state_variables_instances = {} - initial_state_variables_instances = {} + last_state_variables_instances: Dict[str, List["StateVariable"]] = {} + initial_state_variables_instances: Dict[str, "StateVariable"] = {} for v in self._initial_state_variables: last_state_variables_instances[v.canonical_name] = [] initial_state_variables_instances[v.canonical_name] = v - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): result = func.get_last_ssa_state_variables_instances() for variable_name, instances in result.items(): - last_state_variables_instances[variable_name] += instances + # TODO: investigate the next operation + last_state_variables_instances[variable_name] += list(instances) - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): func.fix_phi(last_state_variables_instances, initial_state_variables_instances) # endregion @@ -1497,7 +1505,7 @@ def fix_phi(self) -> None: ################################################################################### ################################################################################### - def __eq__(self, other: SourceMapping) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, str): return other == self.name return NotImplemented @@ -1511,6 +1519,6 @@ def __str__(self) -> str: return self.name def __hash__(self) -> int: - return self._id + return self._id # type:ignore # endregion diff --git a/slither/core/declarations/contract_level.py b/slither/core/declarations/contract_level.py new file mode 100644 index 0000000000..9b81e6d337 --- /dev/null +++ b/slither/core/declarations/contract_level.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING, Optional + +from slither.core.source_mapping.source_mapping import SourceMapping + +if TYPE_CHECKING: + from slither.core.declarations import Contract + + +class ContractLevel(SourceMapping): + """ + This class is used to represent objects that are at the contract level + The opposite is TopLevel + + """ + + def __init__(self) -> None: + super().__init__() + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._contract: Optional["Contract"] = None + + def set_contract(self, contract: "Contract") -> None: + self._contract = contract + + @property + def contract(self) -> "Contract": + assert self._contract + return self._contract diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index 5e851c8da1..7e78748c60 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -1,4 +1,4 @@ -from typing import List, TYPE_CHECKING, Optional, Type, Union +from typing import List, TYPE_CHECKING, Optional, Type from slither.core.solidity_types import UserDefinedType from slither.core.source_mapping.source_mapping import SourceMapping @@ -42,7 +42,7 @@ def compilation_unit(self) -> "SlitherCompilationUnit": ################################################################################### @staticmethod - def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str: + def _convert_type_for_solidity_signature(t: Optional[Type]) -> str: # pylint: disable=import-outside-toplevel from slither.core.declarations import Contract @@ -51,7 +51,7 @@ def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) - return str(t) @property - def solidity_signature(self) -> Optional[str]: + def solidity_signature(self) -> str: """ Return a signature following the Solidity Standard Contract and converted into address @@ -63,7 +63,7 @@ def solidity_signature(self) -> Optional[str]: # (set_solidity_sig was not called before find_variable) if self._solidity_signature is None: raise ValueError("Custom Error not yet built") - return self._solidity_signature + return self._solidity_signature # type: ignore def set_solidity_sig(self) -> None: """ @@ -72,7 +72,7 @@ def set_solidity_sig(self) -> None: Returns: """ - parameters = [x.type for x in self.parameters] + parameters = [x.type for x in self.parameters if x.type] self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")" solidity_parameters = map(self._convert_type_for_solidity_signature, parameters) self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")" diff --git a/slither/core/declarations/custom_error_contract.py b/slither/core/declarations/custom_error_contract.py index a96f120575..cd279a3a62 100644 --- a/slither/core/declarations/custom_error_contract.py +++ b/slither/core/declarations/custom_error_contract.py @@ -1,9 +1,15 @@ -from slither.core.children.child_contract import ChildContract +from typing import TYPE_CHECKING +from slither.core.declarations.contract_level import ContractLevel + + from slither.core.declarations.custom_error import CustomError +if TYPE_CHECKING: + from slither.core.declarations import Contract + -class CustomErrorContract(CustomError, ChildContract): - def is_declared_by(self, contract): +class CustomErrorContract(CustomError, ContractLevel): + def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract :param contract: diff --git a/slither/core/declarations/custom_error_top_level.py b/slither/core/declarations/custom_error_top_level.py index 29a9fd41ac..64a6a85353 100644 --- a/slither/core/declarations/custom_error_top_level.py +++ b/slither/core/declarations/custom_error_top_level.py @@ -9,6 +9,6 @@ class CustomErrorTopLevel(CustomError, TopLevel): - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self.file_scope: "FileScope" = scope diff --git a/slither/core/declarations/enum_contract.py b/slither/core/declarations/enum_contract.py index 46168d1073..2e51ae5116 100644 --- a/slither/core/declarations/enum_contract.py +++ b/slither/core/declarations/enum_contract.py @@ -1,13 +1,13 @@ from typing import TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Enum if TYPE_CHECKING: from slither.core.declarations import Contract -class EnumContract(Enum, ChildContract): +class EnumContract(Enum, ContractLevel): def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract diff --git a/slither/core/declarations/event.py b/slither/core/declarations/event.py index d616679a22..9d42ac224b 100644 --- a/slither/core/declarations/event.py +++ b/slither/core/declarations/event.py @@ -1,6 +1,6 @@ from typing import List, Tuple, TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.event_variable import EventVariable @@ -8,7 +8,7 @@ from slither.core.declarations import Contract -class Event(ChildContract, SourceMapping): +class Event(ContractLevel, SourceMapping): def __init__(self) -> None: super().__init__() self._name = None diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index c383fc99b0..e778019617 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -47,7 +47,6 @@ from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope from slither.slithir.variables.state_variable import StateIRVariable - from slither.core.declarations.function_contract import FunctionContract LOGGER = logging.getLogger("Function") ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) @@ -298,7 +297,7 @@ def contains_assembly(self) -> bool: def contains_assembly(self, c: bool): self._contains_assembly = c - def can_reenter(self, callstack: Optional[List["FunctionContract"]] = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union["Function", "Variable"]]] = None) -> bool: """ Check if the function can re-enter Follow internal calls. @@ -1720,8 +1719,8 @@ def _unchange_phi(ir: "Operation") -> bool: def fix_phi( self, - last_state_variables_instances: Dict[str, List["StateIRVariable"]], - initial_state_variables_instances: Dict[str, "StateIRVariable"], + last_state_variables_instances: Dict[str, List["StateVariable"]], + initial_state_variables_instances: Dict[str, "StateVariable"], ) -> None: from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.variables import Constant, StateIRVariable diff --git a/slither/core/declarations/function_contract.py b/slither/core/declarations/function_contract.py index 01077353b1..8f68ab7b60 100644 --- a/slither/core/declarations/function_contract.py +++ b/slither/core/declarations/function_contract.py @@ -1,10 +1,9 @@ """ Function module """ -from typing import Dict, TYPE_CHECKING, List, Tuple +from typing import Dict, TYPE_CHECKING, List, Tuple, Optional -from slither.core.children.child_contract import ChildContract -from slither.core.children.child_inheritance import ChildInheritance +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Function from slither.utils.code_complexity import compute_cyclomatic_complexity @@ -15,9 +14,31 @@ from slither.core.declarations import Contract from slither.core.scope.scope import FileScope from slither.slithir.variables.state_variable import StateIRVariable + from slither.core.compilation_unit import SlitherCompilationUnit -class FunctionContract(Function, ChildContract, ChildInheritance): +class FunctionContract(Function, ContractLevel): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: + super().__init__(compilation_unit) + self._contract_declarer: Optional["Contract"] = None + + def set_contract_declarer(self, contract: "Contract") -> None: + self._contract_declarer = contract + + @property + def contract_declarer(self) -> "Contract": + """ + Return the contract where this function was declared. Only functions have both a contract, and contract_declarer + This is because we need to have separate representation of the function depending of the contract's context + For example a function calling super.f() will generate different IR depending on the current contract's inheritance + + Returns: + The contract where this function was declared + """ + + assert self._contract_declarer + return self._contract_declarer + @property def canonical_name(self) -> str: """ diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index 9569cde939..f0e903d7b2 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -82,7 +82,7 @@ } -def solidity_function_signature(name): +def solidity_function_signature(name: str) -> str: """ Return the function signature (containing the return value) It is useful if a solidity function is used as a pointer @@ -106,7 +106,7 @@ def _check_name(self, name: str) -> None: # pylint: disable=no-self-use assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset")) @property - def state_variable(self): + def state_variable(self) -> str: if self._name.endswith("_slot"): return self._name[:-5] if self._name.endswith("_offset"): @@ -125,7 +125,7 @@ def type(self) -> ElementaryType: def __str__(self) -> str: return self._name - def __eq__(self, other: SourceMapping) -> bool: + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name def __hash__(self) -> int: @@ -182,13 +182,13 @@ def return_type(self) -> List[Union[TypeInformation, ElementaryType]]: return self._return_type @return_type.setter - def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): + def return_type(self, r: List[Union[TypeInformation, ElementaryType]]) -> None: self._return_type = r def __str__(self) -> str: return self._name - def __eq__(self, other: "SolidityFunction") -> bool: + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name def __hash__(self) -> int: @@ -201,7 +201,7 @@ def __init__(self, custom_error: CustomError) -> None: # pylint: disable=super- self._custom_error = custom_error self._return_type: List[Union[TypeInformation, ElementaryType]] = [] - def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool: + def __eq__(self, other: Any) -> bool: return ( self.__class__ == other.__class__ and self.name == other.name diff --git a/slither/core/declarations/structure_contract.py b/slither/core/declarations/structure_contract.py index aaf660e1ef..c9d05ce4ef 100644 --- a/slither/core/declarations/structure_contract.py +++ b/slither/core/declarations/structure_contract.py @@ -1,8 +1,8 @@ -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Structure -class StructureContract(Structure, ChildContract): +class StructureContract(Structure, ContractLevel): def is_declared_by(self, contract): """ Check if the element is declared by the contract diff --git a/slither/core/declarations/top_level.py b/slither/core/declarations/top_level.py index 15facf2f99..01e6f6dfde 100644 --- a/slither/core/declarations/top_level.py +++ b/slither/core/declarations/top_level.py @@ -2,4 +2,8 @@ class TopLevel(SourceMapping): - pass + """ + This class is used to represent objects that are at the top level + The opposite is ContractLevel + + """ diff --git a/slither/core/declarations/using_for_top_level.py b/slither/core/declarations/using_for_top_level.py index 27d1f90e47..edf846a5b1 100644 --- a/slither/core/declarations/using_for_top_level.py +++ b/slither/core/declarations/using_for_top_level.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, List, Dict, Union +from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM from slither.core.solidity_types.type import Type from slither.core.declarations.top_level import TopLevel @@ -14,5 +15,5 @@ def __init__(self, scope: "FileScope") -> None: self.file_scope: "FileScope" = scope @property - def using_for(self) -> Dict[Union[str, Type], List[Type]]: + def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: return self._using_for diff --git a/slither/core/dominators/utils.py b/slither/core/dominators/utils.py index ca5c51282e..4dd55749d9 100644 --- a/slither/core/dominators/utils.py +++ b/slither/core/dominators/utils.py @@ -95,4 +95,5 @@ def compute_dominance_frontier(nodes: List["Node"]) -> None: runner.dominance_frontier = runner.dominance_frontier.union({node}) while runner != node.immediate_dominator: runner.dominance_frontier = runner.dominance_frontier.union({node}) + assert runner.immediate_dominator runner = runner.immediate_dominator diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index 22aba57fb0..3f08aefd7c 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Optional, TYPE_CHECKING, List -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError @@ -78,7 +77,7 @@ def __str__(self) -> str: raise SlitherCoreError(f"str: Unknown operation type {self})") -class AssignmentOperation(ExpressionTyped): +class AssignmentOperation(Expression): def __init__( self, left_expression: Expression, @@ -91,7 +90,7 @@ def __init__( super().__init__() left_expression.set_lvalue() self._expressions = [left_expression, right_expression] - self._type: Optional["AssignmentOperationType"] = expression_type + self._type: AssignmentOperationType = expression_type self._expression_return_type: Optional["Type"] = expression_return_type @property diff --git a/slither/core/expressions/binary_operation.py b/slither/core/expressions/binary_operation.py index a3d435075c..a395d07cf8 100644 --- a/slither/core/expressions/binary_operation.py +++ b/slither/core/expressions/binary_operation.py @@ -2,7 +2,6 @@ from enum import Enum from typing import List -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError @@ -148,7 +147,7 @@ def __str__(self) -> str: # pylint: disable=too-many-branches raise SlitherCoreError(f"str: Unknown operation type {self})") -class BinaryOperation(ExpressionTyped): +class BinaryOperation(Expression): def __init__( self, left_expression: Expression, diff --git a/slither/core/expressions/call_expression.py b/slither/core/expressions/call_expression.py index 1dbc4074a8..6708dda7e2 100644 --- a/slither/core/expressions/call_expression.py +++ b/slither/core/expressions/call_expression.py @@ -22,7 +22,7 @@ def call_value(self) -> Optional[Expression]: return self._value @call_value.setter - def call_value(self, v): + def call_value(self, v: Optional[Expression]) -> None: self._value = v @property @@ -30,15 +30,15 @@ def call_gas(self) -> Optional[Expression]: return self._gas @call_gas.setter - def call_gas(self, gas): + def call_gas(self, gas: Optional[Expression]) -> None: self._gas = gas @property - def call_salt(self): + def call_salt(self) -> Optional[Expression]: return self._salt @call_salt.setter - def call_salt(self, salt): + def call_salt(self, salt: Optional[Expression]) -> None: self._salt = salt @property diff --git a/slither/core/expressions/conditional_expression.py b/slither/core/expressions/conditional_expression.py index 818425ba1c..3c0afdb4af 100644 --- a/slither/core/expressions/conditional_expression.py +++ b/slither/core/expressions/conditional_expression.py @@ -42,7 +42,7 @@ def else_expression(self) -> Expression: def then_expression(self) -> Expression: return self._then_expression - def __str__(self): + def __str__(self) -> str: return ( "if " + str(self._if_expression) diff --git a/slither/core/expressions/expression_typed.py b/slither/core/expressions/expression_typed.py deleted file mode 100644 index 2bf3fe39dc..0000000000 --- a/slither/core/expressions/expression_typed.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Optional, TYPE_CHECKING - -from slither.core.expressions.expression import Expression - -if TYPE_CHECKING: - from slither.core.solidity_types.type import Type - - -class ExpressionTyped(Expression): # pylint: disable=too-few-public-methods - def __init__(self) -> None: - super().__init__() - self._type: Optional["Type"] = None - - @property - def type(self) -> Optional["Type"]: - return self._type - - @type.setter - def type(self, new_type: "Type"): - self._type = new_type diff --git a/slither/core/expressions/identifier.py b/slither/core/expressions/identifier.py index 0b10c56159..8ffabad894 100644 --- a/slither/core/expressions/identifier.py +++ b/slither/core/expressions/identifier.py @@ -1,18 +1,80 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union + +from slither.core.declarations.contract_level import ContractLevel +from slither.core.declarations.top_level import TopLevel +from slither.core.expressions.expression import Expression +from slither.core.variables.variable import Variable -from slither.core.expressions.expression_typed import ExpressionTyped if TYPE_CHECKING: - from slither.core.variables.variable import Variable + from slither.core.solidity_types.type import Type + from slither.core.declarations import Contract, SolidityVariable, SolidityFunction + from slither.solc_parsing.yul.evm_functions import YulBuiltin -class Identifier(ExpressionTyped): - def __init__(self, value) -> None: +class Identifier(Expression): + def __init__( + self, + value: Union[ + Variable, + "TopLevel", + "ContractLevel", + "Contract", + "SolidityVariable", + "SolidityFunction", + "YulBuiltin", + ], + ) -> None: super().__init__() - self._value: "Variable" = value + + # pylint: disable=import-outside-toplevel + from slither.core.declarations import Contract, SolidityVariable, SolidityFunction + from slither.solc_parsing.yul.evm_functions import YulBuiltin + + assert isinstance( + value, + ( + Variable, + TopLevel, + ContractLevel, + Contract, + SolidityVariable, + SolidityFunction, + YulBuiltin, + ), + ) + + self._value: Union[ + Variable, + "TopLevel", + "ContractLevel", + "Contract", + "SolidityVariable", + "SolidityFunction", + "YulBuiltin", + ] = value + self._type: Optional["Type"] = None + + @property + def type(self) -> Optional["Type"]: + return self._type + + @type.setter + def type(self, new_type: "Type") -> None: + self._type = new_type @property - def value(self) -> "Variable": + def value( + self, + ) -> Union[ + Variable, + "TopLevel", + "ContractLevel", + "Contract", + "SolidityVariable", + "SolidityFunction", + "YulBuiltin", + ]: return self._value def __str__(self) -> str: diff --git a/slither/core/expressions/index_access.py b/slither/core/expressions/index_access.py index 4f96a56d6f..22f014242d 100644 --- a/slither/core/expressions/index_access.py +++ b/slither/core/expressions/index_access.py @@ -1,27 +1,18 @@ -from typing import Union, List, TYPE_CHECKING +from typing import Union, List -from slither.core.expressions.expression_typed import ExpressionTyped +from slither.core.expressions.expression import Expression from slither.core.expressions.identifier import Identifier from slither.core.expressions.literal import Literal -if TYPE_CHECKING: - from slither.core.expressions.expression import Expression - from slither.core.solidity_types.type import Type - - -class IndexAccess(ExpressionTyped): +class IndexAccess(Expression): def __init__( self, left_expression: Union["IndexAccess", Identifier], right_expression: Union[Literal, Identifier], - index_type: str, ) -> None: super().__init__() self._expressions = [left_expression, right_expression] - # TODO type of undexAccess is not always a Type - # assert isinstance(index_type, Type) - self._type: "Type" = index_type @property def expressions(self) -> List["Expression"]: @@ -35,9 +26,5 @@ def expression_left(self) -> "Expression": def expression_right(self) -> "Expression": return self._expressions[1] - @property - def type(self) -> "Type": - return self._type - def __str__(self) -> str: return str(self.expression_left) + "[" + str(self.expression_right) + "]" diff --git a/slither/core/expressions/literal.py b/slither/core/expressions/literal.py index 5dace3c41e..8848ce9668 100644 --- a/slither/core/expressions/literal.py +++ b/slither/core/expressions/literal.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, TYPE_CHECKING +from typing import Optional, Union, TYPE_CHECKING, Any from slither.core.expressions.expression import Expression from slither.core.solidity_types.elementary_type import Fixed, Int, Ufixed, Uint @@ -47,7 +47,7 @@ def __str__(self) -> str: # be sure to handle any character return str(self._value) - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, Literal): return False return (self.value, self.subdenomination) == (other.value, other.subdenomination) diff --git a/slither/core/expressions/member_access.py b/slither/core/expressions/member_access.py index 36d6818b2a..e240243182 100644 --- a/slither/core/expressions/member_access.py +++ b/slither/core/expressions/member_access.py @@ -1,10 +1,9 @@ from slither.core.expressions.expression import Expression -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.solidity_types.type import Type -class MemberAccess(ExpressionTyped): +class MemberAccess(Expression): def __init__(self, member_name: str, member_type: str, expression: Expression) -> None: # assert isinstance(member_type, Type) # TODO member_type is not always a Type diff --git a/slither/core/expressions/type_conversion.py b/slither/core/expressions/type_conversion.py index b9cd6879e1..2acc8bd521 100644 --- a/slither/core/expressions/type_conversion.py +++ b/slither/core/expressions/type_conversion.py @@ -1,6 +1,5 @@ from typing import Union, TYPE_CHECKING -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type @@ -14,7 +13,7 @@ from slither.core.solidity_types.user_defined_type import UserDefinedType -class TypeConversion(ExpressionTyped): +class TypeConversion(Expression): def __init__( self, expression: Union[ @@ -28,6 +27,14 @@ def __init__( self._expression: Expression = expression self._type: Type = expression_type + @property + def type(self) -> Type: + return self._type + + @type.setter + def type(self, new_type: Type) -> None: + self._type = new_type + @property def expression(self) -> Expression: return self._expression diff --git a/slither/core/expressions/unary_operation.py b/slither/core/expressions/unary_operation.py index a04c575915..6572249278 100644 --- a/slither/core/expressions/unary_operation.py +++ b/slither/core/expressions/unary_operation.py @@ -2,7 +2,6 @@ from typing import Union from enum import Enum -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError from slither.core.expressions.identifier import Identifier @@ -91,7 +90,7 @@ def is_prefix(operation_type: "UnaryOperationType") -> bool: raise SlitherCoreError(f"is_prefix: Unknown operation type {operation_type}") -class UnaryOperation(ExpressionTyped): +class UnaryOperation(Expression): def __init__( self, expression: Union[Literal, Identifier, IndexAccess, TupleExpression], diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index e5f4e830a1..798008707f 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -13,7 +13,7 @@ from crytic_compile import CryticCompile from crytic_compile.utils.naming import Filename -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.context.context import Context from slither.core.declarations import Contract, FunctionContract @@ -206,7 +206,7 @@ def _compute_offsets_from_thing(self, thing: SourceMapping): isinstance(thing, FunctionContract) and thing.contract_declarer == thing.contract ) - or (isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract)) + or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)) ): self._offset_to_objects[definition.filename][offset].add(thing) @@ -224,7 +224,7 @@ def _compute_offsets_from_thing(self, thing: SourceMapping): and thing.contract_declarer == thing.contract ) or ( - isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract) + isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract) ) ): self._offset_to_objects[definition.filename][offset].add(thing) @@ -482,8 +482,8 @@ def add_path_to_filter(self, path: str): ################################################################################### @property - def crytic_compile(self) -> Optional[CryticCompile]: - return self._crytic_compile + def crytic_compile(self) -> CryticCompile: + return self._crytic_compile # type: ignore # endregion ################################################################################### diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index 9a0b12c00f..9dfd3cf17b 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -1,38 +1,37 @@ from typing import Union, Optional, Tuple, Any, TYPE_CHECKING from slither.core.expressions.expression import Expression +from slither.core.expressions.literal import Literal +from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.type import Type from slither.visitors.expression.constants_folding import ConstantFolding -from slither.core.expressions.literal import Literal if TYPE_CHECKING: from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.identifier import Identifier - from slither.core.solidity_types.elementary_type import ElementaryType - from slither.core.solidity_types.function_type import FunctionType - from slither.core.solidity_types.type_alias import TypeAliasTopLevel class ArrayType(Type): def __init__( self, - t: Union["TypeAliasTopLevel", "ArrayType", "FunctionType", "ElementaryType"], + t: Type, length: Optional[Union["Identifier", Literal, "BinaryOperation", int]], ) -> None: assert isinstance(t, Type) if length: if isinstance(length, int): - length = Literal(length, "uint256") - assert isinstance(length, Expression) + length = Literal(length, ElementaryType("uint256")) + super().__init__() self._type: Type = t + assert length is None or isinstance(length, Expression) self._length: Optional[Expression] = length if length: if not isinstance(length, Literal): cf = ConstantFolding(length, "uint256") length = cf.result() - self._length_value = length + self._length_value: Optional[Literal] = length else: self._length_value = None diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index ec2b0ef044..a9f45c8d81 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -1,5 +1,5 @@ import itertools -from typing import Tuple +from typing import Tuple, Optional, Any from slither.core.solidity_types.type import Type @@ -176,7 +176,7 @@ def name(self) -> str: return self.type @property - def size(self) -> int: + def size(self) -> Optional[int]: """ Return the size in bits Return None if the size is not known @@ -219,7 +219,7 @@ def max(self) -> int: def __str__(self) -> str: return self._type - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, ElementaryType): return False return self.type == other.type diff --git a/slither/core/solidity_types/mapping_type.py b/slither/core/solidity_types/mapping_type.py index a8acb4d9c4..9741569edf 100644 --- a/slither/core/solidity_types/mapping_type.py +++ b/slither/core/solidity_types/mapping_type.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple, TYPE_CHECKING +from typing import Union, Tuple, TYPE_CHECKING, Any from slither.core.solidity_types.type import Type @@ -38,7 +38,7 @@ def is_dynamic(self) -> bool: def __str__(self) -> str: return f"mapping({str(self._from)} => {str(self._to)})" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, MappingType): return False return self.type_from == other.type_from and self.type_to == other.type_to diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 5b9ea0a37f..9387f511aa 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Tuple -from slither.core.children.child_contract import ChildContract from slither.core.declarations.top_level import TopLevel +from slither.core.declarations.contract_level import ContractLevel from slither.core.solidity_types import Type, ElementaryType if TYPE_CHECKING: @@ -40,7 +40,7 @@ def is_dynamic(self) -> bool: class TypeAliasTopLevel(TypeAlias, TopLevel): - def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None: + def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None: super().__init__(underlying_type, name) self.file_scope: "FileScope" = scope @@ -48,8 +48,8 @@ def __str__(self) -> str: return self.name -class TypeAliasContract(TypeAlias, ChildContract): - def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: +class TypeAliasContract(TypeAlias, ContractLevel): + def __init__(self, underlying_type: ElementaryType, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract diff --git a/slither/core/solidity_types/type_information.py b/slither/core/solidity_types/type_information.py index 2af0b097ac..9cef9352c7 100644 --- a/slither/core/solidity_types/type_information.py +++ b/slither/core/solidity_types/type_information.py @@ -1,4 +1,4 @@ -from typing import Union, TYPE_CHECKING, Tuple +from typing import Union, TYPE_CHECKING, Tuple, Any from slither.core.solidity_types import ElementaryType from slither.core.solidity_types.type import Type @@ -40,10 +40,10 @@ def storage_size(self) -> Tuple[int, bool]: def is_dynamic(self) -> bool: raise NotImplementedError - def __str__(self): + def __str__(self) -> str: return f"type({self.type.name})" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, TypeInformation): return False return self.type == other.type diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index e58b04c212..fceab78559 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -1,6 +1,6 @@ import re from abc import ABCMeta -from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional +from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional, Any from Crypto.Hash import SHA1 from crytic_compile.utils.naming import Filename @@ -98,10 +98,10 @@ def __str__(self) -> str: filename_short: str = self.filename.short if self.filename.short else "" return f"{filename_short}{lines}" - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)): return NotImplemented return ( diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index f3ad60d0b7..3b6b6c5113 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -1,8 +1,7 @@ from slither.core.variables.variable import Variable -from slither.core.children.child_event import ChildEvent -class EventVariable(ChildEvent, Variable): +class EventVariable(Variable): def __init__(self) -> None: super().__init__() self._indexed = False @@ -16,5 +15,5 @@ def indexed(self) -> bool: return self._indexed @indexed.setter - def indexed(self, is_indexed: bool): + def indexed(self, is_indexed: bool) -> None: self._indexed = is_indexed diff --git a/slither/core/variables/local_variable.py b/slither/core/variables/local_variable.py index 7b7b4f8bce..fc23eeba75 100644 --- a/slither/core/variables/local_variable.py +++ b/slither/core/variables/local_variable.py @@ -1,7 +1,6 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING from slither.core.variables.variable import Variable -from slither.core.children.child_function import ChildFunction from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.mapping_type import MappingType @@ -9,11 +8,23 @@ from slither.core.declarations.structure import Structure +if TYPE_CHECKING: # type: ignore + from slither.core.declarations import Function -class LocalVariable(ChildFunction, Variable): + +class LocalVariable(Variable): def __init__(self) -> None: super().__init__() self._location: Optional[str] = None + self._function: Optional["Function"] = None + + def set_function(self, function: "Function") -> None: + self._function = function + + @property + def function(self) -> "Function": + assert self._function + return self._function def set_location(self, loc: str) -> None: self._location = loc diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index 47b7682a47..f2a2d6ee3c 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -1,6 +1,6 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.variables.variable import Variable if TYPE_CHECKING: @@ -8,7 +8,7 @@ from slither.core.declarations import Contract -class StateVariable(ChildContract, Variable): +class StateVariable(ContractLevel, Variable): def __init__(self) -> None: super().__init__() self._node_initialization: Optional["Node"] = None diff --git a/slither/core/variables/structure_variable.py b/slither/core/variables/structure_variable.py index c6034da63e..3a001b6a9d 100644 --- a/slither/core/variables/structure_variable.py +++ b/slither/core/variables/structure_variable.py @@ -1,6 +1,19 @@ +from typing import TYPE_CHECKING, Optional from slither.core.variables.variable import Variable -from slither.core.children.child_structure import ChildStructure -class StructureVariable(ChildStructure, Variable): - pass +if TYPE_CHECKING: + from slither.core.declarations import Structure + + +class StructureVariable(Variable): + def __init__(self) -> None: + super().__init__() + self._structure: Optional["Structure"] = None + + def set_structure(self, structure: "Structure") -> None: + self._structure = structure + + @property + def structure(self) -> "Structure": + return self._structure diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 8607a89217..2b777e6723 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -55,7 +55,7 @@ def initialized(self) -> Optional[bool]: return self._initialized @initialized.setter - def initialized(self, is_init: bool): + def initialized(self, is_init: bool) -> None: self._initialized = is_init @property @@ -73,23 +73,24 @@ def name(self) -> Optional[str]: return self._name @name.setter - def name(self, name): + def name(self, name: str) -> None: self._name = name @property - def type(self) -> Optional[Union[Type, List[Type]]]: + def type(self) -> Optional[Type]: return self._type @type.setter - def type(self, types: Union[Type, List[Type]]): - self._type = types + def type(self, new_type: Type) -> None: + assert isinstance(new_type, Type) + self._type = new_type @property def is_constant(self) -> bool: return self._is_constant @is_constant.setter - def is_constant(self, is_cst: bool): + def is_constant(self, is_cst: bool) -> None: self._is_constant = is_cst @property @@ -159,8 +160,8 @@ def signature(self) -> Tuple[str, List[str], List[str]]: return ( self.name, - [str(x) for x in export_nested_types_from_variable(self)], - [str(x) for x in export_return_type_from_variable(self)], + [str(x) for x in export_nested_types_from_variable(self)], # type: ignore + [str(x) for x in export_return_type_from_variable(self)], # type: ignore ) @property @@ -178,4 +179,5 @@ def solidity_signature(self) -> str: return f'{name}({",".join(parameters)})' def __str__(self) -> str: + assert self._name return self._name diff --git a/slither/detectors/abstract_detector.py b/slither/detectors/abstract_detector.py index 8e2dd490d4..7bb8eb93fb 100644 --- a/slither/detectors/abstract_detector.py +++ b/slither/detectors/abstract_detector.py @@ -59,6 +59,8 @@ def make_solc_versions(minor: int, patch_min: int, patch_max: int) -> List[str]: ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6) # No VERSIONS_08 as it is still in dev +DETECTOR_INFO = List[Union[str, SupportedOutput]] + class AbstractDetector(metaclass=abc.ABCMeta): ARGUMENT = "" # run the detector with slither.py --ARGUMENT @@ -251,7 +253,7 @@ def color(self) -> Callable[[str], str]: def generate_result( self, - info: Union[str, List[Union[str, SupportedOutput]]], + info: DETECTOR_INFO, additional_fields: Optional[Dict] = None, ) -> Output: output = Output( diff --git a/slither/detectors/assembly/shift_parameter_mixup.py b/slither/detectors/assembly/shift_parameter_mixup.py index 31dad23716..a4169499a7 100644 --- a/slither/detectors/assembly/shift_parameter_mixup.py +++ b/slither/detectors/assembly/shift_parameter_mixup.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Binary, BinaryType from slither.slithir.variables import Constant from slither.core.declarations.function_contract import FunctionContract @@ -49,7 +53,12 @@ def _check_function(self, f: FunctionContract) -> List[Output]: BinaryType.RIGHT_SHIFT, ]: if isinstance(ir.variable_left, Constant): - info = [f, " contains an incorrect shift operation: ", node, "\n"] + info: DETECTOR_INFO = [ + f, + " contains an incorrect shift operation: ", + node, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/attributes/const_functions_asm.py b/slither/detectors/attributes/const_functions_asm.py index e3a9383614..01798e0858 100644 --- a/slither/detectors/attributes/const_functions_asm.py +++ b/slither/detectors/attributes/const_functions_asm.py @@ -2,11 +2,14 @@ Module detecting constant functions Recursively check the called functions """ -from typing import List +from typing import List, Dict + +from slither.core.compilation_unit import SlitherCompilationUnit from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.formatters.attributes.const_functions import custom_format from slither.utils.output import Output @@ -73,7 +76,10 @@ def _detect(self) -> List[Output]: if f.contains_assembly: attr = "view" if f.view else "pure" - info = [f, f" is declared {attr} but contains assembly code\n"] + info: DETECTOR_INFO = [ + f, + f" is declared {attr} but contains assembly code\n", + ] res = self.generate_result(info, {"contains_assembly": True}) results.append(res) @@ -81,5 +87,5 @@ def _detect(self) -> List[Output]: return results @staticmethod - def _format(comilation_unit, result): + def _format(comilation_unit: SlitherCompilationUnit, result: Dict) -> None: custom_format(comilation_unit, result) diff --git a/slither/detectors/attributes/const_functions_state.py b/slither/detectors/attributes/const_functions_state.py index 36ea8f32d6..d86ca7c0e1 100644 --- a/slither/detectors/attributes/const_functions_state.py +++ b/slither/detectors/attributes/const_functions_state.py @@ -2,11 +2,14 @@ Module detecting constant functions Recursively check the called functions """ -from typing import List +from typing import List, Dict + +from slither.core.compilation_unit import SlitherCompilationUnit from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.formatters.attributes.const_functions import custom_format from slither.utils.output import Output @@ -74,7 +77,7 @@ def _detect(self) -> List[Output]: if variables_written: attr = "view" if f.view else "pure" - info = [ + info: DETECTOR_INFO = [ f, f" is declared {attr} but changes state variables:\n", ] @@ -89,5 +92,5 @@ def _detect(self) -> List[Output]: return results @staticmethod - def _format(slither, result): + def _format(slither: SlitherCompilationUnit, result: Dict) -> None: custom_format(slither, result) diff --git a/slither/detectors/attributes/constant_pragma.py b/slither/detectors/attributes/constant_pragma.py index 2164a78e8c..2ed76c86ad 100644 --- a/slither/detectors/attributes/constant_pragma.py +++ b/slither/detectors/attributes/constant_pragma.py @@ -1,9 +1,14 @@ """ Check that the same pragma is used in all the files """ -from typing import List - -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from typing import List, Dict + +from slither.core.compilation_unit import SlitherCompilationUnit +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.attributes.constant_pragma import custom_format from slither.utils.output import Output @@ -31,7 +36,7 @@ def _detect(self) -> List[Output]: versions = sorted(list(set(versions))) if len(versions) > 1: - info = ["Different versions of Solidity are used:\n"] + info: DETECTOR_INFO = ["Different versions of Solidity are used:\n"] info += [f"\t- Version used: {[str(v) for v in versions]}\n"] for p in sorted(pragma, key=lambda x: x.version): @@ -44,5 +49,5 @@ def _detect(self) -> List[Output]: return results @staticmethod - def _format(slither, result): + def _format(slither: SlitherCompilationUnit, result: Dict) -> None: custom_format(slither, result) diff --git a/slither/detectors/attributes/incorrect_solc.py b/slither/detectors/attributes/incorrect_solc.py index 73874cffc6..eaf40bf21f 100644 --- a/slither/detectors/attributes/incorrect_solc.py +++ b/slither/detectors/attributes/incorrect_solc.py @@ -5,7 +5,11 @@ import re from typing import List, Optional, Tuple -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.attributes.incorrect_solc import custom_format from slither.utils.output import Output @@ -141,7 +145,7 @@ def _detect(self) -> List[Output]: # If we found any disallowed pragmas, we output our findings. if disallowed_pragmas: for (reason, p) in disallowed_pragmas: - info = ["Pragma version", p, f" {reason}\n"] + info: DETECTOR_INFO = ["Pragma version", p, f" {reason}\n"] json = self.generate_result(info) diff --git a/slither/detectors/attributes/locked_ether.py b/slither/detectors/attributes/locked_ether.py index 2fdabaea67..a6f882922c 100644 --- a/slither/detectors/attributes/locked_ether.py +++ b/slither/detectors/attributes/locked_ether.py @@ -4,7 +4,11 @@ from typing import List from slither.core.declarations.contract import Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( HighLevelCall, LowLevelCall, @@ -85,7 +89,7 @@ def _detect(self) -> List[Output]: funcs_payable = [function for function in contract.functions if function.payable] if funcs_payable: if self.do_no_send_ether(contract): - info = ["Contract locking ether found:\n"] + info: DETECTOR_INFO = ["Contract locking ether found:\n"] info += ["\tContract ", contract, " has payable functions:\n"] for function in funcs_payable: info += ["\t - ", function, "\n"] diff --git a/slither/detectors/attributes/unimplemented_interface.py b/slither/detectors/attributes/unimplemented_interface.py index ff0889d116..5c6c9c5f26 100644 --- a/slither/detectors/attributes/unimplemented_interface.py +++ b/slither/detectors/attributes/unimplemented_interface.py @@ -5,7 +5,11 @@ Check for contracts which implement all interface functions but do not explicitly derive from those interfaces. """ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations.contract import Contract from slither.utils.output import Output @@ -139,7 +143,7 @@ def _detect(self) -> List[Output]: continue intended_interfaces = self.detect_unimplemented_interface(contract, interfaces) for interface in intended_interfaces: - info = [contract, " should inherit from ", interface, "\n"] + info: DETECTOR_INFO = [contract, " should inherit from ", interface, "\n"] res = self.generate_result(info) results.append(res) return results diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index 83ed69b9b6..04dfe085a8 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -2,7 +2,14 @@ Detects the passing of arrays located in memory to functions which expect to modify arrays via storage reference. """ from typing import List, Set, Tuple, Union -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +from slither.core.declarations import Function +from slither.core.variables import Variable +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.solidity_types.array_type import ArrayType from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable @@ -89,12 +96,7 @@ def get_funcs_modifying_array_params(contracts: List[Contract]) -> Set[FunctionC @staticmethod def detect_calls_passing_ref_to_function( contracts: List[Contract], array_modifying_funcs: Set[FunctionContract] - ) -> List[ - Union[ - Tuple[Node, StateVariable, FunctionContract], - Tuple[Node, LocalVariable, FunctionContract], - ] - ]: + ) -> List[Tuple[Node, Variable, Union[Function, Variable]]]: """ Obtains all calls passing storage arrays by value to a function which cannot write to them successfully. :param contracts: The collection of contracts to check for problematic calls in. @@ -105,7 +107,7 @@ def detect_calls_passing_ref_to_function( write to the array unsuccessfully. """ # Define our resulting array. - results = [] + results: List[Tuple[Node, Variable, Union[Function, Variable]]] = [] # Verify we have functions in our list to check for. if not array_modifying_funcs: @@ -159,7 +161,7 @@ def _detect(self) -> List[Output]: if problematic_calls: for calling_node, affected_argument, invoked_function in problematic_calls: - info = [ + info: DETECTOR_INFO = [ calling_node.function, " passes array ", affected_argument, diff --git a/slither/detectors/compiler_bugs/enum_conversion.py b/slither/detectors/compiler_bugs/enum_conversion.py index 671b8d6995..c7f1bcf4e2 100644 --- a/slither/detectors/compiler_bugs/enum_conversion.py +++ b/slither/detectors/compiler_bugs/enum_conversion.py @@ -10,6 +10,7 @@ AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) from slither.slithir.operations import TypeConversion from slither.core.declarations.enum import Enum @@ -73,10 +74,14 @@ def _detect(self) -> List[Output]: for c in self.compilation_unit.contracts: ret = _detect_dangerous_enum_conversions(c) for node, var in ret: - func_info = [node, " has a dangerous enum conversion\n"] + func_info: DETECTOR_INFO = [node, " has a dangerous enum conversion\n"] # Output each node with the function info header as a separate result. - variable_info = ["\t- Variable: ", var, f" of type: {str(var.type)}\n"] - node_info = ["\t- Enum conversion: ", node, "\n"] + variable_info: DETECTOR_INFO = [ + "\t- Variable: ", + var, + f" of type: {str(var.type)}\n", + ] + node_info: DETECTOR_INFO = ["\t- Enum conversion: ", node, "\n"] json = self.generate_result(func_info + variable_info + node_info) results.append(json) diff --git a/slither/detectors/compiler_bugs/multiple_constructor_schemes.py b/slither/detectors/compiler_bugs/multiple_constructor_schemes.py index 3486cc41b1..ae325b2a6e 100644 --- a/slither/detectors/compiler_bugs/multiple_constructor_schemes.py +++ b/slither/detectors/compiler_bugs/multiple_constructor_schemes.py @@ -1,6 +1,10 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -58,7 +62,10 @@ def _detect(self) -> List[Output]: # If there is more than one, we encountered the described issue occurring. if constructors and len(constructors) > 1: - info = [contract, " contains multiple constructors in the same contract:\n"] + info: DETECTOR_INFO = [ + contract, + " contains multiple constructors in the same contract:\n", + ] for constructor in constructors: info += ["\t- ", constructor, "\n"] diff --git a/slither/detectors/compiler_bugs/reused_base_constructor.py b/slither/detectors/compiler_bugs/reused_base_constructor.py index 73cfac12e9..73bd410c79 100644 --- a/slither/detectors/compiler_bugs/reused_base_constructor.py +++ b/slither/detectors/compiler_bugs/reused_base_constructor.py @@ -6,6 +6,7 @@ AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract @@ -151,7 +152,7 @@ def _detect(self) -> List[Output]: continue # Generate data to output. - info = [ + info: DETECTOR_INFO = [ contract, " gives base constructor ", base_constructor, diff --git a/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py b/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py index aee6361c64..dd34eb5e0d 100644 --- a/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py +++ b/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py @@ -6,6 +6,7 @@ AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) from slither.core.solidity_types import ArrayType from slither.core.solidity_types import UserDefinedType @@ -122,7 +123,13 @@ def _detect(self) -> List[Output]: for contract in self.contracts: storage_abiencoderv2_arrays = self._detect_storage_abiencoderv2_arrays(contract) for function, node in storage_abiencoderv2_arrays: - info = ["Function ", function, " trigger an abi encoding bug:\n\t- ", node, "\n"] + info: DETECTOR_INFO = [ + "Function ", + function, + " trigger an abi encoding bug:\n\t- ", + node, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/compiler_bugs/storage_signed_integer_array.py b/slither/detectors/compiler_bugs/storage_signed_integer_array.py index 736f667892..cfd13cdbc9 100644 --- a/slither/detectors/compiler_bugs/storage_signed_integer_array.py +++ b/slither/detectors/compiler_bugs/storage_signed_integer_array.py @@ -1,18 +1,21 @@ """ Module detecting storage signed integer array bug """ -from typing import List +from typing import List, Tuple, Set +from slither.core.declarations import Function, Contract from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) -from slither.core.cfg.node import NodeType +from slither.core.cfg.node import NodeType, Node from slither.core.solidity_types import ArrayType from slither.core.solidity_types.elementary_type import Int, ElementaryType from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable +from slither.slithir.operations import Operation, OperationWithLValue from slither.slithir.operations.assignment import Assignment from slither.slithir.operations.init_array import InitArray from slither.utils.output import Output @@ -60,7 +63,7 @@ class StorageSignedIntegerArray(AbstractDetector): VULNERABLE_SOLC_VERSIONS = make_solc_versions(4, 7, 25) + make_solc_versions(5, 0, 9) @staticmethod - def _is_vulnerable_type(ir): + def _is_vulnerable_type(ir: Operation) -> bool: """ Detect if the IR lvalue is a vulnerable type Must be a storage allocation, and an array of Int @@ -68,23 +71,28 @@ def _is_vulnerable_type(ir): """ # Storage allocation # Base type is signed integer + if not isinstance(ir, OperationWithLValue): + return False + return ( ( isinstance(ir.lvalue, StateVariable) or (isinstance(ir.lvalue, LocalVariable) and ir.lvalue.is_storage) ) - and isinstance(ir.lvalue.type.type, ElementaryType) - and ir.lvalue.type.type.type in Int + and isinstance(ir.lvalue.type.type, ElementaryType) # type: ignore + and ir.lvalue.type.type.type in Int # type: ignore ) - def detect_storage_signed_integer_arrays(self, contract): + def detect_storage_signed_integer_arrays( + self, contract: Contract + ) -> Set[Tuple[Function, Node]]: """ Detects and returns all nodes with storage-allocated signed integer array init/assignment :param contract: Contract to detect within :return: A list of tuples with (function, node) where function node has storage-allocated signed integer array init/assignment """ # Create our result set. - results = set() + results: Set[Tuple[Function, Node]] = set() # Loop for each function and modifier. for function in contract.functions_and_modifiers_declared: @@ -118,9 +126,13 @@ def _detect(self) -> List[Output]: for contract in self.contracts: storage_signed_integer_arrays = self.detect_storage_signed_integer_arrays(contract) for function, node in storage_signed_integer_arrays: - contract_info = ["Contract ", contract, " \n"] - function_info = ["\t- Function ", function, "\n"] - node_info = ["\t\t- ", node, " has a storage signed integer array assignment\n"] + contract_info: DETECTOR_INFO = ["Contract ", contract, " \n"] + function_info: DETECTOR_INFO = ["\t- Function ", function, "\n"] + node_info: DETECTOR_INFO = [ + "\t\t- ", + node, + " has a storage signed integer array assignment\n", + ] res = self.generate_result(contract_info + function_info + node_info) results.append(res) diff --git a/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py b/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py index 6685948b35..826b671bd1 100644 --- a/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py +++ b/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py @@ -6,6 +6,7 @@ AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) from slither.slithir.operations import InternalDynamicCall, OperationWithLValue from slither.slithir.variables import ReferenceVariable @@ -115,10 +116,10 @@ def _detect(self) -> List[Output]: results = [] for contract in self.compilation_unit.contracts: - contract_info = ["Contract ", contract, " \n"] + contract_info: DETECTOR_INFO = ["Contract ", contract, " \n"] nodes = self._detect_uninitialized_function_ptr_in_constructor(contract) for node in nodes: - node_info = [ + node_info: DETECTOR_INFO = [ "\t ", node, " is an unintialized function pointer call in a constructor\n", diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20.py b/slither/detectors/erc/erc20/arbitrary_send_erc20.py index 17b1fba30f..f060054590 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20.py @@ -61,12 +61,12 @@ def _arbitrary_from(nodes: List[Node], results: List[Node]) -> None: is_dependent( ir.arguments[0], SolidityVariableComposed("msg.sender"), - node.function.contract, + node, ) or is_dependent( ir.arguments[0], SolidityVariable("this"), - node.function.contract, + node, ) ) ): @@ -79,12 +79,12 @@ def _arbitrary_from(nodes: List[Node], results: List[Node]) -> None: is_dependent( ir.arguments[1], SolidityVariableComposed("msg.sender"), - node.function.contract, + node, ) or is_dependent( ir.arguments[1], SolidityVariable("this"), - node.function.contract, + node, ) ) ): diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py b/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py index f43b6302ec..351f1dcfa7 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from .arbitrary_send_erc20 import ArbitrarySendErc20 @@ -38,7 +42,7 @@ def _detect(self) -> List[Output]: arbitrary_sends.detect() for node in arbitrary_sends.no_permit_results: func = node.function - info = [func, " uses arbitrary from in transferFrom: ", node, "\n"] + info: DETECTOR_INFO = [func, " uses arbitrary from in transferFrom: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py b/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py index 1d311c442c..ca4c4a7939 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from .arbitrary_send_erc20 import ArbitrarySendErc20 @@ -41,7 +45,7 @@ def _detect(self) -> List[Output]: arbitrary_sends.detect() for node in arbitrary_sends.permit_results: func = node.function - info = [ + info: DETECTOR_INFO = [ func, " uses arbitrary from in transferFrom in combination with permit: ", node, diff --git a/slither/detectors/erc/erc20/incorrect_erc20_interface.py b/slither/detectors/erc/erc20/incorrect_erc20_interface.py index 4da6ab5ae0..a17f04e8c7 100644 --- a/slither/detectors/erc/erc20/incorrect_erc20_interface.py +++ b/slither/detectors/erc/erc20/incorrect_erc20_interface.py @@ -6,7 +6,11 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -109,7 +113,7 @@ def _detect(self) -> List[Output]: functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface(c) if functions: for function in functions: - info = [ + info: DETECTOR_INFO = [ c, " has incorrect ERC20 function interface:", function, diff --git a/slither/detectors/erc/incorrect_erc721_interface.py b/slither/detectors/erc/incorrect_erc721_interface.py index 8327e8b2ee..e05f3ce8e9 100644 --- a/slither/detectors/erc/incorrect_erc721_interface.py +++ b/slither/detectors/erc/incorrect_erc721_interface.py @@ -2,7 +2,11 @@ Detect incorrect erc721 interface. """ from typing import Any, List, Tuple, Union -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.utils.output import Output @@ -89,7 +93,9 @@ def incorrect_erc721_interface( return False @staticmethod - def detect_incorrect_erc721_interface(contract: Contract) -> List[Union[FunctionContract, Any]]: + def detect_incorrect_erc721_interface( + contract: Contract, + ) -> List[Union[FunctionContract, Any]]: """Detect incorrect ERC721 interface Returns: @@ -119,7 +125,7 @@ def _detect(self) -> List[Output]: functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface(c) if functions: for function in functions: - info = [ + info: DETECTOR_INFO = [ c, " has incorrect ERC721 function interface:", function, diff --git a/slither/detectors/examples/backdoor.py b/slither/detectors/examples/backdoor.py index 0e8e9ad81f..3928346415 100644 --- a/slither/detectors/examples/backdoor.py +++ b/slither/detectors/examples/backdoor.py @@ -1,6 +1,10 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -28,7 +32,7 @@ def _detect(self) -> List[Output]: for f in contract.functions: if "backdoor" in f.name: # Info to be printed - info = ["Backdoor function found in ", f, "\n"] + info: DETECTOR_INFO = ["Backdoor function found in ", f, "\n"] # Add the result in result res = self.generate_result(info) diff --git a/slither/detectors/functions/arbitrary_send_eth.py b/slither/detectors/functions/arbitrary_send_eth.py index 390b1f2abf..f6c688a3fc 100644 --- a/slither/detectors/functions/arbitrary_send_eth.py +++ b/slither/detectors/functions/arbitrary_send_eth.py @@ -18,7 +18,9 @@ from slither.core.declarations.solidity_variables import ( SolidityFunction, SolidityVariableComposed, + SolidityVariable, ) +from slither.core.variables import Variable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import ( HighLevelCall, @@ -39,6 +41,10 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: ret: List[Node] = [] for node in func.nodes: + func = node.function + deps_target: Union[Contract, Function] = ( + func.contract if isinstance(func, FunctionContract) else func + ) for ir in node.irs: if isinstance(ir, SolidityCall): if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"): @@ -49,7 +55,7 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: if is_dependent( ir.variable_right, SolidityVariableComposed("msg.sender"), - func.contract, + deps_target, ): return False if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): @@ -64,12 +70,13 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: if is_dependent( ir.call_value, SolidityVariableComposed("msg.value"), - func.contract, + node, ): continue - if is_tainted(ir.destination, func.contract): - ret.append(node) + if isinstance(ir.destination, (Variable, SolidityVariable)): + if is_tainted(ir.destination, node): + ret.append(node) return ret diff --git a/slither/detectors/functions/cyclomatic_complexity.py b/slither/detectors/functions/cyclomatic_complexity.py index 53212fd4f9..1151b80a0b 100644 --- a/slither/detectors/functions/cyclomatic_complexity.py +++ b/slither/detectors/functions/cyclomatic_complexity.py @@ -1,7 +1,11 @@ from typing import List, Tuple from slither.core.declarations import Function -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.code_complexity import compute_cyclomatic_complexity from slither.utils.output import Output @@ -44,7 +48,7 @@ def _detect(self) -> List[Output]: _check_for_high_cc(high_cc_functions, f) for f, cc in high_cc_functions: - info = [f, f" has a high cyclomatic complexity ({cc}).\n"] + info: DETECTOR_INFO = [f, f" has a high cyclomatic complexity ({cc}).\n"] res = self.generate_result(info) results.append(res) return results diff --git a/slither/detectors/functions/dead_code.py b/slither/detectors/functions/dead_code.py index 1a25c57761..98eb97ff7e 100644 --- a/slither/detectors/functions/dead_code.py +++ b/slither/detectors/functions/dead_code.py @@ -4,7 +4,11 @@ from typing import List, Tuple from slither.core.declarations import Function, FunctionContract, Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -72,7 +76,7 @@ def _detect(self) -> List[Output]: # Continue if the functon is not implemented because it means the contract is abstract if not function.is_implemented: continue - info = [function, " is never used and should be removed\n"] + info: DETECTOR_INFO = [function, " is never used and should be removed\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/functions/modifier.py b/slither/detectors/functions/modifier.py index 271d8e6cb0..61ec1825e2 100644 --- a/slither/detectors/functions/modifier.py +++ b/slither/detectors/functions/modifier.py @@ -6,7 +6,11 @@ default value can still be returned. """ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.cfg.node import Node, NodeType from slither.utils.output import Output @@ -82,7 +86,11 @@ def _detect(self) -> List[Output]: node = None else: # Nothing was found in the outer scope - info = ["Modifier ", mod, " does not always execute _; or revert"] + info: DETECTOR_INFO = [ + "Modifier ", + mod, + " does not always execute _; or revert", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/functions/permit_domain_signature_collision.py b/slither/detectors/functions/permit_domain_signature_collision.py index de64ec52eb..39543fb496 100644 --- a/slither/detectors/functions/permit_domain_signature_collision.py +++ b/slither/detectors/functions/permit_domain_signature_collision.py @@ -6,7 +6,11 @@ from slither.core.declarations import Function from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.function import get_function_id from slither.utils.output import Output @@ -63,7 +67,7 @@ def _detect(self) -> List[Output]: assert isinstance(func_or_var, StateVariable) incorrect_return_type = func_or_var.type != ElementaryType("bytes32") if hash_collision or incorrect_return_type: - info = [ + info: DETECTOR_INFO = [ "The function signature of ", func_or_var, " collides with DOMAIN_SEPARATOR and should be renamed or removed.\n", diff --git a/slither/detectors/functions/protected_variable.py b/slither/detectors/functions/protected_variable.py index 68ed098c79..5796729262 100644 --- a/slither/detectors/functions/protected_variable.py +++ b/slither/detectors/functions/protected_variable.py @@ -6,7 +6,11 @@ from typing import List from slither.core.declarations import Function, Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -58,7 +62,7 @@ def _analyze_function(self, function: Function, contract: Contract) -> List[Outp self.logger.error(f"{function_sig} not found") continue if function_protection not in function.all_internal_calls(): - info = [ + info: DETECTOR_INFO = [ function, " should have ", function_protection, diff --git a/slither/detectors/functions/suicidal.py b/slither/detectors/functions/suicidal.py index 7741da57da..1f8cb52f9c 100644 --- a/slither/detectors/functions/suicidal.py +++ b/slither/detectors/functions/suicidal.py @@ -7,7 +7,11 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -78,7 +82,7 @@ def _detect(self) -> List[Output]: functions = self.detect_suicidal(c) for func in functions: - info = [func, " allows anyone to destruct the contract\n"] + info: DETECTOR_INFO = [func, " allows anyone to destruct the contract\n"] res = self.generate_result(info) diff --git a/slither/detectors/functions/unimplemented.py b/slither/detectors/functions/unimplemented.py index 11a1fad80d..27a2d94a98 100644 --- a/slither/detectors/functions/unimplemented.py +++ b/slither/detectors/functions/unimplemented.py @@ -8,7 +8,13 @@ Do not consider fallback function or constructor """ from typing import List, Set -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +from slither.core.declarations import Function +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.utils.output import Output @@ -62,7 +68,7 @@ class UnimplementedFunctionDetection(AbstractDetector): def _match_state_variable(contract: Contract, f: FunctionContract) -> bool: return any(s.full_name == f.full_name for s in contract.state_variables) - def _detect_unimplemented_function(self, contract: Contract) -> Set[FunctionContract]: + def _detect_unimplemented_function(self, contract: Contract) -> Set[Function]: """ Detects any function definitions which are not implemented in the given contract. :param contract: The contract to search unimplemented functions for. @@ -77,6 +83,8 @@ def _detect_unimplemented_function(self, contract: Contract) -> Set[FunctionCont # fallback function and constructor. unimplemented = set() for f in contract.all_functions_called: + if not isinstance(f, Function): + continue if ( not f.is_implemented and not f.is_constructor @@ -102,7 +110,7 @@ def _detect(self) -> List[Output]: for contract in self.compilation_unit.contracts_derived: functions = self._detect_unimplemented_function(contract) if functions: - info = [contract, " does not implement functions:\n"] + info: DETECTOR_INFO = [contract, " does not implement functions:\n"] for function in sorted(functions, key=lambda x: x.full_name): info += ["\t- ", function, "\n"] diff --git a/slither/detectors/naming_convention/naming_convention.py b/slither/detectors/naming_convention/naming_convention.py index 96d3964fa5..02deb719e7 100644 --- a/slither/detectors/naming_convention/naming_convention.py +++ b/slither/detectors/naming_convention/naming_convention.py @@ -1,6 +1,10 @@ import re from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.naming_convention.naming_convention import custom_format from slither.utils.output import Output @@ -63,6 +67,7 @@ def should_avoid_name(name: str) -> bool: def _detect(self) -> List[Output]: results = [] + info: DETECTOR_INFO for contract in self.contracts: if not self.is_cap_words(contract.name): diff --git a/slither/detectors/operations/bad_prng.py b/slither/detectors/operations/bad_prng.py index d8bf28f6c4..f816e96c83 100644 --- a/slither/detectors/operations/bad_prng.py +++ b/slither/detectors/operations/bad_prng.py @@ -50,14 +50,17 @@ def contains_bad_PRNG_sources(func: Function, blockhash_ret_values: List[Variabl for node in func.nodes: for ir in node.irs_ssa: if isinstance(ir, Binary) and ir.type == BinaryType.MODULO: + var_left = ir.variable_left + if not isinstance(var_left, (Variable, SolidityVariable)): + continue if is_dependent_ssa( - ir.variable_left, SolidityVariableComposed("block.timestamp"), func.contract - ) or is_dependent_ssa(ir.variable_left, SolidityVariable("now"), func.contract): + var_left, SolidityVariableComposed("block.timestamp"), node + ) or is_dependent_ssa(var_left, SolidityVariable("now"), node): ret.add(node) break for ret_val in blockhash_ret_values: - if is_dependent_ssa(ir.variable_left, ret_val, func.contract): + if is_dependent_ssa(var_left, ret_val, node): ret.add(node) break return list(ret) diff --git a/slither/detectors/operations/block_timestamp.py b/slither/detectors/operations/block_timestamp.py index b80c8c392b..d5c2c8df78 100644 --- a/slither/detectors/operations/block_timestamp.py +++ b/slither/detectors/operations/block_timestamp.py @@ -6,12 +6,17 @@ from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node -from slither.core.declarations import Function, Contract +from slither.core.declarations import Function, Contract, FunctionContract from slither.core.declarations.solidity_variables import ( SolidityVariableComposed, SolidityVariable, ) -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.variables import Variable +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Binary, BinaryType from slither.utils.output import Output @@ -21,25 +26,25 @@ def _timestamp(func: Function) -> List[Node]: for node in func.nodes: if node.contains_require_or_assert(): for var in node.variables_read: - if is_dependent(var, SolidityVariableComposed("block.timestamp"), func.contract): + if is_dependent(var, SolidityVariableComposed("block.timestamp"), node): ret.add(node) - if is_dependent(var, SolidityVariable("now"), func.contract): + if is_dependent(var, SolidityVariable("now"), node): ret.add(node) for ir in node.irs: if isinstance(ir, Binary) and BinaryType.return_bool(ir.type): - for var in ir.read: - if is_dependent( - var, SolidityVariableComposed("block.timestamp"), func.contract - ): + for var_read in ir.read: + if not isinstance(var_read, (Variable, SolidityVariable)): + continue + if is_dependent(var_read, SolidityVariableComposed("block.timestamp"), node): ret.add(node) - if is_dependent(var, SolidityVariable("now"), func.contract): + if is_dependent(var_read, SolidityVariable("now"), node): ret.add(node) return sorted(list(ret), key=lambda x: x.node_id) def _detect_dangerous_timestamp( contract: Contract, -) -> List[Tuple[Function, List[Node]]]: +) -> List[Tuple[FunctionContract, List[Node]]]: """ Args: contract (Contract) @@ -48,7 +53,7 @@ def _detect_dangerous_timestamp( """ ret = [] for f in [f for f in contract.functions if f.contract_declarer == contract]: - nodes = _timestamp(f) + nodes: List[Node] = _timestamp(f) if nodes: ret.append((f, nodes)) return ret @@ -78,7 +83,7 @@ def _detect(self) -> List[Output]: dangerous_timestamp = _detect_dangerous_timestamp(c) for (func, nodes) in dangerous_timestamp: - info = [func, " uses timestamp for comparisons\n"] + info: DETECTOR_INFO = [func, " uses timestamp for comparisons\n"] info += ["\tDangerous comparisons:\n"] diff --git a/slither/detectors/operations/low_level_calls.py b/slither/detectors/operations/low_level_calls.py index 1ea91c37a9..463c748757 100644 --- a/slither/detectors/operations/low_level_calls.py +++ b/slither/detectors/operations/low_level_calls.py @@ -2,7 +2,11 @@ Module detecting usage of low level calls """ from typing import List, Tuple -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract @@ -52,7 +56,7 @@ def _detect(self) -> List[Output]: for c in self.contracts: values = self.detect_low_level_calls(c) for func, nodes in values: - info = ["Low level call in ", func, ":\n"] + info: DETECTOR_INFO = ["Low level call in ", func, ":\n"] # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) diff --git a/slither/detectors/operations/missing_events_access_control.py b/slither/detectors/operations/missing_events_access_control.py index 20c2297596..853eafd734 100644 --- a/slither/detectors/operations/missing_events_access_control.py +++ b/slither/detectors/operations/missing_events_access_control.py @@ -11,7 +11,11 @@ from slither.core.declarations.modifier import Modifier from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.event_call import EventCall from slither.utils.output import Output @@ -100,7 +104,7 @@ def _detect(self) -> List[Output]: for contract in self.compilation_unit.contracts_derived: missing_events = self._detect_missing_events(contract) for (function, nodes) in missing_events: - info = [function, " should emit an event for: \n"] + info: DETECTOR_INFO = [function, " should emit an event for: \n"] for (node, _sv, _mod) in nodes: info += ["\t- ", node, " \n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/missing_events_arithmetic.py b/slither/detectors/operations/missing_events_arithmetic.py index 6e1d5fbb50..c17ed32a3a 100644 --- a/slither/detectors/operations/missing_events_arithmetic.py +++ b/slither/detectors/operations/missing_events_arithmetic.py @@ -10,7 +10,11 @@ from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uint from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.event_call import EventCall from slither.utils.output import Output @@ -122,7 +126,7 @@ def _detect(self) -> List[Output]: for contract in self.compilation_unit.contracts_derived: missing_events = self._detect_missing_events(contract) for (function, nodes) in missing_events: - info = [function, " should emit an event for: \n"] + info: DETECTOR_INFO = [function, " should emit an event for: \n"] for (node, _) in nodes: info += ["\t- ", node, " \n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/missing_zero_address_validation.py b/slither/detectors/operations/missing_zero_address_validation.py index a6c8de9ff9..4feac9d0ce 100644 --- a/slither/detectors/operations/missing_zero_address_validation.py +++ b/slither/detectors/operations/missing_zero_address_validation.py @@ -12,7 +12,11 @@ from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.variables.local_variable import LocalVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Call from slither.slithir.operations import Send, Transfer, LowLevelCall from slither.utils.output import Output @@ -155,7 +159,7 @@ def _detect(self) -> List[Output]: missing_zero_address_validation = self._detect_missing_zero_address_validation(contract) for (_, var_nodes) in missing_zero_address_validation: for var, nodes in var_nodes.items(): - info = [var, " lacks a zero-check on ", ":\n"] + info: DETECTOR_INFO = [var, " lacks a zero-check on ", ":\n"] for node in nodes: info += ["\t\t- ", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/unused_return_values.py b/slither/detectors/operations/unused_return_values.py index 7edde20fc5..93dda274aa 100644 --- a/slither/detectors/operations/unused_return_values.py +++ b/slither/detectors/operations/unused_return_values.py @@ -7,7 +7,11 @@ from slither.core.declarations import Function from slither.core.declarations.function_contract import FunctionContract from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import HighLevelCall from slither.slithir.operations.operation import Operation from slither.utils.output import Output @@ -91,7 +95,7 @@ def _detect(self) -> List[Output]: if unused_return: for node in unused_return: - info = [f, " ignores return value by ", node, "\n"] + info: DETECTOR_INFO = [f, " ignores return value by ", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/void_constructor.py b/slither/detectors/operations/void_constructor.py index fb44ea98c9..365904fa9e 100644 --- a/slither/detectors/operations/void_constructor.py +++ b/slither/detectors/operations/void_constructor.py @@ -1,6 +1,10 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Nop from slither.utils.output import Output @@ -39,7 +43,7 @@ def _detect(self) -> List[Output]: for constructor_call in cst.explicit_base_constructor_calls_statements: for node in constructor_call.nodes: if any(isinstance(ir, Nop) for ir in node.irs): - info = ["Void constructor called in ", cst, ":\n"] + info: DETECTOR_INFO = ["Void constructor called in ", cst, ":\n"] info += ["\t- ", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/reentrancy/token.py b/slither/detectors/reentrancy/token.py index c960bffa72..d906a73038 100644 --- a/slither/detectors/reentrancy/token.py +++ b/slither/detectors/reentrancy/token.py @@ -4,7 +4,11 @@ from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node from slither.core.declarations import Function, Contract, SolidityVariableComposed -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall, HighLevelCall from slither.utils.output import Output @@ -88,7 +92,7 @@ def _detect(self) -> List[Output]: for contract in self.compilation_unit.contracts_derived: vulns = _detect_token_reentrant(contract) for function, nodes in vulns.items(): - info = [function, " is an reentrancy unsafe token function:\n"] + info: DETECTOR_INFO = [function, " is an reentrancy unsafe token function:\n"] for node in nodes: info += ["\t-", node, "\n"] json = self.generate_result(info) diff --git a/slither/detectors/shadowing/builtin_symbols.py b/slither/detectors/shadowing/builtin_symbols.py index b0a44c8e2d..ab54861053 100644 --- a/slither/detectors/shadowing/builtin_symbols.py +++ b/slither/detectors/shadowing/builtin_symbols.py @@ -9,7 +9,11 @@ from slither.core.declarations.modifier import Modifier from slither.core.variables import Variable from slither.core.variables.local_variable import LocalVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -194,7 +198,7 @@ def _detect(self) -> List[Output]: shadow_type = shadow[0] shadow_object = shadow[1] - info = [ + info: DETECTOR_INFO = [ shadow_object, f' ({shadow_type}) shadows built-in symbol"\n', ] diff --git a/slither/detectors/shadowing/local.py b/slither/detectors/shadowing/local.py index 07abe52489..d67b5f688b 100644 --- a/slither/detectors/shadowing/local.py +++ b/slither/detectors/shadowing/local.py @@ -9,7 +9,11 @@ from slither.core.declarations.modifier import Modifier from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -85,7 +89,7 @@ def detect_shadowing_definitions( ] = [] # Loop through all functions + modifiers in this contract. - for function in contract.functions + contract.modifiers: + for function in contract.functions + list(contract.modifiers): # We should only look for functions declared directly in this contract (not in a base contract). if function.contract_declarer != contract: continue @@ -144,7 +148,7 @@ def _detect(self) -> List[Output]: for shadow in shadows: local_variable = shadow[0] overshadowed = shadow[1] - info = [local_variable, " shadows:\n"] + info: DETECTOR_INFO = [local_variable, " shadows:\n"] for overshadowed_entry in overshadowed: info += [ "\t- ", diff --git a/slither/detectors/shadowing/state.py b/slither/detectors/shadowing/state.py index 801c370a59..c08dbfd25a 100644 --- a/slither/detectors/shadowing/state.py +++ b/slither/detectors/shadowing/state.py @@ -6,7 +6,11 @@ from slither.core.declarations import Contract from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.detectors.shadowing.common import is_upgradable_gap_variable from slither.utils.output import Output @@ -89,7 +93,7 @@ def _detect(self) -> List[Output]: for all_variables in shadowing: shadow = all_variables[0] variables = all_variables[1:] - info = [shadow, " shadows:\n"] + info: DETECTOR_INFO = [shadow, " shadows:\n"] for var in variables: info += ["\t- ", var, "\n"] diff --git a/slither/detectors/slither/name_reused.py b/slither/detectors/slither/name_reused.py index f6f2820fa2..babce6389f 100644 --- a/slither/detectors/slither/name_reused.py +++ b/slither/detectors/slither/name_reused.py @@ -1,12 +1,17 @@ from collections import defaultdict -from typing import Any, List +from typing import List from slither.core.compilation_unit import SlitherCompilationUnit -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.declarations import Contract +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output -def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit) -> List[Any]: +def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit) -> List[Contract]: """ Filter contracts with missing inheritance to return only the "most base" contracts in the inheritance tree. @@ -80,7 +85,7 @@ def _detect(self) -> List[Output]: inheritance_corrupted[father.name].append(contract) for contract_name, files in names_reused.items(): - info = [contract_name, " is re-used:\n"] + info: DETECTOR_INFO = [contract_name, " is re-used:\n"] for file in files: if file is None: info += ["\t- In an file not found, most likely in\n"] diff --git a/slither/detectors/source/rtlo.py b/slither/detectors/source/rtlo.py index f89eb70eb2..b020f69f9f 100644 --- a/slither/detectors/source/rtlo.py +++ b/slither/detectors/source/rtlo.py @@ -1,7 +1,11 @@ import re from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -78,7 +82,7 @@ def _detect(self) -> List[Output]: idx = start_index + result_index relative = self.slither.crytic_compile.filename_lookup(filename).relative - info = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" + info: DETECTOR_INFO = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" # We have a patch, so pattern.find will return at least one result diff --git a/slither/detectors/statements/array_length_assignment.py b/slither/detectors/statements/array_length_assignment.py index 51302a2c94..70dc5aadbb 100644 --- a/slither/detectors/statements/array_length_assignment.py +++ b/slither/detectors/statements/array_length_assignment.py @@ -1,7 +1,9 @@ """ Module detecting assignment of array length """ -from typing import List, Set +from typing import List, Set, Union + +from slither.core.variables import Variable from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -14,7 +16,7 @@ from slither.slithir.operations.binary import Binary from slither.analyses.data_dependency.data_dependency import is_tainted from slither.core.declarations.contract import Contract -from slither.utils.output import Output +from slither.utils.output import Output, SupportedOutput def detect_array_length_assignment(contract: Contract) -> Set[Node]: @@ -50,7 +52,7 @@ def detect_array_length_assignment(contract: Contract) -> Set[Node]: elif isinstance(ir, (Assignment, Binary)): if isinstance(ir.lvalue, ReferenceVariable): if ir.lvalue in array_length_refs and any( - is_tainted(v, contract) for v in ir.read + is_tainted(v, contract) for v in ir.read if isinstance(v, Variable) ): # the taint is not precise enough yet # as a result, REF_0 = REF_0 + 1 @@ -120,12 +122,16 @@ def _detect(self) -> List[Output]: for contract in self.contracts: array_length_assignments = detect_array_length_assignment(contract) if array_length_assignments: - contract_info = [ + contract_info: List[Union[str, SupportedOutput]] = [ contract, " contract sets array length with a user-controlled value:\n", ] for node in array_length_assignments: - node_info = contract_info + ["\t- ", node, "\n"] + node_info: List[Union[str, SupportedOutput]] = contract_info + [ + "\t- ", + node, + "\n", + ] res = self.generate_result(node_info) results.append(res) diff --git a/slither/detectors/statements/assembly.py b/slither/detectors/statements/assembly.py index 2c0c49f091..25b5d8034a 100644 --- a/slither/detectors/statements/assembly.py +++ b/slither/detectors/statements/assembly.py @@ -6,7 +6,11 @@ from slither.core.cfg.node import Node, NodeType from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -52,7 +56,7 @@ def _detect(self) -> List[Output]: for c in self.contracts: values = self.detect_assembly(c) for func, nodes in values: - info = [func, " uses assembly\n"] + info: DETECTOR_INFO = [func, " uses assembly\n"] # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) diff --git a/slither/detectors/statements/assert_state_change.py b/slither/detectors/statements/assert_state_change.py index c82919de68..769d730b82 100644 --- a/slither/detectors/statements/assert_state_change.py +++ b/slither/detectors/statements/assert_state_change.py @@ -6,7 +6,11 @@ from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.internal_call import InternalCall from slither.utils.output import Output @@ -25,7 +29,7 @@ def detect_assert_state_change( results = [] # Loop for each function and modifier. - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for node in function.nodes: # Detect assert() calls if any(c.name == "assert(bool)" for c in node.internal_calls) and ( @@ -36,7 +40,9 @@ def detect_assert_state_change( any( ir for ir in node.irs - if isinstance(ir, InternalCall) and ir.function.state_variables_written + if isinstance(ir, InternalCall) + and ir.function + and ir.function.state_variables_written ) ): results.append((function, node)) @@ -85,7 +91,10 @@ def _detect(self) -> List[Output]: for contract in self.contracts: assert_state_change = detect_assert_state_change(contract) for (func, node) in assert_state_change: - info = [func, " has an assert() call which possibly changes state.\n"] + info: DETECTOR_INFO = [ + func, + " has an assert() call which possibly changes state.\n", + ] info += ["\t-", node, "\n"] info += [ "Consider using require() or change the invariant to not modify the state.\n" diff --git a/slither/detectors/statements/boolean_constant_equality.py b/slither/detectors/statements/boolean_constant_equality.py index 5b91f364f8..97eb14aa5c 100644 --- a/slither/detectors/statements/boolean_constant_equality.py +++ b/slither/detectors/statements/boolean_constant_equality.py @@ -6,7 +6,11 @@ from slither.core.cfg.node import Node from slither.core.declarations import Function from slither.core.declarations.contract import Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( Binary, BinaryType, @@ -84,7 +88,7 @@ def _detect(self) -> List[Output]: boolean_constant_misuses = self._detect_boolean_equality(contract) for (func, nodes) in boolean_constant_misuses: for node in nodes: - info = [ + info: DETECTOR_INFO = [ func, " compares to a boolean constant:\n\t-", node, diff --git a/slither/detectors/statements/boolean_constant_misuse.py b/slither/detectors/statements/boolean_constant_misuse.py index 96dd2012f0..093e43fee6 100644 --- a/slither/detectors/statements/boolean_constant_misuse.py +++ b/slither/detectors/statements/boolean_constant_misuse.py @@ -7,7 +7,11 @@ from slither.core.declarations import Function from slither.core.declarations.contract import Contract from slither.core.solidity_types import ElementaryType -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( Assignment, Call, @@ -120,7 +124,7 @@ def _detect(self) -> List[Output]: boolean_constant_misuses = self._detect_boolean_constant_misuses(contract) for (func, nodes) in boolean_constant_misuses: for node in nodes: - info = [ + info: DETECTOR_INFO = [ func, " uses a Boolean constant improperly:\n\t-", node, diff --git a/slither/detectors/statements/calls_in_loop.py b/slither/detectors/statements/calls_in_loop.py index fdd0c67329..d40d18f599 100644 --- a/slither/detectors/statements/calls_in_loop.py +++ b/slither/detectors/statements/calls_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations import Contract from slither.utils.output import Output from slither.slithir.operations import ( @@ -44,6 +48,7 @@ def call_in_loop( continue ret.append(ir.node) if isinstance(ir, (InternalCall)): + assert ir.function call_in_loop(ir.function.entry_point, in_loop_counter, visited, ret) for son in node.sons: @@ -94,7 +99,7 @@ def _detect(self) -> List[Output]: for node in values: func = node.function - info = [func, " has external calls inside a loop: ", node, "\n"] + info: DETECTOR_INFO = [func, " has external calls inside a loop: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/controlled_delegatecall.py b/slither/detectors/statements/controlled_delegatecall.py index 08280940d1..32e59d6eb7 100644 --- a/slither/detectors/statements/controlled_delegatecall.py +++ b/slither/detectors/statements/controlled_delegatecall.py @@ -3,7 +3,11 @@ from slither.analyses.data_dependency.data_dependency import is_tainted from slither.core.cfg.node import Node from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall from slither.utils.output import Output @@ -58,13 +62,13 @@ def _detect(self) -> List[Output]: continue nodes = controlled_delegatecall(f) if nodes: - func_info = [ + func_info: DETECTOR_INFO = [ f, " uses delegatecall to a input-controlled function id\n", ] for node in nodes: - node_info = func_info + ["\t- ", node, "\n"] + node_info: DETECTOR_INFO = func_info + ["\t- ", node, "\n"] res = self.generate_result(node_info) results.append(res) diff --git a/slither/detectors/statements/costly_operations_in_loop.py b/slither/detectors/statements/costly_operations_in_loop.py index 930085cc61..53fa126477 100644 --- a/slither/detectors/statements/costly_operations_in_loop.py +++ b/slither/detectors/statements/costly_operations_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations import Contract from slither.utils.output import Output from slither.slithir.operations import InternalCall, OperationWithLValue @@ -39,7 +43,7 @@ def costly_operations_in_loop( if isinstance(ir, OperationWithLValue) and isinstance(ir.lvalue, StateVariable): ret.append(ir.node) break - if isinstance(ir, (InternalCall)): + if isinstance(ir, (InternalCall)) and ir.function: costly_operations_in_loop(ir.function.entry_point, in_loop_counter, visited, ret) for son in node.sons: @@ -98,7 +102,7 @@ def _detect(self) -> List[Output]: values = detect_costly_operations_in_loop(c) for node in values: func = node.function - info = [func, " has costly operations inside a loop:\n"] + info: DETECTOR_INFO = [func, " has costly operations inside a loop:\n"] info += ["\t- ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/delegatecall_in_loop.py b/slither/detectors/statements/delegatecall_in_loop.py index b7bf70cbc7..bdcf5dcae8 100644 --- a/slither/detectors/statements/delegatecall_in_loop.py +++ b/slither/detectors/statements/delegatecall_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall, InternalCall from slither.core.declarations import Contract from slither.utils.output import Output @@ -38,7 +42,7 @@ def delegatecall_in_loop( and ir.function_name == "delegatecall" ): results.append(ir.node) - if isinstance(ir, (InternalCall)): + if isinstance(ir, (InternalCall)) and ir.function: delegatecall_in_loop(ir.function.entry_point, in_loop_counter, visited, results) for son in node.sons: @@ -94,7 +98,12 @@ def _detect(self) -> List[Output]: for node in values: func = node.function - info = [func, " has delegatecall inside a loop in a payable function: ", node, "\n"] + info: DETECTOR_INFO = [ + func, + " has delegatecall inside a loop in a payable function: ", + node, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/deprecated_calls.py b/slither/detectors/statements/deprecated_calls.py index 3d0ca4ba9a..e59d254bb1 100644 --- a/slither/detectors/statements/deprecated_calls.py +++ b/slither/detectors/statements/deprecated_calls.py @@ -11,7 +11,11 @@ ) from slither.core.expressions.expression import Expression from slither.core.variables import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall from slither.utils.output import Output from slither.visitors.expression.export_values import ExportValues @@ -186,7 +190,7 @@ def _detect(self) -> List[Output]: for deprecated_reference in deprecated_references: source_object = deprecated_reference[0] deprecated_entries = deprecated_reference[1] - info = ["Deprecated standard detected ", source_object, ":\n"] + info: DETECTOR_INFO = ["Deprecated standard detected ", source_object, ":\n"] for (_dep_id, original_desc, recommended_disc) in deprecated_entries: info += [ diff --git a/slither/detectors/statements/divide_before_multiply.py b/slither/detectors/statements/divide_before_multiply.py index a9de76b407..6f199db414 100644 --- a/slither/detectors/statements/divide_before_multiply.py +++ b/slither/detectors/statements/divide_before_multiply.py @@ -2,13 +2,18 @@ Module detecting possible loss of precision due to divide before multiple """ from collections import defaultdict -from typing import Any, DefaultDict, List, Set, Tuple +from typing import DefaultDict, List, Set, Tuple from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Binary, Assignment, BinaryType, LibraryCall, Operation +from slither.slithir.utils.utils import LVALUE from slither.slithir.variables import Constant from slither.utils.output import Output @@ -19,7 +24,7 @@ def is_division(ir: Operation) -> bool: return True if isinstance(ir, LibraryCall): - if ir.function.name.lower() in [ + if ir.function.name and ir.function.name.lower() in [ "div", "safediv", ]: @@ -35,7 +40,7 @@ def is_multiplication(ir: Operation) -> bool: return True if isinstance(ir, LibraryCall): - if ir.function.name.lower() in [ + if ir.function.name and ir.function.name.lower() in [ "mul", "safemul", ]: @@ -58,7 +63,7 @@ def is_assert(node: Node) -> bool: # pylint: disable=too-many-branches def _explore( - to_explore: Set[Node], f_results: List[Node], divisions: DefaultDict[Any, Any] + to_explore: Set[Node], f_results: List[List[Node]], divisions: DefaultDict[LVALUE, List[Node]] ) -> None: explored = set() while to_explore: # pylint: disable=too-many-nested-blocks @@ -70,22 +75,22 @@ def _explore( equality_found = False # List of nodes related to one bug instance - node_results = [] + node_results: List[Node] = [] for ir in node.irs: if isinstance(ir, Assignment): if ir.rvalue in divisions: # Avoid dupplicate. We dont use set so we keep the order of the nodes - if node not in divisions[ir.rvalue]: - divisions[ir.lvalue] = divisions[ir.rvalue] + [node] + if node not in divisions[ir.rvalue]: # type: ignore + divisions[ir.lvalue] = divisions[ir.rvalue] + [node] # type: ignore else: - divisions[ir.lvalue] = divisions[ir.rvalue] + divisions[ir.lvalue] = divisions[ir.rvalue] # type: ignore if is_division(ir): - divisions[ir.lvalue] = [node] + divisions[ir.lvalue] = [node] # type: ignore if is_multiplication(ir): - mul_arguments = ir.read if isinstance(ir, Binary) else ir.arguments + mul_arguments = ir.read if isinstance(ir, Binary) else ir.arguments # type: ignore nodes = [] for r in mul_arguments: if not isinstance(r, Constant) and (r in divisions): @@ -125,7 +130,7 @@ def detect_divide_before_multiply( # List of tuple (function -> list(list(nodes))) # Each list(nodes) of the list is one bug instances # Each node in the list(nodes) is involved in the bug - results = [] + results: List[Tuple[FunctionContract, List[Node]]] = [] # Loop for each function and modifier. for function in contract.functions_declared: @@ -134,11 +139,11 @@ def detect_divide_before_multiply( # List of list(nodes) # Each list(nodes) is one bug instances - f_results = [] + f_results: List[List[Node]] = [] # lvalue -> node # track all the division results (and the assignment of the division results) - divisions = defaultdict(list) + divisions: DefaultDict[LVALUE, List[Node]] = defaultdict(list) _explore({function.entry_point}, f_results, divisions) @@ -190,7 +195,7 @@ def _detect(self) -> List[Output]: if divisions_before_multiplications: for (func, nodes) in divisions_before_multiplications: - info = [ + info: DETECTOR_INFO = [ func, " performs a multiplication on the result of a division:\n", ] diff --git a/slither/detectors/statements/mapping_deletion.py b/slither/detectors/statements/mapping_deletion.py index 59882cc961..4cdac72400 100644 --- a/slither/detectors/statements/mapping_deletion.py +++ b/slither/detectors/statements/mapping_deletion.py @@ -8,7 +8,11 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types import MappingType, UserDefinedType -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Delete from slither.utils.output import Output @@ -83,7 +87,7 @@ def _detect(self) -> List[Output]: for c in self.contracts: mapping = MappingDeletionDetection.detect_mapping_deletion(c) for (func, struct, node) in mapping: - info = [func, " deletes ", struct, " which contains a mapping:\n"] + info: DETECTOR_INFO = [func, " deletes ", struct, " which contains a mapping:\n"] info += ["\t-", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/statements/msg_value_in_loop.py b/slither/detectors/statements/msg_value_in_loop.py index bfd541201c..55bd9bfc2a 100644 --- a/slither/detectors/statements/msg_value_in_loop.py +++ b/slither/detectors/statements/msg_value_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import InternalCall from slither.core.declarations import SolidityVariableComposed, Contract from slither.utils.output import Output @@ -86,7 +90,7 @@ def _detect(self) -> List[Output]: for node in values: func = node.function - info = [func, " use msg.value in a loop: ", node, "\n"] + info: DETECTOR_INFO = [func, " use msg.value in a loop: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/redundant_statements.py b/slither/detectors/statements/redundant_statements.py index 7e72231342..cebaecebeb 100644 --- a/slither/detectors/statements/redundant_statements.py +++ b/slither/detectors/statements/redundant_statements.py @@ -7,7 +7,11 @@ from slither.core.declarations.contract import Contract from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.expressions.identifier import Identifier -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -87,7 +91,13 @@ def _detect(self) -> List[Output]: if redundant_statements: for redundant_statement in redundant_statements: - info = ['Redundant expression "', redundant_statement, '" in', contract, "\n"] + info: DETECTOR_INFO = [ + 'Redundant expression "', + redundant_statement, + '" in', + contract, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/statements/too_many_digits.py b/slither/detectors/statements/too_many_digits.py index 239efa4bed..a5e09a34c8 100644 --- a/slither/detectors/statements/too_many_digits.py +++ b/slither/detectors/statements/too_many_digits.py @@ -7,7 +7,11 @@ from slither.core.cfg.node import Node from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.variables import Constant from slither.utils.output import Output @@ -88,9 +92,9 @@ def _detect(self) -> List[Output]: # iterate over all the nodes ret = self._detect_too_many_digits(f) if ret: - func_info = [f, " uses literals with too many digits:"] + func_info: DETECTOR_INFO = [f, " uses literals with too many digits:"] for node in ret: - node_info = func_info + ["\n\t- ", node, "\n"] + node_info: DETECTOR_INFO = func_info + ["\n\t- ", node, "\n"] # Add the result in result res = self.generate_result(node_info) diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index 34f8173d53..49bf6006d1 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -6,7 +6,11 @@ from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -80,7 +84,7 @@ def _detect(self) -> List[Output]: for func, nodes in values: for node in nodes: - info = [func, " uses tx.origin for authorization: ", node, "\n"] + info: DETECTOR_INFO = [func, " uses tx.origin for authorization: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/type_based_tautology.py b/slither/detectors/statements/type_based_tautology.py index 9edb1f53ec..2e0fc84806 100644 --- a/slither/detectors/statements/type_based_tautology.py +++ b/slither/detectors/statements/type_based_tautology.py @@ -17,10 +17,9 @@ def typeRange(t: str) -> Tuple[int, int]: bits = int(t.split("int")[1]) if t in Uint: return 0, (2**bits) - 1 - if t in Int: - v = (2 ** (bits - 1)) - 1 - return -v, v - return None + assert t in Int + v = (2 ** (bits - 1)) - 1 + return -v, v def _detect_tautology_or_contradiction(low: int, high: int, cval: int, op: BinaryType) -> bool: diff --git a/slither/detectors/statements/unary.py b/slither/detectors/statements/unary.py index 5bb8d9c3c6..9c0add2389 100644 --- a/slither/detectors/statements/unary.py +++ b/slither/detectors/statements/unary.py @@ -4,29 +4,43 @@ from typing import List from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.expression import Expression from slither.core.expressions.unary_operation import UnaryOperationType, UnaryOperation -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from slither.visitors.expression.expression import ExpressionVisitor - +# pylint: disable=too-few-public-methods class InvalidUnaryExpressionDetector(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self.result: bool = False + super().__init__(expression) + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: if isinstance(expression.expression_right, UnaryOperation): if expression.expression_right.type == UnaryOperationType.PLUS_PRE: # This is defined in ExpressionVisitor but pylint # Seems to think its not # pylint: disable=attribute-defined-outside-init - self._result = True + self.result = True +# pylint: disable=too-few-public-methods class InvalidUnaryStateVariableDetector(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self.result: bool = False + super().__init__(expression) + def _post_unary_operation(self, expression: UnaryOperation) -> None: if expression.type == UnaryOperationType.PLUS_PRE: # This is defined in ExpressionVisitor but pylint # Seems to think its not # pylint: disable=attribute-defined-outside-init - self._result = True + self.result = True class IncorrectUnaryExpressionDetection(AbstractDetector): @@ -72,15 +86,18 @@ def _detect(self) -> List[Output]: for variable in c.state_variables: if ( variable.expression - and InvalidUnaryStateVariableDetector(variable.expression).result() + and InvalidUnaryStateVariableDetector(variable.expression).result ): - info = [variable, f" uses an dangerous unary operator: {variable.expression}\n"] + info: DETECTOR_INFO = [ + variable, + f" uses an dangerous unary operator: {variable.expression}\n", + ] json = self.generate_result(info) results.append(json) for f in c.functions_and_modifiers_declared: for node in f.nodes: - if node.expression and InvalidUnaryExpressionDetector(node.expression).result(): + if node.expression and InvalidUnaryExpressionDetector(node.expression).result: info = [node.function, " uses an dangerous unary operator: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/unprotected_upgradeable.py b/slither/detectors/statements/unprotected_upgradeable.py index 1adf495407..30e6300f17 100644 --- a/slither/detectors/statements/unprotected_upgradeable.py +++ b/slither/detectors/statements/unprotected_upgradeable.py @@ -2,7 +2,11 @@ from slither.core.declarations import SolidityFunction, Function from slither.core.declarations.contract import Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall, SolidityCall from slither.utils.output import Output @@ -110,17 +114,15 @@ def _detect(self) -> List[Output]: item for sublist in vars_init_in_constructors_ for item in sublist ] if vars_init and (set(vars_init) - set(vars_init_in_constructors)): - info = ( - [ - contract, - " is an upgradeable contract that does not protect its initialize functions: ", - ] - + initialize_functions - + [ - ". Anyone can delete the contract with: ", - ] - + functions_that_can_destroy - ) + info: DETECTOR_INFO = [ + contract, + " is an upgradeable contract that does not protect its initialize functions: ", + ] + info += initialize_functions + info += [ + ". Anyone can delete the contract with: ", + ] + info += functions_that_can_destroy res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/write_after_write.py b/slither/detectors/statements/write_after_write.py index 5b2e29925c..1f11921cb4 100644 --- a/slither/detectors/statements/write_after_write.py +++ b/slither/detectors/statements/write_after_write.py @@ -4,7 +4,11 @@ from slither.core.solidity_types import ElementaryType from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( OperationWithLValue, HighLevelCall, @@ -33,6 +37,8 @@ def _handle_ir( _remove_states(written) if isinstance(ir, InternalCall): + if not ir.function: + return if ir.function.all_high_level_calls() or ir.function.all_library_calls(): _remove_states(written) @@ -128,10 +134,17 @@ def _detect(self) -> List[Output]: for contract in self.compilation_unit.contracts_derived: for function in contract.functions: if function.entry_point: - ret = [] + ret: List[Tuple[Variable, Node, Node]] = [] _detect_write_after_write(function.entry_point, set(), {}, ret) for var, node1, node2 in ret: - info = [var, " is written in both\n\t", node1, "\n\t", node2, "\n"] + info: DETECTOR_INFO = [ + var, + " is written in both\n\t", + node1, + "\n\t", + node2, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/variables/function_init_state_variables.py b/slither/detectors/variables/function_init_state_variables.py index e35cfe351c..e440a4f964 100644 --- a/slither/detectors/variables/function_init_state_variables.py +++ b/slither/detectors/variables/function_init_state_variables.py @@ -6,7 +6,11 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from slither.visitors.expression.export_values import ExportValues @@ -104,7 +108,7 @@ def _detect(self) -> List[Output]: state_variables = detect_function_init_state_vars(contract) if state_variables: for state_variable in state_variables: - info = [ + info: DETECTOR_INFO = [ state_variable, " is set pre-construction with a non-constant function or state variable:\n", ] diff --git a/slither/detectors/variables/predeclaration_usage_local.py b/slither/detectors/variables/predeclaration_usage_local.py index 177035ef43..b4d75e51af 100644 --- a/slither/detectors/variables/predeclaration_usage_local.py +++ b/slither/detectors/variables/predeclaration_usage_local.py @@ -11,6 +11,7 @@ AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.utils.output import Output @@ -154,7 +155,7 @@ def _detect(self) -> List[Output]: predeclared_usage_node, predeclared_usage_local_variable, ) in predeclared_usage_nodes: - info = [ + info: DETECTOR_INFO = [ "Variable '", predeclared_usage_local_variable, "' in ", diff --git a/slither/detectors/variables/similar_variables.py b/slither/detectors/variables/similar_variables.py index d0a15aaab7..465e1ce01d 100644 --- a/slither/detectors/variables/similar_variables.py +++ b/slither/detectors/variables/similar_variables.py @@ -7,7 +7,11 @@ from slither.core.declarations.contract import Contract from slither.core.variables.local_variable import LocalVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -86,7 +90,13 @@ def _detect(self) -> List[Output]: for (v1, v2) in sorted(allVars, key=lambda x: (x[0].name, x[1].name)): v_left = v1 if v1.name < v2.name else v2 v_right = v2 if v_left == v1 else v1 - info = ["Variable ", v_left, " is too similar to ", v_right, "\n"] + info: DETECTOR_INFO = [ + "Variable ", + v_left, + " is too similar to ", + v_right, + "\n", + ] json = self.generate_result(info) results.append(json) return results diff --git a/slither/detectors/variables/uninitialized_state_variables.py b/slither/detectors/variables/uninitialized_state_variables.py index 0fbb73b5dc..13cf110521 100644 --- a/slither/detectors/variables/uninitialized_state_variables.py +++ b/slither/detectors/variables/uninitialized_state_variables.py @@ -14,7 +14,11 @@ from slither.core.declarations.contract import Contract from slither.core.variables import Variable from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import InternalCall, LibraryCall from slither.slithir.variables import ReferenceVariable from slither.utils.output import Output @@ -140,7 +144,7 @@ def _detect(self) -> List[Output]: ret = self._detect_uninitialized(c) for variable, functions in ret: - info = [variable, " is never initialized. It is used in:\n"] + info: DETECTOR_INFO = [variable, " is never initialized. It is used in:\n"] for f in functions: info += ["\t- ", f, "\n"] diff --git a/slither/detectors/variables/unused_state_variables.py b/slither/detectors/variables/unused_state_variables.py index d542f67d30..afb4e3ac5e 100644 --- a/slither/detectors/variables/unused_state_variables.py +++ b/slither/detectors/variables/unused_state_variables.py @@ -1,13 +1,19 @@ """ Module detecting unused state variables """ -from typing import List, Optional +from typing import List, Optional, Dict from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.declarations import Function from slither.core.declarations.contract import Contract from slither.core.solidity_types import ArrayType +from slither.core.variables import Variable from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.variables.unused_state_variables import custom_format from slither.utils.output import Output from slither.visitors.expression.export_values import ExportValues @@ -18,14 +24,19 @@ def detect_unused(contract: Contract) -> Optional[List[StateVariable]]: return None # Get all the variables read in all the functions and modifiers - all_functions = contract.all_functions_called + contract.modifiers + all_functions = [ + f + for f in contract.all_functions_called + list(contract.modifiers) + if isinstance(f, Function) + ] variables_used = [x.state_variables_read for x in all_functions] variables_used += [ x.state_variables_written for x in all_functions if not x.is_constructor_variables ] - array_candidates = [x.variables for x in all_functions] - array_candidates = [i for sl in array_candidates for i in sl] + contract.state_variables + array_candidates_ = [x.variables for x in all_functions] + array_candidates: List[Variable] = [i for sl in array_candidates_ for i in sl] + array_candidates += contract.state_variables array_candidates = [ x.type.length for x in array_candidates if isinstance(x.type, ArrayType) and x.type.length ] @@ -65,12 +76,12 @@ def _detect(self) -> List[Output]: unusedVars = detect_unused(c) if unusedVars: for var in unusedVars: - info = [var, " is never used in ", c, "\n"] + info: DETECTOR_INFO = [var, " is never used in ", c, "\n"] json = self.generate_result(info) results.append(json) return results @staticmethod - def _format(compilation_unit: SlitherCompilationUnit, result): + def _format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: custom_format(compilation_unit, result) diff --git a/slither/detectors/variables/var_read_using_this.py b/slither/detectors/variables/var_read_using_this.py index b224f8c17d..a2b93a7d8b 100644 --- a/slither/detectors/variables/var_read_using_this.py +++ b/slither/detectors/variables/var_read_using_this.py @@ -2,7 +2,11 @@ from slither.core.cfg.node import Node from slither.core.declarations import Function, SolidityVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.high_level_call import HighLevelCall from slither.utils.output import Output @@ -35,7 +39,7 @@ def _detect(self) -> List[Output]: for c in self.contracts: for func in c.functions: for node in self._detect_var_read_using_this(func): - info = [ + info: DETECTOR_INFO = [ "The function ", func, " reads ", diff --git a/slither/formatters/attributes/const_functions.py b/slither/formatters/attributes/const_functions.py index 33588af74e..feb404f7b8 100644 --- a/slither/formatters/attributes/const_functions.py +++ b/slither/formatters/attributes/const_functions.py @@ -1,11 +1,12 @@ import re +from typing import Dict from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.exceptions import FormatError from slither.formatters.utils.patches import create_patch -def custom_format(compilation_unit: SlitherCompilationUnit, result): +def custom_format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: for file_scope in compilation_unit.scopes.values(): elements = result["elements"] for element in elements: @@ -33,8 +34,12 @@ def custom_format(compilation_unit: SlitherCompilationUnit, result): def _patch( - compilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start, modify_loc_end -): + compilation_unit: SlitherCompilationUnit, + result: Dict, + in_file: str, + modify_loc_start: int, + modify_loc_end: int, +) -> None: in_file_str = compilation_unit.core.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] # Find the keywords view|pure|constant and remove them diff --git a/slither/formatters/attributes/constant_pragma.py b/slither/formatters/attributes/constant_pragma.py index 251dd07ae5..1127b1e431 100644 --- a/slither/formatters/attributes/constant_pragma.py +++ b/slither/formatters/attributes/constant_pragma.py @@ -1,4 +1,7 @@ import re +from typing import Dict, List, Union + +from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.exceptions import FormatImpossible from slither.formatters.utils.patches import create_patch @@ -16,9 +19,9 @@ PATTERN = re.compile(r"(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") -def custom_format(slither, result): +def custom_format(slither: SlitherCompilationUnit, result: Dict) -> None: elements = result["elements"] - versions_used = [] + versions_used: List[str] = [] for element in elements: versions_used.append("".join(element["type_specific_fields"]["directive"][1:])) solc_version_replace = _analyse_versions(versions_used) @@ -33,7 +36,7 @@ def custom_format(slither, result): ) -def _analyse_versions(used_solc_versions): +def _analyse_versions(used_solc_versions: List[str]) -> str: replace_solc_versions = [] for version in used_solc_versions: replace_solc_versions.append(_determine_solc_version_replacement(version)) @@ -42,7 +45,7 @@ def _analyse_versions(used_solc_versions): return replace_solc_versions[0] -def _determine_solc_version_replacement(used_solc_version): +def _determine_solc_version_replacement(used_solc_version: str) -> str: versions = PATTERN.findall(used_solc_version) if len(versions) == 1: version = versions[0] @@ -64,10 +67,16 @@ def _determine_solc_version_replacement(used_solc_version): raise FormatImpossible("Unknown version!") +# pylint: disable=too-many-arguments def _patch( - slither, result, in_file, pragma, modify_loc_start, modify_loc_end -): # pylint: disable=too-many-arguments - in_file_str = slither.source_code[in_file].encode("utf8") + slither: SlitherCompilationUnit, + result: Dict, + in_file: str, + pragma: Union[str, bytes], + modify_loc_start: int, + modify_loc_end: int, +) -> None: + in_file_str = slither.core.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] create_patch( result, diff --git a/slither/formatters/naming_convention/naming_convention.py b/slither/formatters/naming_convention/naming_convention.py index 76974296f9..4aadad072a 100644 --- a/slither/formatters/naming_convention/naming_convention.py +++ b/slither/formatters/naming_convention/naming_convention.py @@ -1,9 +1,10 @@ import re import logging -from typing import List +from typing import List, Set, Dict, Union, Optional, Callable, Type, Sequence from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.variables import Variable from slither.slithir.operations import ( Send, Transfer, @@ -14,7 +15,7 @@ InternalDynamicCall, Operation, ) -from slither.core.declarations import Modifier +from slither.core.declarations import Modifier, Event from slither.core.solidity_types import UserDefinedType, MappingType from slither.core.declarations import Enum, Contract, Structure, Function from slither.core.solidity_types.elementary_type import ElementaryTypeName @@ -29,7 +30,7 @@ # pylint: disable=anomalous-backslash-in-string -def custom_format(compilation_unit: SlitherCompilationUnit, result): +def custom_format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: elements = result["elements"] for element in elements: target = element["additional_fields"]["target"] @@ -129,24 +130,24 @@ def custom_format(compilation_unit: SlitherCompilationUnit, result): SOLIDITY_KEYWORDS += ElementaryTypeName -def _name_already_use(slither, name): +def _name_already_use(slither: SlitherCompilationUnit, name: str) -> bool: # Do not convert to a name used somewhere else if not KEY in slither.context: - all_names = set() + all_names: Set[str] = set() for contract in slither.contracts_derived: all_names = all_names.union({st.name for st in contract.structures}) all_names = all_names.union({f.name for f in contract.functions_and_modifiers}) all_names = all_names.union({e.name for e in contract.enums}) - all_names = all_names.union({s.name for s in contract.state_variables}) + all_names = all_names.union({s.name for s in contract.state_variables if s.name}) for function in contract.functions: - all_names = all_names.union({v.name for v in function.variables}) + all_names = all_names.union({v.name for v in function.variables if v.name}) slither.context[KEY] = all_names return name in slither.context[KEY] -def _convert_CapWords(original_name, slither): +def _convert_CapWords(original_name: str, slither: SlitherCompilationUnit) -> str: name = original_name.capitalize() while "_" in name: @@ -162,10 +163,13 @@ def _convert_CapWords(original_name, slither): return name -def _convert_mixedCase(original_name, compilation_unit: SlitherCompilationUnit): - name = original_name - if isinstance(name, bytes): - name = name.decode("utf8") +def _convert_mixedCase( + original_name: Union[str, bytes], compilation_unit: SlitherCompilationUnit +) -> str: + if isinstance(original_name, bytes): + name = original_name.decode("utf8") + else: + name = original_name while "_" in name: offset = name.find("_") @@ -174,13 +178,15 @@ def _convert_mixedCase(original_name, compilation_unit: SlitherCompilationUnit): name = name[0].lower() + name[1:] if _name_already_use(compilation_unit, name): - raise FormatImpossible(f"{original_name} cannot be converted to {name} (already used)") + raise FormatImpossible(f"{original_name} cannot be converted to {name} (already used)") # type: ignore if name in SOLIDITY_KEYWORDS: - raise FormatImpossible(f"{original_name} cannot be converted to {name} (Solidity keyword)") + raise FormatImpossible(f"{original_name} cannot be converted to {name} (Solidity keyword)") # type: ignore return name -def _convert_UPPER_CASE_WITH_UNDERSCORES(name, compilation_unit: SlitherCompilationUnit): +def _convert_UPPER_CASE_WITH_UNDERSCORES( + name: str, compilation_unit: SlitherCompilationUnit +) -> str: if _name_already_use(compilation_unit, name.upper()): raise FormatImpossible(f"{name} cannot be converted to {name.upper()} (already used)") if name.upper() in SOLIDITY_KEYWORDS: @@ -188,7 +194,10 @@ def _convert_UPPER_CASE_WITH_UNDERSCORES(name, compilation_unit: SlitherCompilat return name.upper() -conventions = { +TARGET_TYPE = Union[Contract, Variable, Function] +CONVENTION_F_TYPE = Callable[[str, SlitherCompilationUnit], str] + +conventions: Dict[str, CONVENTION_F_TYPE] = { "CapWords": _convert_CapWords, "mixedCase": _convert_mixedCase, "UPPER_CASE_WITH_UNDERSCORES": _convert_UPPER_CASE_WITH_UNDERSCORES, @@ -203,7 +212,9 @@ def _convert_UPPER_CASE_WITH_UNDERSCORES(name, compilation_unit: SlitherCompilat ################################################################################### -def _get_from_contract(compilation_unit: SlitherCompilationUnit, element, name, getter): +def _get_from_contract( + compilation_unit: SlitherCompilationUnit, element: Dict, name: str, getter: str +) -> TARGET_TYPE: scope = compilation_unit.get_scope(element["source_mapping"]["filename_absolute"]) contract_name = element["type_specific_fields"]["parent"]["name"] contract = scope.get_contract_from_name(contract_name) @@ -218,9 +229,13 @@ def _get_from_contract(compilation_unit: SlitherCompilationUnit, element, name, ################################################################################### -def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): +def _patch( + compilation_unit: SlitherCompilationUnit, result: Dict, element: Dict, _target: str +) -> None: scope = compilation_unit.get_scope(element["source_mapping"]["filename_absolute"]) + target: Optional[TARGET_TYPE] = None + if _target == "contract": target = scope.get_contract_from_name(element["name"]) @@ -257,7 +272,9 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): ] param_name = element["name"] contract = scope.get_contract_from_name(contract_name) + assert contract function = contract.get_function_from_full_name(function_sig) + assert function target = function.get_local_variable_from_name(param_name) elif _target in ["variable", "variable_constant"]: @@ -271,7 +288,9 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): ] var_name = element["name"] contract = scope.get_contract_from_name(contract_name) + assert contract function = contract.get_function_from_full_name(function_sig) + assert function target = function.get_local_variable_from_name(var_name) # State variable else: @@ -287,6 +306,7 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): else: raise FormatError("Unknown naming convention! " + _target) + assert target _explore( compilation_unit, result, target, conventions[element["additional_fields"]["convention"]] ) @@ -310,7 +330,7 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): ) -def _is_var_declaration(slither, filename, start): +def _is_var_declaration(slither: SlitherCompilationUnit, filename: str, start: int) -> bool: """ Detect usage of 'var ' for Solidity < 0.5 :param slither: @@ -319,12 +339,19 @@ def _is_var_declaration(slither, filename, start): :return: """ v = "var " - return slither.source_code[filename][start : start + len(v)] == v + return slither.core.source_code[filename][start : start + len(v)] == v def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches - slither, result, target, convert, custom_type, filename_source_code, start, end -): + slither: SlitherCompilationUnit, + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, + custom_type: Optional[Union[Type, List[Type]]], + filename_source_code: str, + start: int, + end: int, +) -> None: if isinstance(custom_type, UserDefinedType): # Patch type based on contract/enum if isinstance(custom_type.type, (Enum, Contract)): @@ -358,7 +385,7 @@ def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-man # Structure contain a list of elements, that might need patching # .elems return a list of VariableStructure _explore_variables_declaration( - slither, custom_type.type.elems.values(), result, target, convert + slither, list(custom_type.type.elems.values()), result, target, convert ) if isinstance(custom_type, MappingType): @@ -377,7 +404,7 @@ def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-man full_txt_start = start full_txt_end = end - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] re_match = re.match(RE_MAPPING, full_txt) @@ -417,14 +444,19 @@ def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-man def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks - slither, variables, result, target, convert, patch_comment=False -): + slither: SlitherCompilationUnit, + variables: Sequence[Variable], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, + patch_comment: bool = False, +) -> None: for variable in variables: # First explore the type of the variable filename_source_code = variable.source_mapping.filename.absolute full_txt_start = variable.source_mapping.start full_txt_end = full_txt_start + variable.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -442,6 +474,8 @@ def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-ma # If the variable is the target if variable == target: old_str = variable.name + if old_str is None: + old_str = "" new_str = convert(old_str, slither) loc_start = full_txt_start + full_txt.find(old_str.encode("utf8")) @@ -458,10 +492,10 @@ def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-ma idx = len(func.parameters) - func.parameters.index(variable) + 1 first_line = end_line - idx - 2 - potential_comments = slither.source_code[filename_source_code].encode( + potential_comments_ = slither.core.source_code[filename_source_code].encode( "utf8" ) - potential_comments = potential_comments.splitlines(keepends=True)[ + potential_comments = potential_comments_.splitlines(keepends=True)[ first_line : end_line - 1 ] @@ -491,10 +525,16 @@ def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-ma idx_beginning += len(line) -def _explore_structures_declaration(slither, structures, result, target, convert): +def _explore_structures_declaration( + slither: SlitherCompilationUnit, + structures: Sequence[Structure], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for st in structures: # Explore the variable declared within the structure (VariableStructure) - _explore_variables_declaration(slither, st.elems.values(), result, target, convert) + _explore_variables_declaration(slither, list(st.elems.values()), result, target, convert) # If the structure is the target if st == target: @@ -504,7 +544,7 @@ def _explore_structures_declaration(slither, structures, result, target, convert filename_source_code = st.source_mapping.filename.absolute full_txt_start = st.source_mapping.start full_txt_end = full_txt_start + st.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -517,7 +557,13 @@ def _explore_structures_declaration(slither, structures, result, target, convert create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore_events_declaration(slither, events, result, target, convert): +def _explore_events_declaration( + slither: SlitherCompilationUnit, + events: Sequence[Event], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for event in events: # Explore the parameters _explore_variables_declaration(slither, event.elems, result, target, convert) @@ -535,7 +581,7 @@ def _explore_events_declaration(slither, events, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def get_ir_variables(ir): +def get_ir_variables(ir: Operation) -> List[Union[Variable, Function]]: all_vars = ir.read if isinstance(ir, (InternalCall, InternalDynamicCall, HighLevelCall)): @@ -553,9 +599,15 @@ def get_ir_variables(ir): return [v for v in all_vars if v] -def _explore_irs(slither, irs: List[Operation], result, target, convert): +def _explore_irs( + slither: SlitherCompilationUnit, + irs: List[Operation], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: # pylint: disable=too-many-locals - if irs is None: + if not irs: return for ir in irs: for v in get_ir_variables(ir): @@ -568,7 +620,7 @@ def _explore_irs(slither, irs: List[Operation], result, target, convert): filename_source_code = source_mapping.filename.absolute full_txt_start = source_mapping.start full_txt_end = full_txt_start + source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -600,7 +652,13 @@ def _explore_irs(slither, irs: List[Operation], result, target, convert): ) -def _explore_functions(slither, functions, result, target, convert): +def _explore_functions( + slither: SlitherCompilationUnit, + functions: List[Function], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for function in functions: _explore_variables_declaration(slither, function.variables, result, target, convert, True) _explore_irs(slither, function.all_slithir_operations(), result, target, convert) @@ -612,7 +670,7 @@ def _explore_functions(slither, functions, result, target, convert): filename_source_code = function.source_mapping.filename.absolute full_txt_start = function.source_mapping.start full_txt_end = full_txt_start + function.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -628,7 +686,13 @@ def _explore_functions(slither, functions, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore_enums(slither, enums, result, target, convert): +def _explore_enums( + slither: SlitherCompilationUnit, + enums: Sequence[Enum], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for enum in enums: if enum == target: old_str = enum.name @@ -637,7 +701,7 @@ def _explore_enums(slither, enums, result, target, convert): filename_source_code = enum.source_mapping.filename.absolute full_txt_start = enum.source_mapping.start full_txt_end = full_txt_start + enum.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -650,7 +714,13 @@ def _explore_enums(slither, enums, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore_contract(slither, contract, result, target, convert): +def _explore_contract( + slither: SlitherCompilationUnit, + contract: Contract, + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: _explore_variables_declaration(slither, contract.state_variables, result, target, convert) _explore_structures_declaration(slither, contract.structures, result, target, convert) _explore_functions(slither, contract.functions_and_modifiers, result, target, convert) @@ -660,7 +730,7 @@ def _explore_contract(slither, contract, result, target, convert): filename_source_code = contract.source_mapping.filename.absolute full_txt_start = contract.source_mapping.start full_txt_end = full_txt_start + contract.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -677,7 +747,12 @@ def _explore_contract(slither, contract, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore(compilation_unit: SlitherCompilationUnit, result, target, convert): +def _explore( + compilation_unit: SlitherCompilationUnit, + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for contract in compilation_unit.contracts_derived: _explore_contract(compilation_unit, contract, result, target, convert) diff --git a/slither/formatters/variables/unused_state_variables.py b/slither/formatters/variables/unused_state_variables.py index 8e0852a175..90009c7f11 100644 --- a/slither/formatters/variables/unused_state_variables.py +++ b/slither/formatters/variables/unused_state_variables.py @@ -1,8 +1,10 @@ +from typing import Dict + from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.utils.patches import create_patch -def custom_format(compilation_unit: SlitherCompilationUnit, result): +def custom_format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: elements = result["elements"] for element in elements: if element["type"] == "variable": @@ -14,7 +16,9 @@ def custom_format(compilation_unit: SlitherCompilationUnit, result): ) -def _patch(compilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start): +def _patch( + compilation_unit: SlitherCompilationUnit, result: Dict, in_file: str, modify_loc_start: int +) -> None: in_file_str = compilation_unit.core.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:] old_str = ( diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py index e10db3f76e..38225e6d7a 100644 --- a/slither/printers/call/call_graph.py +++ b/slither/printers/call/call_graph.py @@ -6,33 +6,37 @@ The output is a dot file named filename.dot """ from collections import defaultdict -from slither.printers.abstract_printer import AbstractPrinter -from slither.core.declarations.solidity_variables import SolidityFunction +from typing import Optional, Union, Dict, Set, Tuple, Sequence + +from slither.core.declarations import Contract, FunctionContract from slither.core.declarations.function import Function +from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.variables.variable import Variable +from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output -def _contract_subgraph(contract): +def _contract_subgraph(contract: Contract) -> str: return f"cluster_{contract.id}_{contract.name}" # return unique id for contract function to use as node name -def _function_node(contract, function): +def _function_node(contract: Contract, function: Union[Function, Variable]) -> str: return f"{contract.id}_{function.name}" # return unique id for solidity function to use as node name -def _solidity_function_node(solidity_function): +def _solidity_function_node(solidity_function: SolidityFunction) -> str: return f"{solidity_function.name}" # return dot language string to add graph edge -def _edge(from_node, to_node): +def _edge(from_node: str, to_node: str) -> str: return f'"{from_node}" -> "{to_node}"' # return dot language string to add graph node (with optional label) -def _node(node, label=None): +def _node(node: str, label: Optional[str] = None) -> str: return " ".join( ( f'"{node}"', @@ -43,13 +47,13 @@ def _node(node, label=None): # pylint: disable=too-many-arguments def _process_internal_call( - contract, - function, - internal_call, - contract_calls, - solidity_functions, - solidity_calls, -): + contract: Contract, + function: Function, + internal_call: Union[Function, SolidityFunction], + contract_calls: Dict[Contract, Set[str]], + solidity_functions: Set[str], + solidity_calls: Set[str], +) -> None: if isinstance(internal_call, (Function)): contract_calls[contract].add( _edge( @@ -69,11 +73,15 @@ def _process_internal_call( ) -def _render_external_calls(external_calls): +def _render_external_calls(external_calls: Set[str]) -> str: return "\n".join(external_calls) -def _render_internal_calls(contract, contract_functions, contract_calls): +def _render_internal_calls( + contract: Contract, + contract_functions: Dict[Contract, Set[str]], + contract_calls: Dict[Contract, Set[str]], +) -> str: lines = [] lines.append(f"subgraph {_contract_subgraph(contract)} {{") @@ -87,7 +95,7 @@ def _render_internal_calls(contract, contract_functions, contract_calls): return "\n".join(lines) -def _render_solidity_calls(solidity_functions, solidity_calls): +def _render_solidity_calls(solidity_functions: Set[str], solidity_calls: Set[str]) -> str: lines = [] lines.append("subgraph cluster_solidity {") @@ -102,13 +110,13 @@ def _render_solidity_calls(solidity_functions, solidity_calls): def _process_external_call( - contract, - function, - external_call, - contract_functions, - external_calls, - all_contracts, -): + contract: Contract, + function: Function, + external_call: Tuple[Contract, Union[Function, Variable]], + contract_functions: Dict[Contract, Set[str]], + external_calls: Set[str], + all_contracts: Set[Contract], +) -> None: external_contract, external_function = external_call if not external_contract in all_contracts: @@ -133,15 +141,15 @@ def _process_external_call( # pylint: disable=too-many-arguments def _process_function( - contract, - function, - contract_functions, - contract_calls, - solidity_functions, - solidity_calls, - external_calls, - all_contracts, -): + contract: Contract, + function: Function, + contract_functions: Dict[Contract, Set[str]], + contract_calls: Dict[Contract, Set[str]], + solidity_functions: Set[str], + solidity_calls: Set[str], + external_calls: Set[str], + all_contracts: Set[Contract], +) -> None: contract_functions[contract].add( _node(_function_node(contract, function), function.name), ) @@ -166,29 +174,35 @@ def _process_function( ) -def _process_functions(functions): - contract_functions = defaultdict(set) # contract -> contract functions nodes - contract_calls = defaultdict(set) # contract -> contract calls edges +def _process_functions(functions: Sequence[Function]) -> str: + # TODO add support for top level function + + contract_functions: Dict[Contract, Set[str]] = defaultdict( + set + ) # contract -> contract functions nodes + contract_calls: Dict[Contract, Set[str]] = defaultdict(set) # contract -> contract calls edges - solidity_functions = set() # solidity function nodes - solidity_calls = set() # solidity calls edges - external_calls = set() # external calls edges + solidity_functions: Set[str] = set() # solidity function nodes + solidity_calls: Set[str] = set() # solidity calls edges + external_calls: Set[str] = set() # external calls edges all_contracts = set() for function in functions: - all_contracts.add(function.contract_declarer) + if isinstance(function, FunctionContract): + all_contracts.add(function.contract_declarer) for function in functions: - _process_function( - function.contract_declarer, - function, - contract_functions, - contract_calls, - solidity_functions, - solidity_calls, - external_calls, - all_contracts, - ) + if isinstance(function, FunctionContract): + _process_function( + function.contract_declarer, + function, + contract_functions, + contract_calls, + solidity_functions, + solidity_calls, + external_calls, + all_contracts, + ) render_internal_calls = "" for contract in all_contracts: @@ -209,7 +223,7 @@ class PrinterCallGraph(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph" - def output(self, filename): + def output(self, filename: str) -> Output: """ Output the graph in filename Args: @@ -241,7 +255,9 @@ def output(self, filename): function.canonical_name: function for function in all_functions } content = "\n".join( - ["strict digraph {"] + [_process_functions(all_functions_as_dict.values())] + ["}"] + ["strict digraph {"] + + [_process_functions(list(all_functions_as_dict.values()))] + + ["}"] ) f.write(content) results.append((all_contracts_filename, content)) diff --git a/slither/printers/functions/authorization.py b/slither/printers/functions/authorization.py index ab61d354e6..48b94c297e 100644 --- a/slither/printers/functions/authorization.py +++ b/slither/printers/functions/authorization.py @@ -1,10 +1,12 @@ """ Module printing summary of the contract """ +from typing import List from slither.printers.abstract_printer import AbstractPrinter from slither.core.declarations.function import Function from slither.utils.myprettytable import MyPrettyTable +from slither.utils.output import Output class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): @@ -15,11 +17,15 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variables-written-and-authorization" @staticmethod - def get_msg_sender_checks(function): - all_functions = function.all_internal_calls() + [function] + function.modifiers + def get_msg_sender_checks(function: Function) -> List[str]: + all_functions = ( + [f for f in function.all_internal_calls() if isinstance(f, Function)] + + [function] + + [m for m in function.modifiers if isinstance(m, Function)] + ) - all_nodes = [f.nodes for f in all_functions if isinstance(f, Function)] - all_nodes = [item for sublist in all_nodes for item in sublist] + all_nodes_ = [f.nodes for f in all_functions] + all_nodes = [item for sublist in all_nodes_ for item in sublist] all_conditional_nodes = [ n for n in all_nodes if n.contains_if() or n.contains_require_or_assert() @@ -31,7 +37,7 @@ def get_msg_sender_checks(function): ] return all_conditional_nodes_on_msg_sender - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: @@ -40,7 +46,7 @@ def output(self, _filename): txt = "" all_tables = [] - for contract in self.contracts: + for contract in self.contracts: # type: ignore if contract.is_top_level: continue txt += f"\nContract {contract.name}\n" @@ -49,7 +55,9 @@ def output(self, _filename): ) for function in contract.functions: - state_variables_written = [v.name for v in function.all_state_variables_written()] + state_variables_written = [ + v.name for v in function.all_state_variables_written() if v.name + ] msg_sender_condition = self.get_msg_sender_checks(function) table.add_row( [ diff --git a/slither/printers/functions/cfg.py b/slither/printers/functions/cfg.py index 03e010ff40..3c75f723f4 100644 --- a/slither/printers/functions/cfg.py +++ b/slither/printers/functions/cfg.py @@ -1,4 +1,5 @@ from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output class CFG(AbstractPrinter): @@ -8,7 +9,7 @@ class CFG(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#cfg" - def output(self, filename): + def output(self, filename: str) -> Output: """ _filename is not used Args: @@ -17,10 +18,10 @@ def output(self, filename): info = "" all_files = [] - for contract in self.contracts: + for contract in self.contracts: # type: ignore if contract.is_top_level: continue - for function in contract.functions + contract.modifiers: + for function in contract.functions + list(contract.modifiers): if filename: new_filename = f"{filename}-{contract.name}-{function.full_name}.dot" else: diff --git a/slither/printers/functions/dominator.py b/slither/printers/functions/dominator.py index f618fd5dbd..1b32498f95 100644 --- a/slither/printers/functions/dominator.py +++ b/slither/printers/functions/dominator.py @@ -1,4 +1,5 @@ from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output class Dominator(AbstractPrinter): @@ -8,7 +9,7 @@ class Dominator(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#dominator" - def output(self, filename): + def output(self, filename: str) -> Output: """ _filename is not used Args: diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index b160dd0e6f..acbf5b0158 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -80,7 +80,7 @@ def _is_constant(f: Function) -> bool: # pylint: disable=too-many-branches :return: """ if f.view or f.pure: - if not f.contract.compilation_unit.solc_version.startswith("0.4"): + if not f.compilation_unit.solc_version.startswith("0.4"): return True if f.payable: return False @@ -103,11 +103,11 @@ def _is_constant(f: Function) -> bool: # pylint: disable=too-many-branches if isinstance(ir, HighLevelCall): if isinstance(ir.function, Variable) or ir.function.view or ir.function.pure: # External call to constant functions are ensured to be constant only for solidity >= 0.5 - if f.contract.compilation_unit.solc_version.startswith("0.4"): + if f.compilation_unit.solc_version.startswith("0.4"): return False else: return False - if isinstance(ir, InternalCall): + if isinstance(ir, InternalCall) and ir.function: # Storage write are not properly handled by all_state_variables_written if any(parameter.is_storage for parameter in ir.function.parameters): return False diff --git a/slither/printers/summary/constructor_calls.py b/slither/printers/summary/constructor_calls.py index 665c765469..789811c360 100644 --- a/slither/printers/summary/constructor_calls.py +++ b/slither/printers/summary/constructor_calls.py @@ -5,6 +5,7 @@ from slither.core.source_mapping.source_mapping import Source from slither.printers.abstract_printer import AbstractPrinter from slither.utils import output +from slither.utils.output import Output def _get_source_code(cst: Function) -> str: @@ -17,7 +18,7 @@ class ConstructorPrinter(AbstractPrinter): ARGUMENT = "constructor-calls" HELP = "Print the constructors executed" - def output(self, _filename): + def output(self, _filename: str) -> Output: info = "" for contract in self.slither.contracts_derived: stack_name = [] diff --git a/slither/printers/summary/contract.py b/slither/printers/summary/contract.py index 5af953e202..5fee944169 100644 --- a/slither/printers/summary/contract.py +++ b/slither/printers/summary/contract.py @@ -2,9 +2,13 @@ Module printing summary of the contract """ import collections +from typing import Dict, List + +from slither.core.declarations import FunctionContract from slither.printers.abstract_printer import AbstractPrinter from slither.utils import output from slither.utils.colors import blue, green, magenta +from slither.utils.output import Output class ContractSummary(AbstractPrinter): @@ -13,7 +17,7 @@ class ContractSummary(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#contract-summary" - def output(self, _filename): # pylint: disable=too-many-locals + def output(self, _filename: str) -> Output: # pylint: disable=too-many-locals """ _filename is not used Args: @@ -53,17 +57,16 @@ def output(self, _filename): # pylint: disable=too-many-locals # Order the function with # contract_declarer -> list_functions - public = [ + public_function = [ (f.contract_declarer.name, f) for f in c.functions if (not f.is_shadowed and not f.is_constructor_variables) ] - collect = collections.defaultdict(list) - for a, b in public: + collect: Dict[str, List[FunctionContract]] = collections.defaultdict(list) + for a, b in public_function: collect[a].append(b) - public = list(collect.items()) - for contract, functions in public: + for contract, functions in collect.items(): txt += blue(f" - From {contract}\n") functions = sorted(functions, key=lambda f: f.full_name) @@ -90,7 +93,7 @@ def output(self, _filename): # pylint: disable=too-many-locals self.info(txt) res = self.generate_output(txt) - for contract, additional_fields in all_contracts: - res.add(contract, additional_fields=additional_fields) + for current_contract, current_additional_fields in all_contracts: + res.add(current_contract, additional_fields=current_additional_fields) return res diff --git a/slither/printers/summary/data_depenency.py b/slither/printers/summary/data_depenency.py index 41659a299a..f1c0dc8d59 100644 --- a/slither/printers/summary/data_depenency.py +++ b/slither/printers/summary/data_depenency.py @@ -1,19 +1,22 @@ """ Module printing summary of the contract """ +from typing import List +from slither.core.declarations import Contract from slither.printers.abstract_printer import AbstractPrinter -from slither.analyses.data_dependency.data_dependency import get_dependencies +from slither.analyses.data_dependency.data_dependency import get_dependencies, SUPPORTED_TYPES from slither.slithir.variables import TemporaryVariable, ReferenceVariable from slither.utils.myprettytable import MyPrettyTable +from slither.utils.output import Output -def _get(v, c): +def _get(v: SUPPORTED_TYPES, c: Contract) -> List[str]: return list( { d.name for d in get_dependencies(v, c) - if not isinstance(d, (TemporaryVariable, ReferenceVariable)) + if not isinstance(d, (TemporaryVariable, ReferenceVariable)) and d.name } ) @@ -25,7 +28,7 @@ class DataDependency(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#data-dependencies" - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: @@ -42,6 +45,7 @@ def output(self, _filename): txt += f"\nContract {c.name}\n" table = MyPrettyTable(["Variable", "Dependencies"]) for v in c.state_variables: + assert v.name table.add_row([v.name, sorted(_get(v, c))]) txt += str(table) diff --git a/slither/printers/summary/declaration.py b/slither/printers/summary/declaration.py index 529aba5f08..c7c4798d53 100644 --- a/slither/printers/summary/declaration.py +++ b/slither/printers/summary/declaration.py @@ -1,4 +1,5 @@ from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output from slither.utils.source_mapping import get_definition, get_implementation, get_references @@ -8,7 +9,7 @@ class Declaration(AbstractPrinter): WIKI = "TODO" - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: diff --git a/slither/printers/summary/variable_order.py b/slither/printers/summary/variable_order.py index 9dc9e77c2e..3325b7a010 100644 --- a/slither/printers/summary/variable_order.py +++ b/slither/printers/summary/variable_order.py @@ -4,6 +4,7 @@ from slither.printers.abstract_printer import AbstractPrinter from slither.utils.myprettytable import MyPrettyTable +from slither.utils.output import Output class VariableOrder(AbstractPrinter): @@ -13,7 +14,7 @@ class VariableOrder(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variable-order" - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 9657e23172..3cfddf5e60 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, List, TYPE_CHECKING, Union, Optional +from typing import Any, List, TYPE_CHECKING, Union, Optional, Dict # pylint: disable= too-many-lines,import-outside-toplevel,too-many-branches,too-many-statements,too-many-nested-blocks from slither.core.declarations import ( @@ -13,12 +13,14 @@ SolidityVariableComposed, Structure, ) +from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM from slither.core.declarations.custom_error import CustomError from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.declarations.solidity_variables import SolidityCustomRevert from slither.core.expressions import Identifier, Literal +from slither.core.expressions.expression import Expression from slither.core.solidity_types import ( ArrayType, ElementaryType, @@ -83,28 +85,6 @@ from slither.utils.function import get_function_id from slither.utils.type import export_nested_types_from_variable from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR -import slither.core.declarations.contract -import slither.core.declarations.function -import slither.core.solidity_types.elementary_type -import slither.core.solidity_types.function_type -import slither.core.solidity_types.user_defined_type -import slither.slithir.operations.assignment -import slither.slithir.operations.binary -import slither.slithir.operations.call -import slither.slithir.operations.high_level_call -import slither.slithir.operations.index -import slither.slithir.operations.init_array -import slither.slithir.operations.internal_call -import slither.slithir.operations.length -import slither.slithir.operations.library_call -import slither.slithir.operations.low_level_call -import slither.slithir.operations.member -import slither.slithir.operations.operation -import slither.slithir.operations.send -import slither.slithir.operations.solidity_call -import slither.slithir.operations.transfer -import slither.slithir.variables.temporary -from slither.core.expressions.expression import Expression if TYPE_CHECKING: from slither.core.cfg.node import Node @@ -112,7 +92,7 @@ logger = logging.getLogger("ConvertToIR") -def convert_expression(expression: Expression, node: "Node") -> List[Any]: +def convert_expression(expression: Expression, node: "Node") -> List[Operation]: # handle standlone expression # such as return true; from slither.core.cfg.node import NodeType @@ -122,8 +102,7 @@ def convert_expression(expression: Expression, node: "Node") -> List[Any]: cond = Condition(cst) cond.set_expression(expression) cond.set_node(node) - result = [cond] - return result + return [cond] if isinstance(expression, Identifier) and node.type in [ NodeType.IF, NodeType.IFLOOP, @@ -131,8 +110,7 @@ def convert_expression(expression: Expression, node: "Node") -> List[Any]: cond = Condition(expression.value) cond.set_expression(expression) cond.set_node(node) - result = [cond] - return result + return [cond] visitor = ExpressionToSlithIR(expression, node) result = visitor.result() @@ -141,15 +119,17 @@ def convert_expression(expression: Expression, node: "Node") -> List[Any]: if result: if node.type in [NodeType.IF, NodeType.IFLOOP]: - assert isinstance(result[-1], (OperationWithLValue)) - cond = Condition(result[-1].lvalue) + prev = result[-1] + assert isinstance(prev, (OperationWithLValue)) and prev.lvalue + cond = Condition(prev.lvalue) cond.set_expression(expression) cond.set_node(node) result.append(cond) elif node.type == NodeType.RETURN: # May return None - if isinstance(result[-1], (OperationWithLValue)): - r = Return(result[-1].lvalue) + prev = result[-1] + if isinstance(prev, (OperationWithLValue)): + r = Return(prev.lvalue) r.set_expression(expression) r.set_node(node) result.append(r) @@ -273,7 +253,7 @@ def _find_function_from_parameter( type_args += ["string"] not_found = True - candidates_kept = [] + candidates_kept: List[Function] = [] for type_arg in type_args: if not not_found: break @@ -336,7 +316,7 @@ def integrate_value_gas(result: List[Operation]) -> List[Operation]: # Find all the assignments assigments = {} for i in result: - if isinstance(i, OperationWithLValue): + if isinstance(i, OperationWithLValue) and i.lvalue: assigments[i.lvalue.name] = i if isinstance(i, TmpCall): if isinstance(i.called, Variable) and i.called.name in assigments: @@ -350,20 +330,25 @@ def integrate_value_gas(result: List[Operation]) -> List[Operation]: for idx, ins in enumerate(result): # value can be shadowed, so we check that the prev ins # is an Argument - if is_value(ins) and isinstance(result[idx - 1], Argument): + if idx == 0: + continue + prev_ins = result[idx - 1] + if is_value(ins) and isinstance(prev_ins, Argument): was_changed = True - result[idx - 1].set_type(ArgumentType.VALUE) - result[idx - 1].call_id = ins.ori.variable_left.name - calls.append(ins.ori.variable_left) + prev_ins.set_type(ArgumentType.VALUE) + # Types checked by is_value + prev_ins.call_id = ins.ori.variable_left.name # type: ignore + calls.append(ins.ori.variable_left) # type: ignore to_remove.append(ins) - variable_to_replace[ins.lvalue.name] = ins.ori.variable_left - elif is_gas(ins) and isinstance(result[idx - 1], Argument): + variable_to_replace[ins.lvalue.name] = ins.ori.variable_left # type: ignore + elif is_gas(ins) and isinstance(prev_ins, Argument): was_changed = True - result[idx - 1].set_type(ArgumentType.GAS) - result[idx - 1].call_id = ins.ori.variable_left.name - calls.append(ins.ori.variable_left) + prev_ins.set_type(ArgumentType.GAS) + # Types checked by is_gas + prev_ins.call_id = ins.ori.variable_left.name # type: ignore + calls.append(ins.ori.variable_left) # type: ignore to_remove.append(ins) - variable_to_replace[ins.lvalue.name] = ins.ori.variable_left + variable_to_replace[ins.lvalue.name] = ins.ori.variable_left # type: ignore # Remove the call to value/gas instruction result = [i for i in result if not i in to_remove] @@ -446,7 +431,7 @@ def propagate_type_and_convert_call(result: List[Operation], node: "Node") -> Li if isinstance(ins, (HighLevelCall, NewContract, InternalDynamicCall)): if ins.call_id in calls_value: ins.call_value = calls_value[ins.call_id] - if ins.call_id in calls_gas: + if ins.call_id in calls_gas and isinstance(ins, (HighLevelCall, InternalDynamicCall)): ins.call_gas = calls_gas[ins.call_id] if isinstance(ins, (Call, NewContract, NewStructure)): @@ -528,12 +513,12 @@ def _convert_type_contract(ir: Member) -> Assignment: def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-locals # propagate the type node_function = node.function - using_for = ( + using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = ( node_function.contract.using_for_complete if isinstance(node_function, FunctionContract) else {} ) - if isinstance(ir, OperationWithLValue): + if isinstance(ir, OperationWithLValue) and ir.lvalue: # Force assignment in case of missing previous correct type if not ir.lvalue.type: if isinstance(ir, Assignment): @@ -644,11 +629,12 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)) ): - length = Length(ir.variable_left, ir.lvalue) - length.set_expression(ir.expression) - length.lvalue.points_to = ir.variable_left - length.set_node(ir.node) - return length + new_length = Length(ir.variable_left, ir.lvalue) + assert ir.expression + new_length.set_expression(ir.expression) + new_length.lvalue.points_to = ir.variable_left + new_length.set_node(ir.node) + return new_length # This only happen for .balance/code/codehash access on a variable for which we dont know at # early parsing time the type # Like @@ -729,7 +715,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo return _convert_type_contract(ir) left = ir.variable_left t = None - ir_func = ir.function + ir_func = ir.node.function # Handling of this.function_name usage if ( left == SolidityVariable("this") @@ -792,6 +778,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo ir.lvalue.set_type(ir.array_type) elif isinstance(ir, NewContract): contract = node.file_scope.get_contract_from_name(ir.contract_name) + assert contract ir.lvalue.set_type(UserDefinedType(contract)) elif isinstance(ir, NewElementaryType): ir.lvalue.set_type(ir.type) @@ -835,9 +822,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo # pylint: disable=too-many-locals -def extract_tmp_call( - ins: TmpCall, contract: Optional[Contract] -) -> slither.slithir.operations.call.Call: +def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]) -> Union[Call, Nop]: assert isinstance(ins, TmpCall) if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): # If the call is made to a variable member, where the member is this @@ -1297,7 +1282,7 @@ def convert_to_push_set_val( element_to_add = ReferenceVariable(node) element_to_add.set_type(new_type) - ir_assign_element_to_add = Index(element_to_add, arr, length_val, ElementaryType("uint256")) + ir_assign_element_to_add = Index(element_to_add, arr, length_val) ir_assign_element_to_add.set_expression(ir.expression) ir_assign_element_to_add.set_node(ir.node) ret.append(ir_assign_element_to_add) @@ -1326,16 +1311,8 @@ def convert_to_push_set_val( def convert_to_push( - ir: slither.slithir.operations.high_level_call.HighLevelCall, node: "Node" -) -> List[ - Union[ - slither.slithir.operations.length.Length, - slither.slithir.operations.assignment.Assignment, - slither.slithir.operations.binary.Binary, - slither.slithir.operations.index.Index, - slither.slithir.operations.init_array.InitArray, - ] -]: + ir: HighLevelCall, node: "Node" +) -> List[Union[Length, Assignment, Binary, Index, InitArray,]]: """ Convert a call to a series of operations to push a new value onto the array @@ -1355,22 +1332,23 @@ def convert_to_push( return ret -def convert_to_pop(ir, node): +def convert_to_pop(ir: HighLevelCall, node: "Node") -> List[Operation]: """ Convert pop operators Return a list of 6 operations """ - ret = [] + ret: List[Operation] = [] arr = ir.destination length = ReferenceVariable(node) length.set_type(ElementaryType("uint256")) ir_length = Length(arr, length) + assert ir.expression ir_length.set_expression(ir.expression) ir_length.set_node(ir.node) - ir_length.lvalue.points_to = arr + length.points_to = arr ret.append(ir_length) val = TemporaryVariable(node) @@ -1381,7 +1359,9 @@ def convert_to_pop(ir, node): ret.append(ir_sub_1) element_to_delete = ReferenceVariable(node) - ir_assign_element_to_delete = Index(element_to_delete, arr, val, ElementaryType("uint256")) + ir_assign_element_to_delete = Index(element_to_delete, arr, val) + # TODO the following is equivalent to length.points_to = arr + # Should it be removed? ir_length.lvalue.points_to = arr element_to_delete.set_type(ElementaryType("uint256")) ir_assign_element_to_delete.set_expression(ir.expression) @@ -1397,7 +1377,7 @@ def convert_to_pop(ir, node): length_to_assign.set_type(ElementaryType("uint256")) ir_length = Length(arr, length_to_assign) ir_length.set_expression(ir.expression) - ir_length.lvalue.points_to = arr + length_to_assign.points_to = arr ir_length.set_node(ir.node) ret.append(ir_length) diff --git a/slither/slithir/operations/assignment.py b/slither/slithir/operations/assignment.py index 0ed5f70a49..5bedf2c856 100644 --- a/slither/slithir/operations/assignment.py +++ b/slither/slithir/operations/assignment.py @@ -1,20 +1,21 @@ import logging -from typing import List +from typing import List, Union from slither.core.declarations.function import Function +from slither.core.solidity_types import Type from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, RVALUE, LVALUE from slither.slithir.variables import TupleVariable, ReferenceVariable -from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.variables.variable import Variable - logger = logging.getLogger("AssignmentOperationIR") class Assignment(OperationWithLValue): def __init__( - self, left_variable: Variable, right_variable: SourceMapping, variable_return_type + self, + left_variable: LVALUE, + right_variable: Union[RVALUE, Function, TupleVariable], + variable_return_type: Type, ) -> None: assert is_valid_lvalue(left_variable) assert is_valid_rvalue(right_variable) or isinstance( @@ -22,30 +23,32 @@ def __init__( ) super().__init__() self._variables = [left_variable, right_variable] - self._lvalue = left_variable - self._rvalue = right_variable + self._lvalue: LVALUE = left_variable + self._rvalue: Union[RVALUE, Function, TupleVariable] = right_variable self._variable_return_type = variable_return_type @property - def variables(self): + def variables(self) -> List[Union[LVALUE, RVALUE, Function, TupleVariable]]: return list(self._variables) @property - def read(self) -> List[SourceMapping]: + def read(self) -> List[Union[RVALUE, Function, TupleVariable]]: return [self.rvalue] @property - def variable_return_type(self): + def variable_return_type(self) -> Type: return self._variable_return_type @property - def rvalue(self) -> SourceMapping: + def rvalue(self) -> Union[RVALUE, Function, TupleVariable]: return self._rvalue - def __str__(self): - if isinstance(self.lvalue, ReferenceVariable): - points = self.lvalue.points_to + def __str__(self) -> str: + lvalue = self.lvalue + assert lvalue + if lvalue and isinstance(lvalue, ReferenceVariable): + points = lvalue.points_to while isinstance(points, ReferenceVariable): points = points.points_to - return f"{self.lvalue} (->{points}) := {self.rvalue}({self.rvalue.type})" - return f"{self.lvalue}({self.lvalue.type}) := {self.rvalue}({self.rvalue.type})" + return f"{lvalue} (->{points}) := {self.rvalue}({self.rvalue.type})" + return f"{lvalue}({lvalue.type}) := {self.rvalue}({self.rvalue.type})" diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index ad65e3e758..d1355a9652 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -1,17 +1,14 @@ import logging -from typing import List - from enum import Enum +from typing import List, Union from slither.core.declarations import Function from slither.core.solidity_types import ElementaryType +from slither.core.variables.variable import Variable from slither.slithir.exceptions import SlithIRError from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, LVALUE, RVALUE from slither.slithir.variables import ReferenceVariable -from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.variables.variable import Variable - logger = logging.getLogger("BinaryOperationIR") @@ -51,7 +48,7 @@ def return_bool(operation_type: "BinaryType") -> bool: ] @staticmethod - def get_type(operation_type): # pylint: disable=too-many-branches + def get_type(operation_type: str) -> "BinaryType": # pylint: disable=too-many-branches if operation_type == "**": return BinaryType.POWER if operation_type == "*": @@ -93,7 +90,7 @@ def get_type(operation_type): # pylint: disable=too-many-branches raise SlithIRError(f"get_type: Unknown operation type {operation_type})") - def can_be_checked_for_overflow(self): + def can_be_checked_for_overflow(self) -> bool: return self in [ BinaryType.POWER, BinaryType.MULTIPLICATION, @@ -108,8 +105,8 @@ class Binary(OperationWithLValue): def __init__( self, result: Variable, - left_variable: SourceMapping, - right_variable: Variable, + left_variable: Union[RVALUE, Function], + right_variable: Union[RVALUE, Function], operation_type: BinaryType, ) -> None: assert is_valid_rvalue(left_variable) or isinstance(left_variable, Function) @@ -126,36 +123,38 @@ def __init__( result.set_type(left_variable.type) @property - def read(self) -> List[SourceMapping]: + def read(self) -> List[Union[RVALUE, LVALUE, Function]]: return [self.variable_left, self.variable_right] @property - def get_variable(self): + def get_variable(self) -> List[Union[RVALUE, Function]]: return self._variables @property - def variable_left(self) -> SourceMapping: - return self._variables[0] + def variable_left(self) -> Union[RVALUE, Function]: + return self._variables[0] # type: ignore @property - def variable_right(self) -> Variable: - return self._variables[1] + def variable_right(self) -> Union[RVALUE, Function]: + return self._variables[1] # type: ignore @property def type(self) -> BinaryType: return self._type @property - def type_str(self): + def type_str(self) -> str: if self.node.scope.is_checked and self._type.can_be_checked_for_overflow(): - return "(c)" + self._type.value - return self._type.value - - def __str__(self): - if isinstance(self.lvalue, ReferenceVariable): - points = self.lvalue.points_to + return "(c)" + str(self._type.value) + return str(self._type.value) + + def __str__(self) -> str: + lvalue = self.lvalue + assert lvalue + if isinstance(lvalue, ReferenceVariable): + points = lvalue.points_to while isinstance(points, ReferenceVariable): points = points.points_to - return f"{str(self.lvalue)}(-> {points}) = {self.variable_left} {self.type_str} {self.variable_right}" + return f"{str(lvalue)}(-> {points}) = {self.variable_left} {self.type_str} {self.variable_right}" - return f"{str(self.lvalue)}({self.lvalue.type}) = {self.variable_left} {self.type_str} {self.variable_right}" + return f"{str(lvalue)}({lvalue.type}) = {self.variable_left} {self.type_str} {self.variable_right}" diff --git a/slither/slithir/operations/call.py b/slither/slithir/operations/call.py index 07304fa99f..816c56e1d9 100644 --- a/slither/slithir/operations/call.py +++ b/slither/slithir/operations/call.py @@ -1,22 +1,25 @@ -from typing import Optional, List +from typing import Optional, List, Union +from slither.core.declarations import Function +from slither.core.variables import Variable from slither.slithir.operations.operation import Operation class Call(Operation): def __init__(self) -> None: super().__init__() - self._arguments = [] + self._arguments: List[Variable] = [] @property - def arguments(self): + def arguments(self) -> List[Variable]: return self._arguments @arguments.setter - def arguments(self, v): + def arguments(self, v: List[Variable]) -> None: self._arguments = v - def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use + # pylint: disable=no-self-use + def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/codesize.py b/slither/slithir/operations/codesize.py index 6640f4fd89..13aa430eb9 100644 --- a/slither/slithir/operations/codesize.py +++ b/slither/slithir/operations/codesize.py @@ -29,5 +29,5 @@ def read(self) -> List[Union[LocalIRVariable, LocalVariable]]: def value(self) -> LocalVariable: return self._value - def __str__(self): + def __str__(self) -> str: return f"{self.lvalue} -> CODESIZE {self.value}" diff --git a/slither/slithir/operations/condition.py b/slither/slithir/operations/condition.py index 41fb3d933d..ccec033d9b 100644 --- a/slither/slithir/operations/condition.py +++ b/slither/slithir/operations/condition.py @@ -1,13 +1,7 @@ -from typing import List, Union -from slither.slithir.operations.operation import Operation +from typing import List -from slither.slithir.utils.utils import is_valid_rvalue -from slither.core.variables.local_variable import LocalVariable -from slither.slithir.variables.constant import Constant -from slither.slithir.variables.local_variable import LocalIRVariable -from slither.slithir.variables.temporary import TemporaryVariable -from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA -from slither.core.variables.variable import Variable +from slither.slithir.operations.operation import Operation +from slither.slithir.utils.utils import is_valid_rvalue, RVALUE class Condition(Operation): @@ -18,9 +12,7 @@ class Condition(Operation): def __init__( self, - value: Union[ - LocalVariable, TemporaryVariableSSA, TemporaryVariable, Constant, LocalIRVariable - ], + value: RVALUE, ) -> None: assert is_valid_rvalue(value) super().__init__() @@ -29,14 +21,12 @@ def __init__( @property def read( self, - ) -> List[ - Union[LocalIRVariable, Constant, LocalVariable, TemporaryVariableSSA, TemporaryVariable] - ]: + ) -> List[RVALUE]: return [self.value] @property - def value(self) -> Variable: + def value(self) -> RVALUE: return self._value - def __str__(self): + def __str__(self) -> str: return f"CONDITION {self.value}" diff --git a/slither/slithir/operations/delete.py b/slither/slithir/operations/delete.py index 496d170ad6..d241033c53 100644 --- a/slither/slithir/operations/delete.py +++ b/slither/slithir/operations/delete.py @@ -36,5 +36,5 @@ def variable( ) -> Union[StateIRVariable, StateVariable, ReferenceVariable, ReferenceVariableSSA]: return self._variable - def __str__(self): + def __str__(self) -> str: return f"{self.lvalue} = delete {self.variable} " diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index a129175193..5d654fc800 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +from slither.core.declarations import Contract from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable @@ -32,7 +33,8 @@ def __init__( assert is_valid_lvalue(result) or result is None self._check_destination(destination) super().__init__() - self._destination = destination + # Contract is only possible for library call, which inherits from highlevelcall + self._destination: Union[Variable, SolidityVariable, Contract] = destination # type: ignore self._function_name = function_name self._nbr_arguments = nbr_arguments self._type_call = type_call @@ -44,8 +46,9 @@ def __init__( self._call_gas = None # Development function, to be removed once the code is stable - # It is ovveride by LbraryCall - def _check_destination(self, destination: SourceMapping) -> None: # pylint: disable=no-self-use + # It is overridden by LibraryCall + # pylint: disable=no-self-use + def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None: assert isinstance(destination, (Variable, SolidityVariable)) @property @@ -79,7 +82,14 @@ def read(self) -> List[SourceMapping]: return [x for x in all_read if x] @property - def destination(self) -> SourceMapping: + def destination(self) -> Union[Variable, SolidityVariable, Contract]: + """ + Return a variable or a solidityVariable + Contract is only possible for LibraryCall + + Returns: + + """ return self._destination @property @@ -116,7 +126,7 @@ def is_static_call(self) -> bool: return True return False - def can_reenter(self, callstack: None = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index ade84fe1d5..f38a25927f 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -1,21 +1,16 @@ from typing import List, Union + from slither.core.declarations import SolidityVariableComposed -from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue -from slither.slithir.variables.reference import ReferenceVariable -from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.variable import Variable -from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, RVALUE, LVALUE +from slither.slithir.variables.reference import ReferenceVariable class Index(OperationWithLValue): def __init__( - self, - result: Union[ReferenceVariable, ReferenceVariableSSA], - left_variable: Variable, - right_variable: SourceMapping, - index_type: Union[ElementaryType, str], + self, result: ReferenceVariable, left_variable: Variable, right_variable: RVALUE ) -> None: super().__init__() assert is_valid_lvalue(left_variable) or left_variable == SolidityVariableComposed( @@ -24,28 +19,23 @@ def __init__( assert is_valid_rvalue(right_variable) assert isinstance(result, ReferenceVariable) self._variables = [left_variable, right_variable] - self._type = index_type - self._lvalue = result + self._lvalue: ReferenceVariable = result @property def read(self) -> List[SourceMapping]: return list(self.variables) @property - def variables(self) -> List[SourceMapping]: - return self._variables - - @property - def variable_left(self) -> Variable: - return self._variables[0] + def variables(self) -> List[Union[LVALUE, RVALUE, SolidityVariableComposed]]: + return self._variables # type: ignore @property - def variable_right(self) -> SourceMapping: - return self._variables[1] + def variable_left(self) -> Union[LVALUE, SolidityVariableComposed]: + return self._variables[0] # type: ignore @property - def index_type(self) -> Union[ElementaryType, str]: - return self._type + def variable_right(self) -> RVALUE: + return self._variables[1] # type: ignore - def __str__(self): + def __str__(self) -> str: return f"{self.lvalue}({self.lvalue.type}) -> {self.variable_left}[{self.variable_right}]" diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py index 395c688466..1983b885fe 100644 --- a/slither/slithir/operations/internal_call.py +++ b/slither/slithir/operations/internal_call.py @@ -24,7 +24,7 @@ def __init__( super().__init__() self._contract_name = "" if isinstance(function, Function): - self._function = function + self._function: Optional[Function] = function self._function_name = function.name if isinstance(function, FunctionContract): self._contract_name = function.contract_declarer.name @@ -45,7 +45,7 @@ def read(self) -> List[Any]: return list(self._unroll(self.arguments)) @property - def function(self): + def function(self) -> Optional[Function]: return self._function @function.setter diff --git a/slither/slithir/operations/internal_dynamic_call.py b/slither/slithir/operations/internal_dynamic_call.py index a1ad1aa15f..ca245167e1 100644 --- a/slither/slithir/operations/internal_dynamic_call.py +++ b/slither/slithir/operations/internal_dynamic_call.py @@ -24,7 +24,7 @@ def __init__( assert isinstance(function, Variable) assert is_valid_lvalue(lvalue) or lvalue is None super().__init__() - self._function = function + self._function: Variable = function self._function_type = function_type self._lvalue = lvalue @@ -37,7 +37,7 @@ def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return self._unroll(self.arguments) + [self.function] @property - def function(self) -> Union[LocalVariable, LocalIRVariable]: + def function(self) -> Variable: return self._function @property diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py index ebe9bf5efd..1b7f4e8a6e 100644 --- a/slither/slithir/operations/library_call.py +++ b/slither/slithir/operations/library_call.py @@ -1,4 +1,7 @@ -from slither.core.declarations import Function +from typing import Union, Optional, List + +from slither.core.declarations import Function, SolidityVariable +from slither.core.variables import Variable from slither.slithir.operations.high_level_call import HighLevelCall from slither.core.declarations.contract import Contract @@ -9,10 +12,10 @@ class LibraryCall(HighLevelCall): """ # Development function, to be removed once the code is stable - def _check_destination(self, destination: Contract) -> None: + def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None: assert isinstance(destination, Contract) - def can_reenter(self, callstack: None = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool @@ -20,11 +23,11 @@ def can_reenter(self, callstack: None = None) -> bool: if self.is_static_call(): return False # In case of recursion, return False - callstack = [] if callstack is None else callstack - if self.function in callstack: + callstack_local = [] if callstack is None else callstack + if self.function in callstack_local: return False - callstack = callstack + [self.function] - return self.function.can_reenter(callstack) + callstack_local = callstack_local + [self.function] + return self.function.can_reenter(callstack_local) def __str__(self): gas = "" diff --git a/slither/slithir/operations/low_level_call.py b/slither/slithir/operations/low_level_call.py index 7e8c278b8a..eac779d271 100644 --- a/slither/slithir/operations/low_level_call.py +++ b/slither/slithir/operations/low_level_call.py @@ -1,4 +1,6 @@ -from typing import List, Union +from typing import List, Union, Optional + +from slither.core.declarations import Function from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable @@ -74,7 +76,7 @@ def read( # remove None return self._unroll([x for x in all_read if x]) - def can_reenter(self, _callstack: None = None) -> bool: + def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/lvalue.py b/slither/slithir/operations/lvalue.py index d9b800c92b..b983d1c5d4 100644 --- a/slither/slithir/operations/lvalue.py +++ b/slither/slithir/operations/lvalue.py @@ -1,4 +1,6 @@ -from typing import Any, List +from typing import Any, List, Optional + +from slither.core.variables import Variable from slither.slithir.operations.operation import Operation @@ -10,16 +12,16 @@ class OperationWithLValue(Operation): def __init__(self) -> None: super().__init__() - self._lvalue = None + self._lvalue: Optional[Variable] = None @property - def lvalue(self): + def lvalue(self) -> Optional[Variable]: return self._lvalue - @property - def used(self) -> List[Any]: - return self.read + [self.lvalue] - @lvalue.setter - def lvalue(self, lvalue): + def lvalue(self, lvalue: Variable) -> None: self._lvalue = lvalue + + @property + def used(self) -> List[Optional[Any]]: + return self.read + [self.lvalue] diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py index 9a561ea87a..0942813cfc 100644 --- a/slither/slithir/operations/member.py +++ b/slither/slithir/operations/member.py @@ -5,7 +5,7 @@ from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.solidity_types import ElementaryType from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_rvalue +from slither.slithir.utils.utils import is_valid_rvalue, RVALUE from slither.slithir.variables.constant import Constant from slither.slithir.variables.reference import ReferenceVariable from slither.core.source_mapping.source_mapping import SourceMapping @@ -39,7 +39,9 @@ def __init__( assert isinstance(variable_right, Constant) assert isinstance(result, ReferenceVariable) super().__init__() - self._variable_left = variable_left + self._variable_left: Union[ + RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType + ] = variable_left self._variable_right = variable_right self._lvalue = result self._gas = None @@ -50,7 +52,11 @@ def read(self) -> List[SourceMapping]: return [self.variable_left, self.variable_right] @property - def variable_left(self) -> SourceMapping: + def variable_left( + self, + ) -> Union[ + RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType + ]: return self._variable_left @property diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index 08ddbd9604..10fa91efd4 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -1,11 +1,13 @@ from typing import Optional, Any, List, Union + +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract +from slither.core.variables import Variable from slither.slithir.operations import Call, OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant -from slither.core.declarations.contract import Contract from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA -from slither.core.declarations.function_contract import FunctionContract class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes @@ -60,6 +62,7 @@ def read(self) -> List[Any]: def contract_created(self) -> Contract: contract_name = self.contract_name contract_instance = self.node.file_scope.get_contract_from_name(contract_name) + assert contract_instance return contract_instance ################################################################################### @@ -68,7 +71,7 @@ def contract_created(self) -> Contract: ################################################################################### ################################################################################### - def can_reenter(self, callstack: Optional[List[FunctionContract]] = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view @@ -94,7 +97,7 @@ def can_send_eth(self) -> bool: # endregion - def __str__(self): + def __str__(self) -> str: options = "" if self.call_value: options = f"value:{self.call_value} " diff --git a/slither/slithir/operations/new_structure.py b/slither/slithir/operations/new_structure.py index 752de6a3d8..f24b3bccd9 100644 --- a/slither/slithir/operations/new_structure.py +++ b/slither/slithir/operations/new_structure.py @@ -14,7 +14,9 @@ class NewStructure(Call, OperationWithLValue): def __init__( - self, structure: StructureContract, lvalue: Union[TemporaryVariableSSA, TemporaryVariable] + self, + structure: StructureContract, + lvalue: Union[TemporaryVariableSSA, TemporaryVariable], ) -> None: super().__init__() assert isinstance(structure, Structure) diff --git a/slither/slithir/operations/operation.py b/slither/slithir/operations/operation.py index fcf5f48686..aca3e645bf 100644 --- a/slither/slithir/operations/operation.py +++ b/slither/slithir/operations/operation.py @@ -1,11 +1,14 @@ import abc -from typing import Any, List +from typing import Any, List, Optional, TYPE_CHECKING from slither.core.context.context import Context -from slither.core.children.child_expression import ChildExpression -from slither.core.children.child_node import ChildNode +from slither.core.expressions.expression import Expression from slither.core.variables.variable import Variable from slither.utils.utils import unroll +if TYPE_CHECKING: + from slither.core.compilation_unit import SlitherCompilationUnit + from slither.core.cfg.node import Node + class AbstractOperation(abc.ABC): @property @@ -25,7 +28,24 @@ def used(self): pass # pylint: disable=unnecessary-pass -class Operation(Context, ChildExpression, ChildNode, AbstractOperation): +class Operation(Context, AbstractOperation): + def __init__(self) -> None: + super().__init__() + self._node: Optional["Node"] = None + self._expression: Optional[Expression] = None + + def set_node(self, node: "Node") -> None: + self._node = node + + @property + def node(self) -> "Node": + assert self._node + return self._node + + @property + def compilation_unit(self) -> "SlitherCompilationUnit": + return self.node.compilation_unit + @property def used(self) -> List[Variable]: """ @@ -37,3 +57,10 @@ def used(self) -> List[Variable]: @staticmethod def _unroll(l: List[Any]) -> List[Any]: return unroll(l) + + def set_expression(self, expression: Expression) -> None: + self._expression = expression + + @property + def expression(self) -> Optional[Expression]: + return self._expression diff --git a/slither/slithir/operations/return_operation.py b/slither/slithir/operations/return_operation.py index c21579763d..290572ebf8 100644 --- a/slither/slithir/operations/return_operation.py +++ b/slither/slithir/operations/return_operation.py @@ -1,11 +1,10 @@ -from typing import List +from typing import List, Optional, Union, Any from slither.core.declarations import Function +from slither.core.variables.variable import Variable from slither.slithir.operations.operation import Operation - +from slither.slithir.utils.utils import is_valid_rvalue, RVALUE from slither.slithir.variables.tuple import TupleVariable -from slither.slithir.utils.utils import is_valid_rvalue -from slither.core.variables.variable import Variable class Return(Operation): @@ -14,10 +13,13 @@ class Return(Operation): Only present as last operation in RETURN node """ - def __init__(self, values) -> None: + def __init__( + self, values: Optional[Union[RVALUE, TupleVariable, Function, List[RVALUE]]] + ) -> None: # Note: Can return None # ex: return call() # where call() dont return + self._values: List[Union[RVALUE, TupleVariable, Function]] if not isinstance(values, list): assert ( is_valid_rvalue(values) @@ -25,20 +27,19 @@ def __init__(self, values) -> None: or values is None ) if values is None: - values = [] + self._values = [] else: - values = [values] + self._values = [values] else: # Remove None # Prior Solidity 0.5 # return (0,) # was valid for returns(uint) - values = [v for v in values if not v is None] - self._valid_value(values) + self._values = [v for v in values if not v is None] + self._valid_value(self._values) super().__init__() - self._values = values - def _valid_value(self, value) -> bool: + def _valid_value(self, value: Any) -> bool: if isinstance(value, list): assert all(self._valid_value(v) for v in value) else: @@ -53,5 +54,5 @@ def read(self) -> List[Variable]: def values(self) -> List[Variable]: return self._unroll(self._values) - def __str__(self): + def __str__(self) -> str: return f"RETURN {','.join([f'{x}' for x in self.values])}" diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index b059c55a67..ad0e139374 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -1,17 +1,16 @@ from typing import Any, List, Union -from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction +from slither.core.declarations.solidity_variables import SolidityFunction +from slither.core.solidity_types.elementary_type import ElementaryType from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue -from slither.core.children.child_node import ChildNode -from slither.core.solidity_types.elementary_type import ElementaryType class SolidityCall(Call, OperationWithLValue): def __init__( self, - function: Union[SolidityCustomRevert, SolidityFunction], + function: SolidityFunction, nbr_arguments: int, - result: ChildNode, + result, type_call: Union[str, List[ElementaryType]], ) -> None: assert isinstance(function, SolidityFunction) @@ -26,7 +25,7 @@ def read(self) -> List[Any]: return self._unroll(self.arguments) @property - def function(self) -> Union[SolidityCustomRevert, SolidityFunction]: + def function(self) -> SolidityFunction: return self._function @property diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index f351f1fdd5..e9998bc65b 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -1,13 +1,12 @@ from typing import List, Union + from slither.core.declarations import Contract -from slither.core.solidity_types.type import Type -from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue -import slither.core.declarations.contract from slither.core.solidity_types.elementary_type import ElementaryType -from slither.core.solidity_types.type_alias import TypeAliasContract, TypeAliasTopLevel +from slither.core.solidity_types.type_alias import TypeAlias from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA @@ -17,15 +16,15 @@ def __init__( self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: SourceMapping, - variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel], + variable_type: Union[TypeAlias, UserDefinedType, ElementaryType], ) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) assert is_valid_lvalue(result) - assert isinstance(variable_type, Type) + assert isinstance(variable_type, (TypeAlias, UserDefinedType, ElementaryType)) self._variable = variable - self._type = variable_type + self._type: Union[TypeAlias, UserDefinedType, ElementaryType] = variable_type self._lvalue = result @property @@ -35,18 +34,12 @@ def variable(self) -> SourceMapping: @property def type( self, - ) -> Union[ - TypeAliasContract, - TypeAliasTopLevel, - slither.core.declarations.contract.Contract, - UserDefinedType, - ElementaryType, - ]: + ) -> Union[TypeAlias, UserDefinedType, ElementaryType,]: return self._type @property def read(self) -> List[SourceMapping]: return [self.variable] - def __str__(self): + def __str__(self) -> str: return str(self.lvalue) + f" = CONVERT {self.variable} to {self.type}" diff --git a/slither/slithir/tmp_operations/argument.py b/slither/slithir/tmp_operations/argument.py index 25ea5d0191..638c5dcb4e 100644 --- a/slither/slithir/tmp_operations/argument.py +++ b/slither/slithir/tmp_operations/argument.py @@ -1,4 +1,7 @@ from enum import Enum +from typing import Optional, List + +from slither.core.expressions.expression import Expression from slither.slithir.operations.operation import Operation @@ -10,26 +13,26 @@ class ArgumentType(Enum): class Argument(Operation): - def __init__(self, argument) -> None: + def __init__(self, argument: Expression) -> None: super().__init__() self._argument = argument self._type = ArgumentType.CALL - self._callid = None + self._callid: Optional[str] = None @property - def argument(self): + def argument(self) -> Expression: return self._argument @property - def call_id(self): + def call_id(self) -> Optional[str]: return self._callid @call_id.setter - def call_id(self, c): + def call_id(self, c: str) -> None: self._callid = c @property - def read(self): + def read(self) -> List[Expression]: return [self.argument] def set_type(self, t: ArgumentType) -> None: @@ -39,7 +42,7 @@ def set_type(self, t: ArgumentType) -> None: def get_type(self) -> ArgumentType: return self._type - def __str__(self): + def __str__(self) -> str: call_id = "none" if self.call_id: call_id = f"(id ({self.call_id}))" diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 156914b611..9a180d14f6 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -366,7 +366,7 @@ def last_name( def is_used_later( initial_node: Node, - variable: Union[StateIRVariable, LocalVariable], + variable: Union[StateIRVariable, LocalVariable, TemporaryVariableSSA], ) -> bool: # TODO: does not handle the case where its read and written in the declaration node # It can be problematic if this happens in a loop/if structure @@ -751,8 +751,7 @@ def copy_ir(ir: Operation, *instances) -> Operation: lvalue = get_variable(ir, lambda x: x.lvalue, *instances) variable_left = get_variable(ir, lambda x: x.variable_left, *instances) variable_right = get_variable(ir, lambda x: x.variable_right, *instances) - index_type = ir.index_type - return Index(lvalue, variable_left, variable_right, index_type) + return Index(lvalue, variable_left, variable_right) if isinstance(ir, InitArray): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) init_values = get_rec_values(ir, lambda x: x.init_values, *instances) diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index 0a50f8e50e..49b1a879cc 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -1,3 +1,5 @@ +from typing import Union, Optional + from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable @@ -10,8 +12,26 @@ from slither.slithir.variables.tuple import TupleVariable from slither.core.source_mapping.source_mapping import SourceMapping +RVALUE = Union[ + StateVariable, + LocalVariable, + TopLevelVariable, + TemporaryVariable, + Constant, + SolidityVariable, + ReferenceVariable, +] + +LVALUE = Union[ + StateVariable, + LocalVariable, + TemporaryVariable, + ReferenceVariable, + TupleVariable, +] + -def is_valid_rvalue(v: SourceMapping) -> bool: +def is_valid_rvalue(v: Optional[SourceMapping]) -> bool: return isinstance( v, ( @@ -26,7 +46,7 @@ def is_valid_rvalue(v: SourceMapping) -> bool: ) -def is_valid_lvalue(v) -> bool: +def is_valid_lvalue(v: Optional[SourceMapping]) -> bool: return isinstance( v, ( diff --git a/slither/slithir/variables/constant.py b/slither/slithir/variables/constant.py index ddfc9e0546..5321e52500 100644 --- a/slither/slithir/variables/constant.py +++ b/slither/slithir/variables/constant.py @@ -28,7 +28,7 @@ def __init__( assert isinstance(constant_type, ElementaryType) self._type = constant_type if constant_type.type in Int + Uint + ["address"]: - self._val = convert_string_to_int(val) + self._val: Union[bool, int, str] = convert_string_to_int(val) elif constant_type.type == "bool": self._val = (val == "true") | (val == "True") else: @@ -41,6 +41,8 @@ def __init__( self._type = ElementaryType("string") self._val = val + self._name = str(self._val) + @property def value(self) -> Union[bool, int, str]: """ @@ -63,20 +65,18 @@ def original_value(self) -> str: def __str__(self) -> str: return str(self.value) - @property - def name(self) -> str: - return str(self) - - def __eq__(self, other: Union["Constant", str]) -> bool: + def __eq__(self, other: object) -> bool: return self.value == other - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return self.value != other - def __lt__(self, other): - return self.value < other + def __lt__(self, other: object) -> bool: + if not isinstance(other, (Constant, str)): + raise NotImplementedError + return self.value < other # type: ignore - def __repr__(self): + def __repr__(self) -> str: return f"{str(self.value)}" def __hash__(self) -> int: diff --git a/slither/slithir/variables/local_variable.py b/slither/slithir/variables/local_variable.py index eb32d40247..35b624a013 100644 --- a/slither/slithir/variables/local_variable.py +++ b/slither/slithir/variables/local_variable.py @@ -41,11 +41,11 @@ def __init__(self, local_variable: LocalVariable) -> None: self._non_ssa_version = local_variable @property - def index(self): + def index(self) -> int: return self._index @index.setter - def index(self, idx): + def index(self, idx: int) -> None: self._index = idx @property diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 95802b7e21..9ab51be655 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -1,6 +1,5 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.core.declarations import Contract, Enum, SolidityVariable, Function from slither.core.variables.variable import Variable @@ -8,7 +7,7 @@ from slither.core.cfg.node import Node -class ReferenceVariable(ChildNode, Variable): +class ReferenceVariable(Variable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -19,6 +18,10 @@ def __init__(self, node: "Node", index: Optional[int] = None) -> None: self._points_to = None self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/slithir/variables/state_variable.py b/slither/slithir/variables/state_variable.py index 7bb3a4077d..f7fb8ab8a8 100644 --- a/slither/slithir/variables/state_variable.py +++ b/slither/slithir/variables/state_variable.py @@ -30,11 +30,11 @@ def __init__(self, state_variable: StateVariable) -> None: self._non_ssa_version = state_variable @property - def index(self): + def index(self) -> int: return self._index @index.setter - def index(self, idx): + def index(self, idx: int) -> None: self._index = idx @property diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py index 8cb1cf3503..5a485f9856 100644 --- a/slither/slithir/variables/temporary.py +++ b/slither/slithir/variables/temporary.py @@ -1,13 +1,12 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.core.variables.variable import Variable if TYPE_CHECKING: from slither.core.cfg.node import Node -class TemporaryVariable(ChildNode, Variable): +class TemporaryVariable(Variable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -17,6 +16,10 @@ def __init__(self, node: "Node", index: Optional[int] = None) -> None: self._index = index self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/slithir/variables/tuple.py b/slither/slithir/variables/tuple.py index dc085347e9..9a13b1d5d0 100644 --- a/slither/slithir/variables/tuple.py +++ b/slither/slithir/variables/tuple.py @@ -1,13 +1,12 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.slithir.variables.variable import SlithIRVariable if TYPE_CHECKING: from slither.core.cfg.node import Node -class TupleVariable(ChildNode, SlithIRVariable): +class TupleVariable(SlithIRVariable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -18,6 +17,10 @@ def __init__(self, node: "Node", index: Optional[int] = None) -> None: self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/slithir/variables/variable.py b/slither/slithir/variables/variable.py index a1a1a6df9b..20d203ea43 100644 --- a/slither/slithir/variables/variable.py +++ b/slither/slithir/variables/variable.py @@ -7,8 +7,9 @@ def __init__(self) -> None: self._index = 0 @property - def ssa_name(self): + def ssa_name(self) -> str: + assert self.name return self.name - def __str__(self): + def __str__(self) -> str: return self.ssa_name diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 8ee755fa11..f3202d00cb 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,12 +1,18 @@ import logging import re -from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set - -from slither.core.declarations import Modifier, Event, EnumContract, StructureContract, Function -from slither.core.declarations.contract import Contract +from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set, Sequence + +from slither.core.declarations import ( + Modifier, + Event, + EnumContract, + StructureContract, + Function, +) +from slither.core.declarations.contract import Contract, USING_FOR_KEY from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.function_contract import FunctionContract -from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type +from slither.core.solidity_types import ElementaryType, TypeAliasContract from slither.core.variables.state_variable import StateVariable from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.custom_error import CustomErrorSolc @@ -297,7 +303,7 @@ def _parse_struct(self, struct: Dict) -> None: st.set_contract(self._contract) st.set_offset(struct["src"], self._contract.compilation_unit) - st_parser = StructureContractSolc(st, struct, self) + st_parser = StructureContractSolc(st, struct, self) # type: ignore self._contract.structures_as_dict[st.name] = st self._structures_parser.append(st_parser) @@ -307,7 +313,7 @@ def parse_structs(self) -> None: for struct in self._structuresNotParsed: self._parse_struct(struct) - self._structuresNotParsed = None + self._structuresNotParsed = [] def _parse_custom_error(self, custom_error: Dict) -> None: ce = CustomErrorContract(self.compilation_unit) @@ -324,7 +330,7 @@ def parse_custom_errors(self) -> None: for custom_error in self._customErrorParsed: self._parse_custom_error(custom_error) - self._customErrorParsed = None + self._customErrorParsed = [] def parse_state_variables(self) -> None: for father in self._contract.inheritance_reverse: @@ -351,6 +357,7 @@ def parse_state_variables(self) -> None: var_parser = StateVariableSolc(var, varNotParsed) self._variables_parser.append(var_parser) + assert var.name self._contract.variables_as_dict[var.name] = var self._contract.add_variables_ordered([var]) @@ -360,7 +367,7 @@ def _parse_modifier(self, modifier_data: Dict) -> None: modif.set_contract(self._contract) modif.set_contract_declarer(self._contract) - modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) + modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) # type: ignore self._contract.compilation_unit.add_modifier(modif) self._modifiers_no_params.append(modif_parser) self._modifiers_parser.append(modif_parser) @@ -370,7 +377,7 @@ def _parse_modifier(self, modifier_data: Dict) -> None: def parse_modifiers(self) -> None: for modifier in self._modifiersNotParsed: self._parse_modifier(modifier) - self._modifiersNotParsed = None + self._modifiersNotParsed = [] def _parse_function(self, function_data: Dict) -> None: func = FunctionContract(self._contract.compilation_unit) @@ -378,7 +385,7 @@ def _parse_function(self, function_data: Dict) -> None: func.set_contract(self._contract) func.set_contract_declarer(self._contract) - func_parser = FunctionSolc(func, function_data, self, self._slither_parser) + func_parser = FunctionSolc(func, function_data, self, self._slither_parser) # type: ignore self._contract.compilation_unit.add_function(func) self._functions_no_params.append(func_parser) self._functions_parser.append(func_parser) @@ -390,7 +397,7 @@ def parse_functions(self) -> None: for function in self._functionsNotParsed: self._parse_function(function) - self._functionsNotParsed = None + self._functionsNotParsed = [] # endregion ################################################################################### @@ -434,7 +441,8 @@ def analyze_params_modifiers(self) -> None: Cls_parser, self._modifiers_parser, ) - self._contract.set_modifiers(modifiers) + # modifiers will be using Modifier so we can ignore the next type check + self._contract.set_modifiers(modifiers) # type: ignore except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing params {e}") self._modifiers_no_params = [] @@ -454,7 +462,8 @@ def analyze_params_functions(self) -> None: Cls_parser, self._functions_parser, ) - self._contract.set_functions(functions) + # function will be using FunctionContract so we can ignore the next type check + self._contract.set_functions(functions) # type: ignore except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing params {e}") self._functions_no_params = [] @@ -465,7 +474,7 @@ def _analyze_params_element( # pylint: disable=too-many-arguments Cls_parser: Callable, element_parser: FunctionSolc, explored_reference_id: Set[str], - parser: List[FunctionSolc], + parser: Union[List[FunctionSolc], List[ModifierSolc]], all_elements: Dict[str, Function], ) -> None: elem = Cls(self._contract.compilation_unit) @@ -503,13 +512,13 @@ def _analyze_params_element( # pylint: disable=too-many-arguments def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-locals self, - elements_no_params: List[FunctionSolc], + elements_no_params: Sequence[FunctionSolc], getter: Callable[["ContractSolc"], List[FunctionSolc]], getter_available: Callable[[Contract], List[FunctionContract]], Cls: Callable, Cls_parser: Callable, - parser: List[FunctionSolc], - ) -> Dict[str, Union[FunctionContract, Modifier]]: + parser: Union[List[FunctionSolc], List[ModifierSolc]], + ) -> Dict[str, Function]: """ Analyze the parameters of the given elements (Function or Modifier). The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier) @@ -521,13 +530,13 @@ def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-loc :param Cls: Class to create for collision :return: """ - all_elements = {} + all_elements: Dict[str, Function] = {} - explored_reference_id = set() + explored_reference_id: Set[str] = set() try: for father in self._contract.inheritance: father_parser = self._slither_parser.underlying_contract_to_parser[father] - for element_parser in getter(father_parser): + for element_parser in getter(father_parser): # type: ignore self._analyze_params_element( Cls, Cls_parser, element_parser, explored_reference_id, parser, all_elements ) @@ -592,7 +601,7 @@ def analyze_using_for(self) -> None: # pylint: disable=too-many-branches if self.is_compact_ast: for using_for in self._usingForNotParsed: if "typeName" in using_for and using_for["typeName"]: - type_name = parse_type(using_for["typeName"], self) + type_name: USING_FOR_KEY = parse_type(using_for["typeName"], self) else: type_name = "*" if type_name not in self._contract.using_for: @@ -611,7 +620,7 @@ def analyze_using_for(self) -> None: # pylint: disable=too-many-branches assert children and len(children) <= 2 if len(children) == 2: new = parse_type(children[0], self) - old = parse_type(children[1], self) + old: USING_FOR_KEY = parse_type(children[1], self) else: new = parse_type(children[0], self) old = "*" @@ -622,7 +631,7 @@ def analyze_using_for(self) -> None: # pylint: disable=too-many-branches except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing using for {e}") - def _analyze_function_list(self, function_list: List, type_name: Type) -> None: + def _analyze_function_list(self, function_list: List, type_name: USING_FOR_KEY) -> None: for f in function_list: full_name_split = f["function"]["name"].split(".") if len(full_name_split) == 1: @@ -641,7 +650,9 @@ def _analyze_function_list(self, function_list: List, type_name: Type) -> None: function_name = full_name_split[2] self._analyze_library_function(library_name, function_name, type_name) - def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type) -> None: + def _check_aliased_import( + self, first_part: str, function_name: str, type_name: USING_FOR_KEY + ) -> None: # We check if the first part appear as alias for an import # if it is then function_name must be a top level function # otherwise it's a library function @@ -651,13 +662,13 @@ def _check_aliased_import(self, first_part: str, function_name: str, type_name: return self._analyze_library_function(first_part, function_name, type_name) - def _analyze_top_level_function(self, function_name: str, type_name: Type) -> None: + def _analyze_top_level_function(self, function_name: str, type_name: USING_FOR_KEY) -> None: for tl_function in self.compilation_unit.functions_top_level: if tl_function.name == function_name: self._contract.using_for[type_name].append(tl_function) def _analyze_library_function( - self, library_name: str, function_name: str, type_name: Type + self, library_name: str, function_name: str, type_name: USING_FOR_KEY ) -> None: # Get the library function found = False @@ -684,22 +695,13 @@ def analyze_enums(self) -> None: # for enum, we can parse and analyze it # at the same time self._analyze_enum(enum) - self._enumsNotParsed = None + self._enumsNotParsed = [] except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing enum {e}") def _analyze_enum( self, - enum: Dict[ - str, - Union[ - str, - int, - List[Dict[str, Union[int, str]]], - Dict[str, str], - List[Dict[str, Union[Dict[str, str], int, str]]], - ], - ], + enum: Dict, ) -> None: # Enum can be parsed in one pass if self.is_compact_ast: @@ -748,13 +750,13 @@ def analyze_events(self) -> None: event.set_contract(self._contract) event.set_offset(event_to_parse["src"], self._contract.compilation_unit) - event_parser = EventSolc(event, event_to_parse, self) - event_parser.analyze(self) + event_parser = EventSolc(event, event_to_parse, self) # type: ignore + event_parser.analyze(self) # type: ignore self._contract.events_as_dict[event.full_name] = event except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing event {e}") - self._eventsNotParsed = None + self._eventsNotParsed = [] # endregion ################################################################################### @@ -763,7 +765,7 @@ def analyze_events(self) -> None: ################################################################################### ################################################################################### - def delete_content(self): + def delete_content(self) -> None: """ Remove everything not parsed from the contract This is used only if something went wrong with the inheritance parsing @@ -828,7 +830,7 @@ def _handle_comment(self, attributes: Dict) -> None: ################################################################################### ################################################################################### - def __hash__(self): + def __hash__(self) -> int: return self._contract.id # endregion diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 9671d9bbe2..ba2f225f06 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -242,7 +242,7 @@ def _analyze_attributes(self) -> None: if "payable" in attributes: self._function.payable = attributes["payable"] - def analyze_params(self): + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: return @@ -272,7 +272,7 @@ def analyze_params(self): if returns: self._parse_returns(returns) - def analyze_content(self): + def analyze_content(self) -> None: if self._content_was_analyzed: return @@ -308,8 +308,8 @@ def analyze_content(self): for node_parser in self._node_to_nodesolc.values(): node_parser.analyze_expressions(self) - for node_parser in self._node_to_yulobject.values(): - node_parser.analyze_expressions() + for yul_parser in self._node_to_yulobject.values(): + yul_parser.analyze_expressions() self._rewrite_ternary_as_if_else() @@ -1297,7 +1297,7 @@ def _remove_incorrect_edges(self): son.remove_father(node) node.set_sons(new_sons) - def _remove_alone_endif(self): + def _remove_alone_endif(self) -> None: """ Can occur on: if(..){ diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index ea433a9216..d0dc4c7e02 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -481,11 +481,14 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) if name == "IndexAccess": if is_compact_ast: - index_type = expression["typeDescriptions"]["typeString"] + # We dont use the index type here, as we recover it later + # We could change the paradigm with the current AST parsing + # And do the type parsing in advanced for most of the operation + # index_type = expression["typeDescriptions"]["typeString"] left = expression["baseExpression"] right = expression.get("indexExpression", None) else: - index_type = expression["attributes"]["type"] + # index_type = expression["attributes"]["type"] children = expression["children"] left = children[0] right = children[1] if len(children) > 1 else None @@ -502,7 +505,7 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) left_expression = parse_expression(left, caller_context) right_expression = parse_expression(right, caller_context) - index = IndexAccess(left_expression, right_expression, index_type) + index = IndexAccess(left_expression, right_expression) index.set_offset(src, caller_context.compilation_unit) return index diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py index d21d89875d..69b72a521b 100644 --- a/slither/solc_parsing/variables/variable_declaration.py +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -1,6 +1,6 @@ import logging import re -from typing import Dict, Optional +from typing import Dict, Optional, Union from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.expressions.expression_parsing import parse_expression @@ -42,12 +42,12 @@ def __init__(self, variable: Variable, variable_data: Dict) -> None: self._variable = variable self._was_analyzed = False - self._elem_to_parse = None - self._initializedNotParsed = None + self._elem_to_parse: Optional[Union[Dict, UnknownType]] = None + self._initializedNotParsed: Optional[Dict] = None self._is_compact_ast = False - self._reference_id = None + self._reference_id: Optional[int] = None if "nodeType" in variable_data: self._is_compact_ast = True @@ -87,7 +87,7 @@ def __init__(self, variable: Variable, variable_data: Dict) -> None: declaration = variable_data["children"][0] self._init_from_declaration(declaration, init) elif nodeType == "VariableDeclaration": - self._init_from_declaration(variable_data, False) + self._init_from_declaration(variable_data, None) else: raise ParsingError(f"Incorrect variable declaration type {nodeType}") @@ -101,6 +101,7 @@ def reference_id(self) -> int: Return the solc id. It can be compared with the referencedDeclaration attr Returns None if it was not parsed (legacy AST) """ + assert self._reference_id return self._reference_id def _handle_comment(self, attributes: Dict) -> None: @@ -127,7 +128,7 @@ def _analyze_variable_attributes(self, attributes: Dict) -> None: self._variable.visibility = "internal" def _init_from_declaration( - self, var: Dict, init: Optional[bool] + self, var: Dict, init: Optional[Dict] ) -> None: # pylint: disable=too-many-branches if self._is_compact_ast: attributes = var @@ -195,7 +196,7 @@ def _init_from_declaration( self._initializedNotParsed = init elif len(var["children"]) in [0, 1]: self._variable.initialized = False - self._initializedNotParsed = [] + self._initializedNotParsed = None else: assert len(var["children"]) == 2 self._variable.initialized = True @@ -212,5 +213,6 @@ def analyze(self, caller_context: CallerContextExpression) -> None: self._elem_to_parse = None if self._variable.initialized: + assert self._initializedNotParsed self._variable.expression = parse_expression(self._initializedNotParsed, caller_context) self._initializedNotParsed = None diff --git a/slither/tools/doctor/checks/versions.py b/slither/tools/doctor/checks/versions.py index ec7ef1d1f3..00662b3e9f 100644 --- a/slither/tools/doctor/checks/versions.py +++ b/slither/tools/doctor/checks/versions.py @@ -1,6 +1,6 @@ from importlib import metadata import json -from typing import Optional +from typing import Optional, Any import urllib from packaging.version import parse, Version @@ -17,6 +17,7 @@ def get_installed_version(name: str) -> Optional[Version]: def get_github_version(name: str) -> Optional[Version]: try: + # type: ignore with urllib.request.urlopen( f"https://api.github.com/repos/crytic/{name}/releases/latest" ) as response: @@ -27,7 +28,7 @@ def get_github_version(name: str) -> Optional[Version]: return None -def show_versions(**_kwargs) -> None: +def show_versions(**_kwargs: Any) -> None: versions = { "Slither": (get_installed_version("slither-analyzer"), get_github_version("slither")), "crytic-compile": ( diff --git a/slither/tools/mutator/__main__.py b/slither/tools/mutator/__main__.py index 27e396d0b1..84286ce66c 100644 --- a/slither/tools/mutator/__main__.py +++ b/slither/tools/mutator/__main__.py @@ -79,9 +79,10 @@ def main() -> None: print(args.codebase) sl = Slither(args.codebase, **vars(args)) - for M in _get_mutators(): - m = M(sl) - m.mutate() + for compilation_unit in sl.compilation_units: + for M in _get_mutators(): + m = M(compilation_unit) + m.mutate() # endregion diff --git a/slither/tools/mutator/mutators/abstract_mutator.py b/slither/tools/mutator/mutators/abstract_mutator.py index 850c3c399a..169d8725e4 100644 --- a/slither/tools/mutator/mutators/abstract_mutator.py +++ b/slither/tools/mutator/mutators/abstract_mutator.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Optional, Dict -from slither import Slither +from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.utils.patches import apply_patch, create_diff logger = logging.getLogger("Slither") @@ -34,8 +34,11 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public- FAULTCLASS = FaultClass.Undefined FAULTNATURE = FaultNature.Undefined - def __init__(self, slither: Slither, rate: int = 10, seed: Optional[int] = None): - self.slither = slither + def __init__( + self, compilation_unit: SlitherCompilationUnit, rate: int = 10, seed: Optional[int] = None + ): + self.compilation_unit = compilation_unit + self.slither = compilation_unit.core self.seed = seed self.rate = rate @@ -87,7 +90,7 @@ def mutate(self) -> None: continue for patch in patches: patched_txt, offset = apply_patch(patched_txt, patch, offset) - diff = create_diff(self.slither, original_txt, patched_txt, file) + diff = create_diff(self.compilation_unit, original_txt, patched_txt, file) if not diff: logger.info(f"Impossible to generate patch; empty {patches}") print(diff) diff --git a/slither/tools/mutator/utils/command_line.py b/slither/tools/mutator/utils/command_line.py index 840976ccfc..feb479c5c8 100644 --- a/slither/tools/mutator/utils/command_line.py +++ b/slither/tools/mutator/utils/command_line.py @@ -18,6 +18,6 @@ def output_mutators(mutators_classes: List[Type[AbstractMutator]]) -> None: mutators_list = sorted(mutators_list, key=lambda element: (element[2], element[3], element[0])) idx = 1 for (argument, help_info, fault_class, fault_nature) in mutators_list: - table.add_row([idx, argument, help_info, fault_class, fault_nature]) + table.add_row([str(idx), argument, help_info, fault_class, fault_nature]) idx = idx + 1 print(table) diff --git a/slither/tools/read_storage/utils/utils.py b/slither/tools/read_storage/utils/utils.py index befd3d0e79..4a04a5b6d2 100644 --- a/slither/tools/read_storage/utils/utils.py +++ b/slither/tools/read_storage/utils/utils.py @@ -1,7 +1,8 @@ from typing import Union from eth_typing.evm import ChecksumAddress -from eth_utils import to_checksum_address, to_int, to_text +from eth_utils import to_int, to_text, to_checksum_address +from web3 import Web3 def get_offset_value(hex_bytes: bytes, offset: int, size: int) -> bytes: @@ -48,7 +49,7 @@ def coerce_type( if "address" in solidity_type: if not isinstance(value, (str, bytes)): raise TypeError - return to_checksum_address(value) + return to_checksum_address(value) # type: ignore if not isinstance(value, bytes): raise TypeError @@ -56,7 +57,7 @@ def coerce_type( def get_storage_data( - web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str] + web3: Web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str] ) -> bytes: """ Retrieves the storage data from the blockchain at target address and slot. diff --git a/slither/tools/similarity/cache.py b/slither/tools/similarity/cache.py index 53fc7f5f00..ccd64b84b9 100644 --- a/slither/tools/similarity/cache.py +++ b/slither/tools/similarity/cache.py @@ -1,4 +1,5 @@ import sys +from typing import Dict, Optional try: import numpy as np @@ -8,7 +9,7 @@ sys.exit(-1) -def load_cache(infile, nsamples=None): +def load_cache(infile: str, nsamples: Optional[int] = None) -> Dict: cache = {} with np.load(infile, allow_pickle=True) as data: array = data["arr_0"][0] @@ -20,5 +21,5 @@ def load_cache(infile, nsamples=None): return cache -def save_cache(cache, outfile): +def save_cache(cache: Dict, outfile: str) -> None: np.savez(outfile, [np.array(cache)]) diff --git a/slither/tools/upgradeability/checks/abstract_checks.py b/slither/tools/upgradeability/checks/abstract_checks.py index 016be2647b..a3ab137a32 100644 --- a/slither/tools/upgradeability/checks/abstract_checks.py +++ b/slither/tools/upgradeability/checks/abstract_checks.py @@ -34,6 +34,8 @@ class CheckClassification(ComparableEnum): CheckClassification.HIGH: "High", } +CHECK_INFO = List[Union[str, SupportedOutput]] + class AbstractCheck(metaclass=abc.ABCMeta): ARGUMENT = "" @@ -140,7 +142,7 @@ def check(self) -> List[Dict]: def generate_result( self, - info: Union[str, List[Union[str, SupportedOutput]]], + info: CHECK_INFO, additional_fields: Optional[Dict] = None, ) -> Output: output = Output( diff --git a/slither/tools/upgradeability/checks/constant.py b/slither/tools/upgradeability/checks/constant.py index a5a80bf5ab..bd98146496 100644 --- a/slither/tools/upgradeability/checks/constant.py +++ b/slither/tools/upgradeability/checks/constant.py @@ -1,7 +1,11 @@ +from typing import List + from slither.tools.upgradeability.checks.abstract_checks import ( AbstractCheck, CheckClassification, + CHECK_INFO, ) +from slither.utils.output import Output class WereConstant(AbstractCheck): @@ -47,10 +51,12 @@ class WereConstant(AbstractCheck): REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True - def _check(self): + def _check(self) -> List[Output]: contract_v1 = self.contract contract_v2 = self.contract_v2 + if contract_v2 is None: + raise Exception("were-constant requires a V2 contract") state_variables_v1 = contract_v1.state_variables state_variables_v2 = contract_v2.state_variables @@ -81,7 +87,7 @@ def _check(self): v2_additional_variables -= 1 idx_v2 += 1 continue - info = [state_v1, " was constant, but ", state_v2, "is not.\n"] + info: CHECK_INFO = [state_v1, " was constant, but ", state_v2, "is not.\n"] json = self.generate_result(info) results.append(json) @@ -134,10 +140,13 @@ class BecameConstant(AbstractCheck): REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True - def _check(self): + def _check(self) -> List[Output]: contract_v1 = self.contract contract_v2 = self.contract_v2 + if contract_v2 is None: + raise Exception("became-constant requires a V2 contract") + state_variables_v1 = contract_v1.state_variables state_variables_v2 = contract_v2.state_variables @@ -169,7 +178,7 @@ def _check(self): idx_v2 += 1 continue elif state_v2.is_constant: - info = [state_v1, " was not constant but ", state_v2, " is.\n"] + info: CHECK_INFO = [state_v1, " was not constant but ", state_v2, " is.\n"] json = self.generate_result(info) results.append(json) diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index e8ae9b26c3..b4535ddfe3 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -1,7 +1,11 @@ +from typing import List + from slither.tools.upgradeability.checks.abstract_checks import ( CheckClassification, AbstractCheck, + CHECK_INFO, ) +from slither.utils.output import Output class VariableWithInit(AbstractCheck): @@ -37,11 +41,11 @@ class VariableWithInit(AbstractCheck): REQUIRE_CONTRACT = True - def _check(self): + def _check(self) -> List[Output]: results = [] for s in self.contract.state_variables_ordered: if s.initialized and not (s.is_constant or s.is_immutable): - info = [s, " is a state variable with an initial value.\n"] + info: CHECK_INFO = [s, " is a state variable with an initial value.\n"] json = self.generate_result(info) results.append(json) return results diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index 030fb0f65f..fc83c44c6a 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -1,7 +1,12 @@ +from typing import List + +from slither.core.declarations import Contract from slither.tools.upgradeability.checks.abstract_checks import ( CheckClassification, AbstractCheck, + CHECK_INFO, ) +from slither.utils.output import Output class MissingVariable(AbstractCheck): @@ -45,9 +50,12 @@ class MissingVariable(AbstractCheck): REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True - def _check(self): + def _check(self) -> List[Output]: contract1 = self.contract contract2 = self.contract_v2 + + assert contract2 + order1 = [ variable for variable in contract1.state_variables_ordered @@ -63,7 +71,7 @@ def _check(self): for idx, _ in enumerate(order1): variable1 = order1[idx] if len(order2) <= idx: - info = ["Variable missing in ", contract2, ": ", variable1, "\n"] + info: CHECK_INFO = ["Variable missing in ", contract2, ": ", variable1, "\n"] json = self.generate_result(info) results.append(json) @@ -108,13 +116,14 @@ class DifferentVariableContractProxy(AbstractCheck): REQUIRE_CONTRACT = True REQUIRE_PROXY = True - def _contract1(self): + def _contract1(self) -> Contract: return self.contract - def _contract2(self): + def _contract2(self) -> Contract: + assert self.proxy return self.proxy - def _check(self): + def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() order1 = [ @@ -128,7 +137,7 @@ def _check(self): if not (variable.is_constant or variable.is_immutable) ] - results = [] + results: List[Output] = [] for idx, _ in enumerate(order1): if len(order2) <= idx: # Handle by MissingVariable @@ -137,7 +146,7 @@ def _check(self): variable1 = order1[idx] variable2 = order2[idx] if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info = [ + info: CHECK_INFO = [ "Different variables between ", contract1, " and ", @@ -190,7 +199,8 @@ class DifferentVariableContractNewContract(DifferentVariableContractProxy): REQUIRE_PROXY = False REQUIRE_CONTRACT_V2 = True - def _contract2(self): + def _contract2(self) -> Contract: + assert self.contract_v2 return self.contract_v2 @@ -235,13 +245,14 @@ class ExtraVariablesProxy(AbstractCheck): REQUIRE_CONTRACT = True REQUIRE_PROXY = True - def _contract1(self): + def _contract1(self) -> Contract: return self.contract - def _contract2(self): + def _contract2(self) -> Contract: + assert self.proxy return self.proxy - def _check(self): + def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() order1 = [ @@ -264,7 +275,7 @@ def _check(self): while idx < len(order2): variable2 = order2[idx] - info = ["Extra variables in ", contract2, ": ", variable2, "\n"] + info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] json = self.generate_result(info) results.append(json) idx = idx + 1 @@ -299,5 +310,6 @@ class ExtraVariablesNewContract(ExtraVariablesProxy): REQUIRE_PROXY = False REQUIRE_CONTRACT_V2 = True - def _contract2(self): + def _contract2(self) -> Contract: + assert self.contract_v2 return self.contract_v2 diff --git a/slither/tools/upgradeability/utils/command_line.py b/slither/tools/upgradeability/utils/command_line.py index 88b61ceed5..c5767a5221 100644 --- a/slither/tools/upgradeability/utils/command_line.py +++ b/slither/tools/upgradeability/utils/command_line.py @@ -63,7 +63,7 @@ def output_detectors(detector_classes: List[Type[AbstractCheck]]) -> None: def output_to_markdown(detector_classes: List[Type[AbstractCheck]], _filter_wiki: str) -> None: - def extract_help(cls: AbstractCheck) -> str: + def extract_help(cls: Type[AbstractCheck]) -> str: if cls.WIKI == "": return cls.HELP return f"[{cls.HELP}]({cls.WIKI})" diff --git a/slither/utils/code_complexity.py b/slither/utils/code_complexity.py index a389663b33..aa78384999 100644 --- a/slither/utils/code_complexity.py +++ b/slither/utils/code_complexity.py @@ -35,7 +35,7 @@ def compute_strongly_connected_components(function: "Function") -> List[List["No components = [] l = [] - def visit(node): + def visit(node: "Node") -> None: if not visited[node]: visited[node] = True for son in node.sons: @@ -45,7 +45,7 @@ def visit(node): for n in function.nodes: visit(n) - def assign(node: "Node", root: List["Node"]): + def assign(node: "Node", root: List["Node"]) -> None: if not assigned[node]: assigned[node] = True root.append(node) diff --git a/slither/utils/colors.py b/slither/utils/colors.py index 5d688489b4..1a2ff1da39 100644 --- a/slither/utils/colors.py +++ b/slither/utils/colors.py @@ -28,7 +28,7 @@ def enable_windows_virtual_terminal_sequences() -> bool: try: # pylint: disable=import-outside-toplevel - from ctypes import windll, byref + from ctypes import windll, byref # type: ignore from ctypes.wintypes import DWORD, HANDLE kernel32 = windll.kernel32 @@ -65,7 +65,7 @@ def enable_windows_virtual_terminal_sequences() -> bool: return True -def set_colorization_enabled(enabled: bool): +def set_colorization_enabled(enabled: bool) -> None: """ Sets the enabled state of output colorization. :param enabled: Boolean indicating whether output should be colorized. diff --git a/slither/utils/myprettytable.py b/slither/utils/myprettytable.py index a1dfd7ac01..af10a6ff25 100644 --- a/slither/utils/myprettytable.py +++ b/slither/utils/myprettytable.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Union from prettytable import PrettyTable @@ -8,7 +8,7 @@ def __init__(self, field_names: List[str]): self._field_names = field_names self._rows: List = [] - def add_row(self, row: List[str]) -> None: + def add_row(self, row: List[Union[str, List[str]]]) -> None: self._rows.append(row) def to_pretty_table(self) -> PrettyTable: diff --git a/slither/utils/output.py b/slither/utils/output.py index 9dba15e311..84c9ac65a1 100644 --- a/slither/utils/output.py +++ b/slither/utils/output.py @@ -10,8 +10,17 @@ from pkg_resources import require from slither.core.cfg.node import Node -from slither.core.declarations import Contract, Function, Enum, Event, Structure, Pragma +from slither.core.declarations import ( + Contract, + Function, + Enum, + Event, + Structure, + Pragma, + FunctionContract, +) from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.variable import Variable from slither.exceptions import SlitherError from slither.utils.colors import yellow @@ -351,21 +360,19 @@ def _create_parent_element( ], ]: # pylint: disable=import-outside-toplevel - from slither.core.children.child_contract import ChildContract - from slither.core.children.child_function import ChildFunction - from slither.core.children.child_inheritance import ChildInheritance + from slither.core.declarations.contract_level import ContractLevel - if isinstance(element, ChildInheritance): + if isinstance(element, FunctionContract): if element.contract_declarer: contract = Output("") contract.add_contract(element.contract_declarer) return contract.data["elements"][0] - elif isinstance(element, ChildContract): + elif isinstance(element, ContractLevel): if element.contract: contract = Output("") contract.add_contract(element.contract) return contract.data["elements"][0] - elif isinstance(element, ChildFunction): + elif isinstance(element, (LocalVariable, Node)): if element.function: function = Output("") function.add_function(element.function) diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 5f419ef999..12eb6be9d1 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -1,6 +1,7 @@ from fractions import Fraction -from typing import Union, TYPE_CHECKING +from typing import Union +from slither.core import expressions from slither.core.expressions import ( BinaryOperationType, Literal, @@ -11,12 +12,11 @@ TupleExpression, TypeConversion, ) +from slither.core.variables import Variable from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor - -if TYPE_CHECKING: - from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.solidity_types.elementary_type import ElementaryType class NotConstant(Exception): @@ -45,11 +45,19 @@ class ConstantFolding(ExpressionVisitor): def __init__( self, expression: CONSTANT_TYPES_OPERATIONS, custom_type: Union[str, "ElementaryType"] ) -> None: - self._type = custom_type + if isinstance(custom_type, str): + custom_type = ElementaryType(custom_type) + self._type: ElementaryType = custom_type super().__init__(expression) + @property + def expression(self) -> CONSTANT_TYPES_OPERATIONS: + # We make the assumption that the expression is always a CONSTANT_TYPES_OPERATIONS + # Other expression are not supported for constant unfolding + return self._expression # type: ignore + def result(self) -> "Literal": - value = get_val(self._expression) + value = get_val(self.expression) if isinstance(value, Fraction): value = int(value) # emulate 256-bit wrapping @@ -58,30 +66,75 @@ def result(self) -> "Literal": return Literal(value, self._type) def _post_identifier(self, expression: Identifier) -> None: + if not isinstance(expression.value, Variable): + return if not expression.value.is_constant: raise NotConstant expr = expression.value.expression # assumption that we won't have infinite loop - if not isinstance(expr, Literal): + # Everything outside of literal + if isinstance( + expr, (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion) + ): cf = ConstantFolding(expr, self._type) expr = cf.result() + assert isinstance(expr, Literal) set_val(expression, convert_string_to_int(expr.converted_value)) # pylint: disable=too-many-branches def _post_binary_operation(self, expression: BinaryOperation) -> None: - left = get_val(expression.expression_left) - right = get_val(expression.expression_right) - if expression.type == BinaryOperationType.POWER: - set_val(expression, left**right) - elif expression.type == BinaryOperationType.MULTIPLICATION: + expression_left = expression.expression_left + expression_right = expression.expression_right + if not isinstance( + expression_left, + (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ): + raise NotConstant + if not isinstance( + expression_right, + (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ): + raise NotConstant + + left = get_val(expression_left) + right = get_val(expression_right) + + if ( + expression.type == BinaryOperationType.POWER + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): + set_val(expression, left**right) # type: ignore + elif ( + expression.type == BinaryOperationType.MULTIPLICATION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left * right) - elif expression.type == BinaryOperationType.DIVISION: - set_val(expression, left / right) - elif expression.type == BinaryOperationType.MODULO: + elif ( + expression.type == BinaryOperationType.DIVISION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): + # TODO: maybe check for right + left to be int to use // ? + set_val(expression, left // right if isinstance(right, int) else left / right) + elif ( + expression.type == BinaryOperationType.MODULO + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left % right) - elif expression.type == BinaryOperationType.ADDITION: + elif ( + expression.type == BinaryOperationType.ADDITION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left + right) - elif expression.type == BinaryOperationType.SUBTRACTION: + elif ( + expression.type == BinaryOperationType.SUBTRACTION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left - right) # Convert to int for operations not supported by Fraction elif expression.type == BinaryOperationType.LEFT_SHIFT: @@ -118,7 +171,10 @@ def _post_unary_operation(self, expression: UnaryOperation) -> None: # Case of uint a = -7; uint[-a] arr; if expression.type == UnaryOperationType.MINUS_PRE: expr = expression.expression - if not isinstance(expr, Literal): + # Everything outside of literal + if isinstance( + expr, (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion) + ): cf = ConstantFolding(expr, self._type) expr = cf.result() assert isinstance(expr, Literal) @@ -135,45 +191,66 @@ def _post_literal(self, expression: Literal) -> None: except ValueError as e: raise NotConstant from e - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: expressions.AssignmentOperation) -> None: raise NotConstant - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: expressions.CallExpression) -> None: raise NotConstant - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: expressions.ConditionalExpression) -> None: raise NotConstant - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: expressions.ElementaryTypeNameExpression + ) -> None: raise NotConstant - def _post_index_access(self, expression): + def _post_index_access(self, expression: expressions.IndexAccess) -> None: raise NotConstant - def _post_member_access(self, expression): + def _post_member_access(self, expression: expressions.MemberAccess) -> None: raise NotConstant - def _post_new_array(self, expression): + def _post_new_array(self, expression: expressions.NewArray) -> None: raise NotConstant - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: expressions.NewContract) -> None: raise NotConstant - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: expressions.NewElementaryType) -> None: raise NotConstant - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: expressions.TupleExpression) -> None: if expression.expressions: if len(expression.expressions) == 1: - cf = ConstantFolding(expression.expressions[0], self._type) + first_expr = expression.expressions[0] + if not isinstance( + first_expr, + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + ), + ): + raise NotConstant + cf = ConstantFolding(first_expr, self._type) expr = cf.result() assert isinstance(expr, Literal) set_val(expression, convert_string_to_fraction(expr.converted_value)) return raise NotConstant - def _post_type_conversion(self, expression): - cf = ConstantFolding(expression.expression, self._type) + def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: + expr = expression.expression + if not isinstance( + expr, + (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ): + raise NotConstant + cf = ConstantFolding(expr, self._type) expr = cf.result() assert isinstance(expr, Literal) set_val(expression, convert_string_to_fraction(expr.converted_value)) diff --git a/slither/visitors/expression/export_values.py b/slither/visitors/expression/export_values.py index f5ca39a969..0c51e78315 100644 --- a/slither/visitors/expression/export_values.py +++ b/slither/visitors/expression/export_values.py @@ -1,4 +1,15 @@ -from typing import Any, List +from typing import Any, List, Optional + +from slither.core.expressions import ( + AssignmentOperation, + ConditionalExpression, + ElementaryTypeNameExpression, + IndexAccess, + NewArray, + NewContract, + UnaryOperation, + NewElementaryType, +) from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.identifier import Identifier @@ -25,12 +36,16 @@ def set_val(expression: Expression, val: List[Any]) -> None: class ExportValues(ExpressionVisitor): - def result(self) -> List[Any]: + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + + def result(self) -> List[Expression]: if self._result is None: self._result = list(set(get(self.expression))) return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -49,20 +64,22 @@ def _post_call_expression(self, expression: CallExpression) -> None: val = called + args set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = if_expr + else_expr + then_expr set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: set_val(expression, []) def _post_identifier(self, expression: Identifier) -> None: set_val(expression, [expression.value]) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -76,13 +93,13 @@ def _post_member_access(self, expression: MemberAccess) -> None: val = expr set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: @@ -95,7 +112,7 @@ def _post_type_conversion(self, expression: TypeConversion) -> None: val = expr set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr set_val(expression, val) diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index 464ea12858..41886a1023 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -1,5 +1,4 @@ import logging -from typing import Optional, Any from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -22,16 +21,14 @@ logger = logging.getLogger("ExpressionVisitor") +# pylint: disable=too-few-public-methods class ExpressionVisitor: def __init__(self, expression: Expression) -> None: - # Inherited class must declared their variables prior calling super().__init__ + super().__init__() + # Inherited class must declare their variables prior calling super().__init__ self._expression = expression - self._result: Any = None self._visit_expression(self.expression) - def result(self) -> Optional[bool]: - return self._result - @property def expression(self) -> Expression: return self._expression @@ -146,7 +143,7 @@ def _visit_new_array(self, expression: NewArray) -> None: def _visit_new_contract(self, expression: NewContract) -> None: pass - def _visit_new_elementary_type(self, expression): + def _visit_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _visit_tuple_expression(self, expression: TupleExpression) -> None: @@ -162,7 +159,7 @@ def _visit_unary_operation(self, expression: UnaryOperation) -> None: # pre visit - def _pre_visit(self, expression) -> None: # pylint: disable=too-many-branches + def _pre_visit(self, expression: Expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._pre_assignement_operation(expression) @@ -251,7 +248,7 @@ def _pre_new_array(self, expression: NewArray) -> None: def _pre_new_contract(self, expression: NewContract) -> None: pass - def _pre_new_elementary_type(self, expression): + def _pre_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _pre_tuple_expression(self, expression: TupleExpression) -> None: @@ -265,7 +262,7 @@ def _pre_unary_operation(self, expression: UnaryOperation) -> None: # post visit - def _post_visit(self, expression) -> None: # pylint: disable=too-many-branches + def _post_visit(self, expression: Expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._post_assignement_operation(expression) @@ -328,7 +325,7 @@ def _post_binary_operation(self, expression: BinaryOperation) -> None: def _post_call_expression(self, expression: CallExpression) -> None: pass - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: pass def _post_elementary_type_name_expression( @@ -354,7 +351,7 @@ def _post_new_array(self, expression: NewArray) -> None: def _post_new_contract(self, expression: NewContract) -> None: pass - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _post_tuple_expression(self, expression: TupleExpression) -> None: diff --git a/slither/visitors/expression/expression_printer.py b/slither/visitors/expression/expression_printer.py index 317e1ace67..601627c028 100644 --- a/slither/visitors/expression/expression_printer.py +++ b/slither/visitors/expression/expression_printer.py @@ -1,97 +1,107 @@ +from typing import Optional + +from slither.core import expressions +from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor -def get(expression): +def get(expression: Expression) -> str: val = expression.context["ExpressionPrinter"] # we delete the item to reduce memory use del expression.context["ExpressionPrinter"] return val -def set_val(expression, val): +def set_val(expression: Expression, val: str) -> None: expression.context["ExpressionPrinter"] = val class ExpressionPrinter(ExpressionVisitor): - def result(self): + def __init__(self, expression: Expression) -> None: + self._result: Optional[str] = None + super().__init__(expression) + + def result(self) -> str: if not self._result: self._result = get(self.expression) return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: expressions.AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = f"{left} {expression.type} {right}" set_val(expression, val) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: expressions.BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = f"{left} {expression.type} {right}" set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: expressions.CallExpression) -> None: called = get(expression.called) arguments = ",".join([get(x) for x in expression.arguments if x]) val = f"{called}({arguments})" set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: expressions.ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = f"if {if_expr} then {else_expr} else {then_expr}" set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: expressions.ElementaryTypeNameExpression + ) -> None: set_val(expression, str(expression.type)) - def _post_identifier(self, expression): + def _post_identifier(self, expression: expressions.Identifier) -> None: set_val(expression, str(expression.value)) - def _post_index_access(self, expression): + def _post_index_access(self, expression: expressions.IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = f"{left}[{right}]" set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: expressions.Literal) -> None: set_val(expression, str(expression.value)) - def _post_member_access(self, expression): + def _post_member_access(self, expression: expressions.MemberAccess) -> None: expr = get(expression.expression) member_name = str(expression.member_name) val = f"{expr}.{member_name}" set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: expressions.NewArray) -> None: array = str(expression.array_type) depth = expression.depth val = f"new {array}{'[]' * depth}" set_val(expression, val) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: expressions.NewContract) -> None: contract = str(expression.contract_name) val = f"new {contract}" set_val(expression, val) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: expressions.NewElementaryType) -> None: t = str(expression.type) val = f"new {t}" set_val(expression, val) - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = f"({','.join(expressions)})" + def _post_tuple_expression(self, expression: expressions.TupleExpression) -> None: + underlying_expressions = [get(e) for e in expression.expressions if e] + val = f"({','.join(underlying_expressions)})" set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: t = str(expression.type) expr = get(expression.expression) val = f"{t}({expr})" set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: expressions.UnaryOperation) -> None: t = str(expression.type) expr = get(expression.expression) if expression.is_prefix: diff --git a/slither/visitors/expression/find_calls.py b/slither/visitors/expression/find_calls.py index 6653a97592..ce00533ed8 100644 --- a/slither/visitors/expression/find_calls.py +++ b/slither/visitors/expression/find_calls.py @@ -1,5 +1,6 @@ -from typing import Any, Union, List +from typing import Any, Union, List, Optional +from slither.core.expressions import NewElementaryType from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import AssignmentOperation @@ -32,6 +33,10 @@ def set_val(expression: Expression, val: List[Union[Any, CallExpression]]) -> No class FindCalls(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + def result(self) -> List[Expression]: if self._result is None: self._result = list(set(get(self.expression))) @@ -51,8 +56,8 @@ def _post_binary_operation(self, expression: BinaryOperation) -> None: def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] + argss = [get(a) for a in expression.arguments if a] + args = [item for sublist in argss for item in sublist] val = called + args val += [expression] set_val(expression, val) @@ -93,7 +98,7 @@ def _post_new_array(self, expression: NewArray) -> None: def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: diff --git a/slither/visitors/expression/find_push.py b/slither/visitors/expression/find_push.py deleted file mode 100644 index cf2b07e601..0000000000 --- a/slither/visitors/expression/find_push.py +++ /dev/null @@ -1,96 +0,0 @@ -from slither.visitors.expression.expression import ExpressionVisitor - -from slither.visitors.expression.right_value import RightValue - -key = "FindPush" - - -def get(expression): - val = expression.context[key] - # we delete the item to reduce memory use - del expression.context[key] - return val - - -def set_val(expression, val): - expression.context[key] = val - - -class FindPush(ExpressionVisitor): - def result(self): - if self._result is None: - self._result = list(set(get(self.expression))) - return self._result - - def _post_assignement_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_binary_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_call_expression(self, expression): - called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] - val = called + args - set_val(expression, val) - - def _post_conditional_expression(self, expression): - if_expr = get(expression.if_expression) - else_expr = get(expression.else_expression) - then_expr = get(expression.then_expression) - val = if_expr + else_expr + then_expr - set_val(expression, val) - - def _post_elementary_type_name_expression(self, expression): - set_val(expression, []) - - # save only identifier expression - def _post_identifier(self, expression): - set_val(expression, []) - - def _post_index_access(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_literal(self, expression): - set_val(expression, []) - - def _post_member_access(self, expression): - val = [] - if expression.member_name == "push": - right = RightValue(expression.expression) - val = right.result() - set_val(expression, val) - - def _post_new_array(self, expression): - set_val(expression, []) - - def _post_new_contract(self, expression): - set_val(expression, []) - - def _post_new_elementary_type(self, expression): - set_val(expression, []) - - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = [item for sublist in expressions for item in sublist] - set_val(expression, val) - - def _post_type_conversion(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_unary_operation(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) diff --git a/slither/visitors/expression/has_conditional.py b/slither/visitors/expression/has_conditional.py index b866a696b5..613138533f 100644 --- a/slither/visitors/expression/has_conditional.py +++ b/slither/visitors/expression/has_conditional.py @@ -1,13 +1,15 @@ +from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.conditional_expression import ConditionalExpression class HasConditional(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self._result: bool = False + super().__init__(expression) + def result(self) -> bool: - # == True, to convert None to false - return self._result is True + return self._result def _post_conditional_expression(self, expression: ConditionalExpression) -> None: - # if self._result is True: - # raise('Slither does not support nested ternary operator') self._result = True diff --git a/slither/visitors/expression/left_value.py b/slither/visitors/expression/left_value.py deleted file mode 100644 index 3b16c8c26a..0000000000 --- a/slither/visitors/expression/left_value.py +++ /dev/null @@ -1,109 +0,0 @@ -# Return the 'left' value of an expression - -from slither.visitors.expression.expression import ExpressionVisitor - -from slither.core.expressions.assignment_operation import AssignmentOperationType - -from slither.core.variables.variable import Variable - -key = "LeftValue" - - -def get(expression): - val = expression.context[key] - # we delete the item to reduce memory use - del expression.context[key] - return val - - -def set_val(expression, val): - expression.context[key] = val - - -class LeftValue(ExpressionVisitor): - def result(self): - if self._result is None: - self._result = list(set(get(self.expression))) - return self._result - - # overide index access visitor to explore only left part - def _visit_index_access(self, expression): - self._visit_expression(expression.expression_left) - - def _post_assignement_operation(self, expression): - if expression.type != AssignmentOperationType.ASSIGN: - left = get(expression.expression_left) - else: - left = [] - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_binary_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_call_expression(self, expression): - called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] - val = called + args - set_val(expression, val) - - def _post_conditional_expression(self, expression): - if_expr = get(expression.if_expression) - else_expr = get(expression.else_expression) - then_expr = get(expression.then_expression) - val = if_expr + else_expr + then_expr - set_val(expression, val) - - def _post_elementary_type_name_expression(self, expression): - set_val(expression, []) - - # save only identifier expression - def _post_identifier(self, expression): - if isinstance(expression.value, Variable): - set_val(expression, [expression.value]) - # elif isinstance(expression.value, SolidityInbuilt): - # set_val(expression, [expression]) - else: - set_val(expression, []) - - def _post_index_access(self, expression): - left = get(expression.expression_left) - val = left - set_val(expression, val) - - def _post_literal(self, expression): - set_val(expression, []) - - def _post_member_access(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_new_array(self, expression): - set_val(expression, []) - - def _post_new_contract(self, expression): - set_val(expression, []) - - def _post_new_elementary_type(self, expression): - set_val(expression, []) - - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = [item for sublist in expressions for item in sublist] - set_val(expression, val) - - def _post_type_conversion(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_unary_operation(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) diff --git a/slither/visitors/expression/read_var.py b/slither/visitors/expression/read_var.py index e8f5aae67e..a0efdde618 100644 --- a/slither/visitors/expression/read_var.py +++ b/slither/visitors/expression/read_var.py @@ -1,5 +1,6 @@ -from typing import Any, List, Union +from typing import Any, List, Union, Optional +from slither.core.expressions import NewElementaryType from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import ( @@ -40,7 +41,11 @@ def set_val(expression: Expression, val: List[Union[Identifier, IndexAccess, Any class ReadVar(ExpressionVisitor): - def result(self) -> List[Union[Identifier, IndexAccess, Any]]: + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + + def result(self) -> List[Expression]: if self._result is None: self._result = list(set(get(self.expression))) return self._result @@ -69,8 +74,8 @@ def _post_binary_operation(self, expression: BinaryOperation) -> None: def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] + argss = [get(a) for a in expression.arguments if a] + args = [item for sublist in argss for item in sublist] val = called + args set_val(expression, val) @@ -91,6 +96,7 @@ def _post_identifier(self, expression: Identifier) -> None: if isinstance(expression.value, Variable): set_val(expression, [expression]) elif isinstance(expression.value, SolidityVariable): + # TODO: investigate if this branch can be reached, and if Identifier.value has the correct type set_val(expression, [expression]) else: set_val(expression, []) @@ -115,7 +121,7 @@ def _post_new_array(self, expression: NewArray) -> None: def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: diff --git a/slither/visitors/expression/right_value.py b/slither/visitors/expression/right_value.py deleted file mode 100644 index 5a97846bcb..0000000000 --- a/slither/visitors/expression/right_value.py +++ /dev/null @@ -1,115 +0,0 @@ -# Return the 'right' value of an expression -# On index access, explore the left -# on member access, return the member_name -# a.b.c[d] returns c - -from slither.visitors.expression.expression import ExpressionVisitor - -from slither.core.expressions.assignment_operation import AssignmentOperationType -from slither.core.expressions.expression import Expression - -from slither.core.variables.variable import Variable - -key = "RightValue" - - -def get(expression): - val = expression.context[key] - # we delete the item to reduce memory use - del expression.context[key] - return val - - -def set_val(expression, val): - expression.context[key] = val - - -class RightValue(ExpressionVisitor): - def result(self): - if self._result is None: - self._result = list(set(get(self.expression))) - return self._result - - # overide index access visitor to explore only left part - def _visit_index_access(self, expression): - self._visit_expression(expression.expression_left) - - def _post_assignement_operation(self, expression): - if expression.type != AssignmentOperationType.ASSIGN: - left = get(expression.expression_left) - else: - left = [] - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_binary_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_call_expression(self, expression): - called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] - val = called + args - set_val(expression, val) - - def _post_conditional_expression(self, expression): - if_expr = get(expression.if_expression) - else_expr = get(expression.else_expression) - then_expr = get(expression.then_expression) - val = if_expr + else_expr + then_expr - set_val(expression, val) - - def _post_elementary_type_name_expression(self, expression): - set_val(expression, []) - - # save only identifier expression - def _post_identifier(self, expression): - if isinstance(expression.value, Variable): - set_val(expression, [expression.value]) - # elif isinstance(expression.value, SolidityInbuilt): - # set_val(expression, [expression]) - else: - set_val(expression, []) - - def _post_index_access(self, expression): - left = get(expression.expression_left) - val = left - set_val(expression, val) - - def _post_literal(self, expression): - set_val(expression, []) - - def _post_member_access(self, expression): - val = [] - if isinstance(expression.member_name, Expression): - expr = get(expression.member_name) - val = expr - set_val(expression, val) - - def _post_new_array(self, expression): - set_val(expression, []) - - def _post_new_contract(self, expression): - set_val(expression, []) - - def _post_new_elementary_type(self, expression): - set_val(expression, []) - - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = [item for sublist in expressions for item in sublist] - set_val(expression, val) - - def _post_type_conversion(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_unary_operation(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) diff --git a/slither/visitors/expression/write_var.py b/slither/visitors/expression/write_var.py index 97d3858e7e..1c0b6108f5 100644 --- a/slither/visitors/expression/write_var.py +++ b/slither/visitors/expression/write_var.py @@ -1,4 +1,6 @@ -from typing import Any, List +from typing import Any, List, Optional + +from slither.core.expressions import NewElementaryType from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -32,6 +34,10 @@ def set_val(expression: Expression, val: List[Any]) -> None: class WriteVar(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + def result(self) -> List[Any]: if self._result is None: self._result = list(set(get(self.expression))) @@ -123,7 +129,7 @@ def _post_new_array(self, expression: NewArray) -> None: def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index 3a975939ab..90905be4ee 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -1,14 +1,16 @@ import logging -from typing import Union, List, TYPE_CHECKING +from typing import Union, List, TYPE_CHECKING, Any +from slither.core import expressions from slither.core.declarations import ( Function, SolidityVariable, SolidityVariableComposed, SolidityFunction, Contract, + EnumContract, + EnumTopLevel, ) -from slither.core.declarations.enum import Enum from slither.core.expressions import ( AssignmentOperation, AssignmentOperationType, @@ -18,6 +20,8 @@ CallExpression, Identifier, MemberAccess, + ConditionalExpression, + NewElementaryType, ) from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.expression import Expression @@ -27,7 +31,7 @@ from slither.core.expressions.new_contract import NewContract from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.unary_operation import UnaryOperation -from slither.core.solidity_types import ArrayType, ElementaryType, TypeAlias +from slither.core.solidity_types import ArrayType, ElementaryType, TypeAlias, UserDefinedType from slither.core.solidity_types.type import Type from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple @@ -71,18 +75,14 @@ key = "expressionToSlithIR" -def get(expression: Union[Expression, Operation]): +def get(expression: Expression) -> Any: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def get_without_removing(expression): - return expression.context[key] - - -def set_val(expression: Union[Expression, Operation], val) -> None: +def set_val(expression: Expression, val: Any) -> None: expression.context[key] = val @@ -121,7 +121,7 @@ def convert_assignment( left: Union[LocalVariable, StateVariable, ReferenceVariable], right: Union[LocalVariable, StateVariable, ReferenceVariable], t: AssignmentOperationType, - return_type, + return_type: Type, ) -> Union[Binary, Assignment]: if t == AssignmentOperationType.ASSIGN: return Assignment(left, right, return_type) @@ -150,6 +150,7 @@ def convert_assignment( class ExpressionToSlithIR(ExpressionVisitor): + # pylint: disable=super-init-not-called def __init__(self, expression: Expression, node: "Node") -> None: from slither.core.cfg.node import NodeType # pylint: disable=import-outside-toplevel @@ -171,11 +172,16 @@ def result(self) -> List[Operation]: def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) + operation: Operation if isinstance(left, list): # tuple expression: if isinstance(right, list): # unbox assigment assert len(left) == len(right) for idx, _ in enumerate(left): - if not left[idx] is None: + if ( + not left[idx] is None + and expression.type + and expression.expression_return_type + ): operation = convert_assignment( left[idx], right[idx], @@ -282,6 +288,8 @@ def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression_called) args = [get(a) for a in expression.arguments if a] + val: Union[TupleVariable, TemporaryVariable] + var: Operation for arg in args: arg_ = Argument(arg) arg_.set_expression(expression) @@ -290,6 +298,7 @@ def _post_call_expression(self, expression: CallExpression) -> None: # internal call # If tuple + if expression.type_call.startswith("tuple(") and expression.type_call != "tuple()": val = TupleVariable(self._node) else: @@ -308,7 +317,7 @@ def _post_call_expression(self, expression: CallExpression) -> None: ): # wrap: underlying_type -> alias # unwrap: alias -> underlying_type - dest_type = ( + dest_type: Union[TypeAlias, ElementaryType] = ( called if expression_called.member_name == "wrap" else called.underlying_type ) val = TemporaryVariable(self._node) @@ -321,19 +330,19 @@ def _post_call_expression(self, expression: CallExpression) -> None: # yul things elif called.name == "caller()": val = TemporaryVariable(self._node) - var = Assignment(val, SolidityVariableComposed("msg.sender"), "uint256") + var = Assignment(val, SolidityVariableComposed("msg.sender"), ElementaryType("uint256")) self._result.append(var) set_val(expression, val) elif called.name == "origin()": val = TemporaryVariable(self._node) - var = Assignment(val, SolidityVariableComposed("tx.origin"), "uint256") + var = Assignment(val, SolidityVariableComposed("tx.origin"), ElementaryType("uint256")) self._result.append(var) set_val(expression, val) elif called.name == "extcodesize(uint256)": - val = ReferenceVariable(self._node) - var = Member(args[0], Constant("codesize"), val) + val_ref = ReferenceVariable(self._node) + var = Member(args[0], Constant("codesize"), val_ref) self._result.append(var) - set_val(expression, val) + set_val(expression, val_ref) elif called.name == "selfbalance()": val = TemporaryVariable(self._node) var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address")) @@ -352,7 +361,7 @@ def _post_call_expression(self, expression: CallExpression) -> None: set_val(expression, val) elif called.name == "callvalue()": val = TemporaryVariable(self._node) - var = Assignment(val, SolidityVariableComposed("msg.value"), "uint256") + var = Assignment(val, SolidityVariableComposed("msg.value"), ElementaryType("uint256")) self._result.append(var) set_val(expression, val) @@ -379,7 +388,7 @@ def _post_call_expression(self, expression: CallExpression) -> None: self._result.append(message_call) set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: raise Exception(f"Ternary operator are not convertible to SlithIR {expression}") def _post_elementary_type_name_expression( @@ -394,12 +403,13 @@ def _post_identifier(self, expression: Identifier) -> None: def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) + operation: Operation # Left can be a type for abi.decode(var, uint[2]) if isinstance(left, Type): # Nested type are not yet supported by abi.decode, so the assumption # Is that the right variable must be a constant assert isinstance(right, Constant) - t = ArrayType(left, right.value) + t = ArrayType(left, int(right.value)) set_val(expression, t) return val = ReferenceVariable(self._node) @@ -412,13 +422,15 @@ def _post_index_access(self, expression: IndexAccess) -> None: operation = InitArray(init_array_right, init_array_val) operation.set_expression(expression) self._result.append(operation) - operation = Index(val, left, right, expression.type) + operation = Index(val, left, right) operation.set_expression(expression) self._result.append(operation) set_val(expression, val) def _post_literal(self, expression: Literal) -> None: - cst = Constant(expression.value, expression.type, expression.subdenomination) + expression_type = expression.type + assert isinstance(expression_type, ElementaryType) + cst = Constant(expression.value, expression_type, expression.subdenomination) set_val(expression, cst) def _post_member_access(self, expression: MemberAccess) -> None: @@ -436,25 +448,33 @@ def _post_member_access(self, expression: MemberAccess) -> None: assert len(expression.expression.arguments) == 1 val = TemporaryVariable(self._node) type_expression_found = expression.expression.arguments[0] + type_found: Union[ElementaryType, UserDefinedType] if isinstance(type_expression_found, ElementaryTypeNameExpression): - type_found = type_expression_found.type + type_expression_found_type = type_expression_found.type + assert isinstance(type_expression_found_type, ElementaryType) + type_found = type_expression_found_type + min_value = type_found.min + max_value = type_found.max constant_type = type_found else: # type(enum).max/min assert isinstance(type_expression_found, Identifier) - type_found = type_expression_found.value - assert isinstance(type_found, Enum) + type_found_in_expression = type_expression_found.value + assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel)) + type_found = UserDefinedType(type_found_in_expression) constant_type = None + min_value = type_found_in_expression.min + max_value = type_found_in_expression.max if expression.member_name == "min": op = Assignment( val, - Constant(str(type_found.min), constant_type), + Constant(str(min_value), constant_type), type_found, ) else: op = Assignment( val, - Constant(str(type_found.max), constant_type), + Constant(str(max_value), constant_type), type_found, ) self._result.append(op) @@ -500,11 +520,11 @@ def _post_member_access(self, expression: MemberAccess) -> None: set_val(expression, expr.custom_errors_as_dict[expression.member_name]) return - val = ReferenceVariable(self._node) - member = Member(expr, Constant(expression.member_name), val) + val_ref = ReferenceVariable(self._node) + member = Member(expr, Constant(expression.member_name), val_ref) member.set_expression(expression) self._result.append(member) - set_val(expression, val) + set_val(expression, val_ref) def _post_new_array(self, expression: NewArray) -> None: val = TemporaryVariable(self._node) @@ -527,7 +547,7 @@ def _post_new_contract(self, expression: NewContract) -> None: self._result.append(operation) set_val(expression, val) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: # TODO unclear if this is ever used? val = TemporaryVariable(self._node) operation = TmpNewElementaryType(expression.type, val) @@ -536,17 +556,20 @@ def _post_new_elementary_type(self, expression): set_val(expression, val) def _post_tuple_expression(self, expression: TupleExpression) -> None: - expressions = [get(e) if e else None for e in expression.expressions] - if len(expressions) == 1: - val = expressions[0] + all_expressions = [get(e) if e else None for e in expression.expressions] + if len(all_expressions) == 1: + val = all_expressions[0] else: - val = expressions + val = all_expressions set_val(expression, val) - def _post_type_conversion(self, expression: TypeConversion) -> None: + def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: + assert expression.expression expr = get(expression.expression) val = TemporaryVariable(self._node) - operation = TypeConversion(val, expr, expression.type) + expression_type = expression.type + assert isinstance(expression_type, (TypeAlias, UserDefinedType, ElementaryType)) + operation = TypeConversion(val, expr, expression_type) val.set_type(expression.type) operation.set_expression(expression) self._result.append(operation) @@ -555,6 +578,7 @@ def _post_type_conversion(self, expression: TypeConversion) -> None: # pylint: disable=too-many-statements def _post_unary_operation(self, expression: UnaryOperation) -> None: value = get(expression.expression) + operation: Operation if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: lvalue = TemporaryVariable(self._node) operation = Unary(lvalue, value, expression.type) @@ -598,6 +622,7 @@ def _post_unary_operation(self, expression: UnaryOperation) -> None: set_val(expression, value) elif expression.type in [UnaryOperationType.MINUS_PRE]: lvalue = TemporaryVariable(self._node) + assert isinstance(value.type, ElementaryType) operation = Binary(lvalue, Constant("0", value.type), value, BinaryType.SUBTRACTION) operation.set_expression(expression) self._result.append(operation) diff --git a/tests/test_ssa_generation.py b/tests/test_ssa_generation.py index 00d2f23fd4..94620285eb 100644 --- a/tests/test_ssa_generation.py +++ b/tests/test_ssa_generation.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from inspect import getsourcefile from tempfile import NamedTemporaryFile -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict, Callable import pytest from solc_select import solc_select @@ -15,6 +15,7 @@ from slither import Slither from slither.core.cfg.node import Node, NodeType from slither.core.declarations import Function, Contract +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.slithir.operations import ( OperationWithLValue, @@ -35,10 +36,11 @@ ReferenceVariable, LocalIRVariable, StateIRVariable, + TemporaryVariableSSA, ) # Directory of currently executing script. Will be used as basis for temporary file names. -SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent +SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent # type:ignore def valid_version(ver: str) -> bool: @@ -54,15 +56,15 @@ def valid_version(ver: str) -> bool: return False -def have_ssa_if_ir(function: Function): +def have_ssa_if_ir(function: Function) -> None: """Verifies that all nodes in a function that have IR also have SSA IR""" for n in function.nodes: if n.irs: assert n.irs_ssa -# pylint: disable=too-many-branches -def ssa_basic_properties(function: Function): +# pylint: disable=too-many-branches, too-many-locals +def ssa_basic_properties(function: Function) -> None: """Verifies that basic properties of ssa holds 1. Every name is defined only once @@ -76,12 +78,14 @@ def ssa_basic_properties(function: Function): """ ssa_lvalues = set() ssa_rvalues = set() - lvalue_assignments = {} + lvalue_assignments: Dict[str, int] = {} for n in function.nodes: for ir in n.irs: - if isinstance(ir, OperationWithLValue): + if isinstance(ir, OperationWithLValue) and ir.lvalue: name = ir.lvalue.name + if name is None: + continue if name in lvalue_assignments: lvalue_assignments[name] += 1 else: @@ -94,8 +98,9 @@ def ssa_basic_properties(function: Function): ssa_lvalues.add(ssa.lvalue) # 2 (if Local/State Var) - if isinstance(ssa.lvalue, (StateIRVariable, LocalIRVariable)): - assert ssa.lvalue.index > 0 + ssa_lvalue = ssa.lvalue + if isinstance(ssa_lvalue, (StateIRVariable, LocalIRVariable)): + assert ssa_lvalue.index > 0 for rvalue in filter( lambda x: not isinstance(x, (StateIRVariable, Constant)), ssa.read @@ -112,15 +117,18 @@ def ssa_basic_properties(function: Function): undef_vars.add(rvalue.non_ssa_version) # 4 - ssa_defs = defaultdict(int) + ssa_defs: Dict[str, int] = defaultdict(int) for v in ssa_lvalues: - ssa_defs[v.name] += 1 + if v and v.name: + ssa_defs[v.name] += 1 - for (k, n) in lvalue_assignments.items(): - assert ssa_defs[k] >= n + for (k, count) in lvalue_assignments.items(): + assert ssa_defs[k] >= count # Helper 5/6 - def check_property_5_and_6(variables, ssavars): + def check_property_5_and_6( + variables: List[LocalVariable], ssavars: List[LocalIRVariable] + ) -> None: for var in filter(lambda x: x.name, variables): ssa_vars = [x for x in ssavars if x.non_ssa_version == var] assert len(ssa_vars) == 1 @@ -137,7 +145,7 @@ def check_property_5_and_6(variables, ssavars): check_property_5_and_6(function.returns, function.returns_ssa) -def ssa_phi_node_properties(f: Function): +def ssa_phi_node_properties(f: Function) -> None: """Every phi-function should have as many args as predecessors This does not apply if the phi-node refers to state variables, @@ -153,7 +161,7 @@ def ssa_phi_node_properties(f: Function): # TODO (hbrodin): This should probably go into another file, not specific to SSA -def dominance_properties(f: Function): +def dominance_properties(f: Function) -> None: """Verifies properties related to dominators holds 1. Every node have an immediate dominator except entry_node which have none @@ -181,14 +189,16 @@ def find_path(from_node: Node, to: Node) -> bool: assert find_path(node.immediate_dominator, node) -def phi_values_inserted(f: Function): +def phi_values_inserted(f: Function) -> None: """Verifies that phi-values are inserted at the right places For every node that has a dominance frontier, any def (including phi) should be a phi function in its dominance frontier """ - def have_phi_for_var(node: Node, var): + def have_phi_for_var( + node: Node, var: Union[StateIRVariable, LocalIRVariable, TemporaryVariableSSA] + ) -> bool: """Checks if a node has a phi-instruction for var The ssa version would ideally be checked, but then @@ -199,7 +209,14 @@ def have_phi_for_var(node: Node, var): non_ssa = var.non_ssa_version for ssa in node.irs_ssa: if isinstance(ssa, Phi): - if non_ssa in map(lambda ssa_var: ssa_var.non_ssa_version, ssa.read): + if non_ssa in map( + lambda ssa_var: ssa_var.non_ssa_version, + [ + r + for r in ssa.read + if isinstance(r, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA)) + ], + ): return True return False @@ -207,12 +224,15 @@ def have_phi_for_var(node: Node, var): for df in node.dominance_frontier: for ssa in node.irs_ssa: if isinstance(ssa, OperationWithLValue): - if is_used_later(node, ssa.lvalue): - assert have_phi_for_var(df, ssa.lvalue) + ssa_lvalue = ssa.lvalue + if isinstance( + ssa_lvalue, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA) + ) and is_used_later(node, ssa_lvalue): + assert have_phi_for_var(df, ssa_lvalue) @contextmanager -def select_solc_version(version: Optional[str]): +def select_solc_version(version: Optional[str]) -> None: """Selects solc version to use for running tests. If no version is provided, latest is used.""" @@ -257,17 +277,17 @@ def slither_from_source(source_code: str, solc_version: Optional[str] = None): pathlib.Path(fname).unlink() -def verify_properties_hold(source_code_or_slither: Union[str, Slither]): +def verify_properties_hold(source_code_or_slither: Union[str, Slither]) -> None: """Ensures that basic properties of SSA hold true""" - def verify_func(func: Function): + def verify_func(func: Function) -> None: have_ssa_if_ir(func) phi_values_inserted(func) ssa_basic_properties(func) ssa_phi_node_properties(func) dominance_properties(func) - def verify(slither): + def verify(slither: Slither) -> None: for cu in slither.compilation_units: for func in cu.functions_and_modifiers: _dump_function(func) @@ -281,11 +301,12 @@ def verify(slither): if isinstance(source_code_or_slither, Slither): verify(source_code_or_slither) else: + slither: Slither with slither_from_source(source_code_or_slither) as slither: verify(slither) -def _dump_function(f: Function): +def _dump_function(f: Function) -> None: """Helper function to print nodes/ssa ir for a function or modifier""" print(f"---- {f.name} ----") for n in f.nodes: @@ -295,13 +316,13 @@ def _dump_function(f: Function): print("") -def _dump_functions(c: Contract): +def _dump_functions(c: Contract) -> None: """Helper function to print functions and modifiers of a contract""" for f in c.functions_and_modifiers: _dump_function(f) -def get_filtered_ssa(f: Union[Function, Node], flt) -> List[Operation]: +def get_filtered_ssa(f: Union[Function, Node], flt: Callable) -> List[Operation]: """Returns a list of all ssanodes filtered by filter for all nodes in function f""" if isinstance(f, Function): return [ssanode for node in f.nodes for ssanode in node.irs_ssa if flt(ssanode)] @@ -315,7 +336,7 @@ def get_ssa_of_type(f: Union[Function, Node], ssatype) -> List[Operation]: return get_filtered_ssa(f, lambda ssanode: isinstance(ssanode, ssatype)) -def test_multi_write(): +def test_multi_write() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -328,7 +349,7 @@ def test_multi_write(): verify_properties_hold(contract) -def test_single_branch_phi(): +def test_single_branch_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -343,7 +364,7 @@ def test_single_branch_phi(): verify_properties_hold(contract) -def test_basic_phi(): +def test_basic_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -360,7 +381,7 @@ def test_basic_phi(): verify_properties_hold(contract) -def test_basic_loop_phi(): +def test_basic_loop_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -376,7 +397,7 @@ def test_basic_loop_phi(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_phi_propagation_loop(): +def test_phi_propagation_loop() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -397,7 +418,7 @@ def test_phi_propagation_loop(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_free_function_properties(): +def test_free_function_properties() -> None: contract = """ pragma solidity ^0.8.11; @@ -418,7 +439,7 @@ def test_free_function_properties(): verify_properties_hold(contract) -def test_ssa_inter_transactional(): +def test_ssa_inter_transactional() -> None: source = """ pragma solidity ^0.8.11; contract A { @@ -461,7 +482,7 @@ def test_ssa_inter_transactional(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_ssa_phi_callbacks(): +def test_ssa_phi_callbacks() -> None: source = """ pragma solidity ^0.8.11; contract A { @@ -520,7 +541,7 @@ def test_ssa_phi_callbacks(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_storage_refers_to(): +def test_storage_refers_to() -> None: """Test the storage aspects of the SSA IR When declaring a var as being storage, start tracking what storage it refers_to. @@ -690,7 +711,7 @@ def test_initial_version_exists_for_state_variables_function_assign(): # temporary variable, that is then assigned to a call = get_ssa_of_type(ctor, InternalCall)[0] - assert call.function == f + assert call.node.function == f assign = get_ssa_of_type(ctor, Assignment)[0] assert assign.rvalue == call.lvalue assert isinstance(assign.lvalue, StateIRVariable)