diff --git a/setup.py b/setup.py index 857b371605..42f6af1522 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ "packaging", "prettytable>=3.3.0", "pycryptodome>=3.4.6", - "crytic-compile>=0.3.5,<0.4.0", - # "crytic-compile@git+https://github.com/crytic/crytic-compile.git@master#egg=crytic-compile", + # "crytic-compile>=0.3.5,<0.4.0", + "crytic-compile@git+https://github.com/crytic/crytic-compile.git@master#egg=crytic-compile", "web3>=6.0.0", "eth-abi>=4.0.0", "eth-typing>=3.0.0", diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 4400bc265d..aa02597d76 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -89,6 +89,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope self._is_interface: bool = False self._is_library: bool = False self._is_fully_implemented: bool = False + self._is_abstract: bool = False self._signatures: Optional[List[str]] = None self._signatures_declared: Optional[List[str]] = None @@ -199,12 +200,34 @@ def comments(self, comments: str): @property def is_fully_implemented(self) -> bool: + """ + bool: True if the contract defines all functions. + In modern Solidity, virtual functions can lack an implementation. + Prior to Solidity 0.6.0, functions like the following would be not fully implemented: + ```solidity + contract ImplicitAbstract{ + function f() public; + } + ``` + """ return self._is_fully_implemented @is_fully_implemented.setter def is_fully_implemented(self, is_fully_implemented: bool): self._is_fully_implemented = is_fully_implemented + @property + def is_abstract(self) -> bool: + """ + Note for Solidity < 0.6.0 it will always be false + bool: True if the contract is abstract. + """ + return self._is_abstract + + @is_abstract.setter + def is_abstract(self, is_abstract: bool): + self._is_abstract = is_abstract + # endregion ################################################################################### ################################################################################### @@ -983,16 +1006,14 @@ def get_enum_from_canonical_name(self, enum_name: str) -> Optional["Enum"]: def get_functions_overridden_by(self, function: "Function") -> List["Function"]: """ - Return the list of functions overriden by the function + Return the list of functions overridden by the function Args: (core.Function) Returns: list(core.Function) """ - candidatess = [c.functions_declared for c in self.inheritance] - candidates = [candidate for sublist in candidatess for candidate in sublist] - return [f for f in candidates if f.full_name == function.full_name] + return function.overrides # endregion ################################################################################### diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index d2baaf7e7b..0a6f5ae2af 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -37,7 +37,7 @@ HighLevelCallType, LibraryCallType, ) - from slither.core.declarations import Contract + from slither.core.declarations import Contract, FunctionContract from slither.core.cfg.node import Node, NodeType from slither.core.variables.variable import Variable from slither.slithir.variables.variable import SlithIRVariable @@ -46,7 +46,6 @@ from slither.slithir.operations import Operation from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope - from slither.slithir.variables.state_variable import StateIRVariable LOGGER = logging.getLogger("Function") ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) @@ -126,6 +125,9 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: self._pure: bool = False self._payable: bool = False self._visibility: Optional[str] = None + self._virtual: bool = False + self._overrides: List["FunctionContract"] = [] + self._overridden_by: List["FunctionContract"] = [] self._is_implemented: Optional[bool] = None self._is_empty: Optional[bool] = None @@ -441,6 +443,49 @@ def payable(self) -> bool: def payable(self, p: bool): self._payable = p + # endregion + ################################################################################### + ################################################################################### + # region Virtual + ################################################################################### + ################################################################################### + + @property + def is_virtual(self) -> bool: + """ + Note for Solidity < 0.6.0 it will always be false + bool: True if the function is virtual + """ + return self._virtual + + @is_virtual.setter + def is_virtual(self, v: bool): + self._virtual = v + + @property + def is_override(self) -> bool: + """ + Note for Solidity < 0.6.0 it will always be false + bool: True if the function overrides a base function + """ + return len(self._overrides) > 0 + + @property + def overridden_by(self) -> List["FunctionContract"]: + """ + List["FunctionContract"]: List of functions in child contracts that override this function + This may include distinct instances of the same function due to inheritance + """ + return self._overridden_by + + @property + def overrides(self) -> List["FunctionContract"]: + """ + List["FunctionContract"]: List of functions in parent contracts that this function overrides + This may include distinct instances of the same function due to inheritance + """ + return self._overrides + # endregion ################################################################################### ################################################################################### diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index 76220f5bae..8eca260fac 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -22,7 +22,7 @@ from slither.slithir.variables import Constant from slither.utils.colors import red from slither.utils.sarif import read_triage_info -from slither.utils.source_mapping import get_definition, get_references, get_implementation +from slither.utils.source_mapping import get_definition, get_references, get_all_implementations logger = logging.getLogger("Slither") logging.basicConfig() @@ -204,41 +204,53 @@ def offset_to_objects(self, filename_str: str, offset: int) -> Set[SourceMapping def _compute_offsets_from_thing(self, thing: SourceMapping): definition = get_definition(thing, self.crytic_compile) references = get_references(thing) - implementation = get_implementation(thing) + implementations = get_all_implementations(thing, self.contracts) for offset in range(definition.start, definition.end + 1): - if ( - isinstance(thing, TopLevel) + isinstance(thing, (TopLevel, Contract)) or ( isinstance(thing, FunctionContract) and thing.contract_declarer == thing.contract ) or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)) ): + self._offset_to_objects[definition.filename][offset].add(thing) self._offset_to_definitions[definition.filename][offset].add(definition) - self._offset_to_implementations[definition.filename][offset].add(implementation) + self._offset_to_implementations[definition.filename][offset].update(implementations) self._offset_to_references[definition.filename][offset] |= set(references) for ref in references: for offset in range(ref.start, ref.end + 1): - + is_declared_function = ( + isinstance(thing, FunctionContract) + and thing.contract_declarer == thing.contract + ) if ( isinstance(thing, TopLevel) - or ( - isinstance(thing, FunctionContract) - and thing.contract_declarer == thing.contract - ) + or is_declared_function or ( isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract) ) ): self._offset_to_objects[definition.filename][offset].add(thing) - self._offset_to_definitions[ref.filename][offset].add(definition) - self._offset_to_implementations[ref.filename][offset].add(implementation) + if is_declared_function: + # Only show the nearest lexical definition for declared contract-level functions + if ( + thing.contract.source_mapping.start + < offset + < thing.contract.source_mapping.end + ): + + self._offset_to_definitions[ref.filename][offset].add(definition) + + else: + self._offset_to_definitions[ref.filename][offset].add(definition) + + self._offset_to_implementations[ref.filename][offset].update(implementations) self._offset_to_references[ref.filename][offset] |= set(references) def _compute_offsets_to_ref_impl_decl(self): # pylint: disable=too-many-branches @@ -251,15 +263,18 @@ def _compute_offsets_to_ref_impl_decl(self): # pylint: disable=too-many-branche for contract in compilation_unit.contracts: self._compute_offsets_from_thing(contract) - for function in contract.functions: + for function in contract.functions_declared: self._compute_offsets_from_thing(function) for variable in function.local_variables: self._compute_offsets_from_thing(variable) - for modifier in contract.modifiers: + for modifier in contract.modifiers_declared: self._compute_offsets_from_thing(modifier) for variable in modifier.local_variables: self._compute_offsets_from_thing(variable) + for var in contract.state_variables: + self._compute_offsets_from_thing(var) + for st in contract.structures: self._compute_offsets_from_thing(st) @@ -268,6 +283,10 @@ def _compute_offsets_to_ref_impl_decl(self): # pylint: disable=too-many-branche for event in contract.events: self._compute_offsets_from_thing(event) + + for typ in contract.type_aliases: + self._compute_offsets_from_thing(typ) + for enum in compilation_unit.enums_top_level: self._compute_offsets_from_thing(enum) for event in compilation_unit.events_top_level: @@ -276,6 +295,14 @@ def _compute_offsets_to_ref_impl_decl(self): # pylint: disable=too-many-branche self._compute_offsets_from_thing(function) for st in compilation_unit.structures_top_level: self._compute_offsets_from_thing(st) + for var in compilation_unit.variables_top_level: + self._compute_offsets_from_thing(var) + for typ in compilation_unit.type_aliases.values(): + self._compute_offsets_from_thing(typ) + for err in compilation_unit.custom_errors: + self._compute_offsets_from_thing(err) + for event in compilation_unit.events_top_level: + self._compute_offsets_from_thing(event) for import_directive in compilation_unit.import_directives: self._compute_offsets_from_thing(import_directive) for pragma in compilation_unit.pragma_directives: diff --git a/slither/printers/summary/declaration.py b/slither/printers/summary/declaration.py index c7c4798d53..4e3a42cc1e 100644 --- a/slither/printers/summary/declaration.py +++ b/slither/printers/summary/declaration.py @@ -21,18 +21,20 @@ def output(self, _filename: str) -> Output: txt += "\n# Contracts\n" for contract in compilation_unit.contracts: txt += f"# {contract.name}\n" - txt += f"\t- Declaration: {get_definition(contract, compilation_unit.core.crytic_compile).to_detailed_str()}\n" - txt += f"\t- Implementation: {get_implementation(contract).to_detailed_str()}\n" + contract_def = get_definition(contract, compilation_unit.core.crytic_compile) + txt += f"\t- Declaration: {contract_def.to_detailed_str()}\n" + txt += f"\t- Implementation(s): {[x.to_detailed_str() for x in list(self.slither.offset_to_implementations(contract.source_mapping.filename.absolute, contract_def.start))]}\n" txt += ( f"\t- References: {[x.to_detailed_str() for x in get_references(contract)]}\n" ) txt += "\n\t## Function\n" - for func in contract.functions: + for func in contract.functions_declared: txt += f"\t\t- {func.canonical_name}\n" - txt += f"\t\t\t- Declaration: {get_definition(func, compilation_unit.core.crytic_compile).to_detailed_str()}\n" - txt += f"\t\t\t- Implementation: {get_implementation(func).to_detailed_str()}\n" + function_def = get_definition(func, compilation_unit.core.crytic_compile) + txt += f"\t\t\t- Declaration: {function_def.to_detailed_str()}\n" + txt += f"\t\t\t- Implementation(s): {[x.to_detailed_str() for x in list(self.slither.offset_to_implementations(func.source_mapping.filename.absolute, function_def.start))]}\n" txt += f"\t\t\t- References: {[x.to_detailed_str() for x in get_references(func)]}\n" txt += "\n\t## State variables\n" diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index c3161a507d..fcb6c3afa7 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -873,9 +873,7 @@ def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-lo elif isinstance(ir, NewArray): ir.lvalue.set_type(ir.array_type) elif isinstance(ir, NewContract): - contract = node.file_scope.get_contract_from_name(ir.contract_name) - assert contract - ir.lvalue.set_type(UserDefinedType(contract)) + ir.lvalue.set_type(ir.contract_name) elif isinstance(ir, NewElementaryType): ir.lvalue.set_type(ir.type) elif isinstance(ir, NewStructure): @@ -1166,7 +1164,7 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]) -> Union[Call, return n if isinstance(ins.ori, TmpNewContract): - op = NewContract(Constant(ins.ori.contract_name), ins.lvalue) + op = NewContract(ins.ori.contract_name, ins.lvalue) op.set_expression(ins.expression) op.call_id = ins.call_id if ins.call_value: @@ -1211,7 +1209,7 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]) -> Union[Call, internalcall.set_expression(ins.expression) return internalcall - raise Exception(f"Not extracted {type(ins.called)} {ins}") # pylint: disable=bad-option-value + raise SlithIRError(f"Not extracted {type(ins.called)} {ins}") # endregion @@ -1724,6 +1722,7 @@ def convert_type_of_high_and_internal_level_call( Returns: Potential new IR """ + func = None if isinstance(ir, InternalCall): candidates: List[Function] @@ -2019,6 +2018,9 @@ def _find_source_mapping_references(irs: List[Operation]) -> None: if isinstance(ir, NewContract): ir.contract_created.references.append(ir.expression.source_mapping) + if isinstance(ir, HighLevelCall): + ir.function.references.append(ir.expression.source_mapping) + # endregion ################################################################################### diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index 9ed00b1ac0..928cbd0135 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -3,6 +3,7 @@ from slither.core.declarations import Function from slither.core.declarations.contract import Contract from slither.core.variables import Variable +from slither.core.solidity_types import UserDefinedType from slither.slithir.operations import Call, OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant @@ -13,7 +14,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes def __init__( self, - contract_name: Constant, + contract_name: UserDefinedType, lvalue: Union[TemporaryVariableSSA, TemporaryVariable], names: Optional[List[str]] = None, ) -> None: @@ -23,7 +24,9 @@ def __init__( For calls of the form f({argName1 : arg1, ...}), the names of parameters listed in call order. Otherwise, None. """ - assert isinstance(contract_name, Constant) + assert isinstance( + contract_name.type, Contract + ), f"contract_name is {contract_name} of type {type(contract_name)}" assert is_valid_lvalue(lvalue) super().__init__(names=names) self._contract_name = contract_name @@ -58,7 +61,7 @@ def call_salt(self, s): self._call_salt = s @property - def contract_name(self) -> Constant: + def contract_name(self) -> UserDefinedType: return self._contract_name @property @@ -69,10 +72,7 @@ def read(self) -> List[Any]: @property def contract_created(self) -> Contract: - contract_name = self.contract_name - contract_instance = self.node.file_scope.get_contract_from_name(contract_name) - assert contract_instance - return contract_instance + return self.contract_name.type ################################################################################### ################################################################################### diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index dd7f0342b7..660aab1767 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,6 +1,6 @@ import logging import re -from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set, Sequence +from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set, Sequence, Tuple from slither.core.declarations import ( Modifier, @@ -65,7 +65,8 @@ def __init__( # use to remap inheritance id self._remapping: Dict[str, str] = {} - self.baseContracts: List[str] = [] + # (referencedDeclaration, offset) + self.baseContracts: List[Tuple[int, str]] = [] self.baseConstructorContractsCalled: List[str] = [] self._linearized_base_contracts: List[int] @@ -175,6 +176,9 @@ def _parse_contract_info(self) -> None: self._contract.is_fully_implemented = attributes["fullyImplemented"] self._linearized_base_contracts = attributes["linearizedBaseContracts"] + if "abstract" in attributes: + self._contract.is_abstract = attributes["abstract"] + # Parse base contract information self._parse_base_contract_info() @@ -202,7 +206,9 @@ def _parse_base_contract_info(self) -> None: # pylint: disable=too-many-branche # Obtain our contract reference and add it to our base contract list referencedDeclaration = base_contract["baseName"]["referencedDeclaration"] - self.baseContracts.append(referencedDeclaration) + self.baseContracts.append( + (referencedDeclaration, base_contract["baseName"]["src"]) + ) # If we have defined arguments in our arguments object, this is a constructor invocation. # (note: 'arguments' can be [], which is not the same as None. [] implies a constructor was @@ -234,7 +240,10 @@ def _parse_base_contract_info(self) -> None: # pylint: disable=too-many-branche referencedDeclaration = base_contract_items[0]["attributes"][ "referencedDeclaration" ] - self.baseContracts.append(referencedDeclaration) + + self.baseContracts.append( + (referencedDeclaration, base_contract_items[0]["src"]) + ) # If we have an 'attributes'->'arguments' which is None, this is not a constructor call. if ( diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 3b02ee923c..4ff77d008b 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -208,8 +208,6 @@ def _analyze_attributes(self) -> None: else: attributes = self._functionNotParsed["attributes"] - if "payable" in attributes: - self._function.payable = attributes["payable"] if "stateMutability" in attributes: if attributes["stateMutability"] == "payable": self._function.payable = True @@ -243,6 +241,34 @@ def _analyze_attributes(self) -> None: if "payable" in attributes: self._function.payable = attributes["payable"] + if "baseFunctions" in attributes: + overrides_ids = attributes["baseFunctions"] + if len(overrides_ids) > 0: + for f_id in overrides_ids: + funcs = self.slither_parser.functions_by_id[f_id] + for f in funcs: + # Do not consider leaf contracts as overrides. + # B is A { function a() override {} } and C is A { function a() override {} } override A.a(), not each other. + if ( + f.contract == self._function.contract + or f.contract in self._function.contract.inheritance + ): + self._function.overrides.append(f) + f.overridden_by.append(self._function) + + # Attaches reference to override specifier e.g. X is referenced by `function a() override(X)` + if "overrides" in attributes and isinstance(attributes["overrides"], dict): + for override in attributes["overrides"].get("overrides", []): + refId = override["referencedDeclaration"] + overridden_contract = self.slither_parser.contracts_by_id.get(refId, None) + if overridden_contract: + overridden_contract.add_reference_from_raw_source( + override["src"], self.compilation_unit + ) + + if "virtual" in attributes: + self._function.is_virtual = attributes["virtual"] + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 4100d16ad7..4991984ff7 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -614,20 +614,9 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) assert type_name[caller_context.get_key()] == "UserDefinedTypeName" - if is_compact_ast: - - # Changed introduced in Solidity 0.8 - # see https://github.com/crytic/slither/issues/794 - - # TODO explore more the changes introduced in 0.8 and the usage of pathNode/IdentifierPath - if "name" not in type_name: - assert "pathNode" in type_name and "name" in type_name["pathNode"] - contract_name = type_name["pathNode"]["name"] - else: - contract_name = type_name["name"] - else: - contract_name = type_name["attributes"]["name"] - new = NewContract(contract_name) + contract_type = parse_type(type_name, caller_context) + assert isinstance(contract_type, UserDefinedType) + new = NewContract(contract_type) new.set_offset(src, caller_context.compilation_unit) return new diff --git a/slither/solc_parsing/slither_compilation_unit_solc.py b/slither/solc_parsing/slither_compilation_unit_solc.py index 02d8307024..721cf69fc8 100644 --- a/slither/solc_parsing/slither_compilation_unit_solc.py +++ b/slither/solc_parsing/slither_compilation_unit_solc.py @@ -1,3 +1,4 @@ +from collections import defaultdict import json import logging import os @@ -7,7 +8,7 @@ from slither.analyses.data_dependency.data_dependency import compute_dependency from slither.core.compilation_unit import SlitherCompilationUnit -from slither.core.declarations import Contract +from slither.core.declarations import Contract, Function from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.event_top_level import EventTopLevel @@ -79,7 +80,8 @@ def __init__(self, compilation_unit: SlitherCompilationUnit) -> None: self._compilation_unit: SlitherCompilationUnit = compilation_unit - self._contracts_by_id: Dict[int, ContractSolc] = {} + self._contracts_by_id: Dict[int, Contract] = {} + self._functions_by_id: Dict[int, List[Function]] = defaultdict(list) self._parsed = False self._analyzed = False self._is_compact_ast = False @@ -105,6 +107,7 @@ def all_functions_and_modifiers_parser(self) -> List[FunctionSolc]: def add_function_or_modifier_parser(self, f: FunctionSolc) -> None: self._all_functions_and_modifier_parser.append(f) + self._functions_by_id[f.underlying_function.id].append(f.underlying_function) @property def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]: @@ -114,6 +117,14 @@ def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]: def slither_parser(self) -> "SlitherCompilationUnitSolc": return self + @property + def contracts_by_id(self) -> Dict[int, Contract]: + return self._contracts_by_id + + @property + def functions_by_id(self) -> Dict[int, List[Function]]: + return self._functions_by_id + ################################################################################### ################################################################################### # region AST @@ -480,13 +491,16 @@ def resolve_remapping_and_renaming(contract_parser: ContractSolc, want: str) -> else: missing_inheritance = i - # Resolve immediate base contracts. - for i in contract_parser.baseContracts: + # Resolve immediate base contracts and attach references. + for (i, src) in contract_parser.baseContracts: if i in contract_parser.remapping: target = resolve_remapping_and_renaming(contract_parser, i) fathers.append(target) + target.add_reference_from_raw_source(src, self.compilation_unit) elif i in self._contracts_by_id: - fathers.append(self._contracts_by_id[i]) + target = self._contracts_by_id[i] + fathers.append(target) + target.add_reference_from_raw_source(src, self.compilation_unit) else: missing_inheritance = i diff --git a/slither/utils/source_mapping.py b/slither/utils/source_mapping.py index b117cd5f78..9bf772894e 100644 --- a/slither/utils/source_mapping.py +++ b/slither/utils/source_mapping.py @@ -1,7 +1,17 @@ -from typing import List +from typing import List, Set from crytic_compile import CryticCompile -from slither.core.declarations import Contract, Function, Enum, Event, Import, Pragma, Structure -from slither.core.solidity_types.type import Type +from slither.core.declarations import ( + Contract, + Function, + Enum, + Event, + Import, + Pragma, + Structure, + CustomError, + FunctionContract, +) +from slither.core.solidity_types import Type, TypeAlias from slither.core.source_mapping.source_mapping import Source, SourceMapping from slither.core.variables.variable import Variable from slither.exceptions import SlitherError @@ -15,6 +25,10 @@ def get_definition(target: SourceMapping, crytic_compile: CryticCompile) -> Sour pattern = "import" elif isinstance(target, Pragma): pattern = "pragma" # todo maybe return with the while pragma statement + elif isinstance(target, CustomError): + pattern = "error" + elif isinstance(target, TypeAlias): + pattern = "type" elif isinstance(target, Type): raise SlitherError("get_definition_generic not implemented for types") else: @@ -52,5 +66,34 @@ def get_implementation(target: SourceMapping) -> Source: return target.source_mapping +def get_all_implementations(target: SourceMapping, contracts: List[Contract]) -> Set[Source]: + """ + Get all implementations of a contract or function, accounting for inheritance and overrides + """ + implementations = set() + # Abstract contracts and interfaces are implemented by their children + if isinstance(target, Contract): + is_interface = target.is_interface + is_implicitly_abstract = not target.is_fully_implemented + is_explicitly_abstract = target.is_abstract + if is_interface or is_implicitly_abstract or is_explicitly_abstract: + for contract in contracts: + if target in contract.immediate_inheritance: + implementations.add(contract.source_mapping) + + # Parent's virtual functions may be overridden by children + elif isinstance(target, FunctionContract): + for over in target.overridden_by: + implementations.add(over.source_mapping) + # Only show implemented virtual functions + if not target.is_virtual or target.is_implemented: + implementations.add(get_implementation(target)) + + else: + implementations.add(get_implementation(target)) + + return implementations + + def get_references(target: SourceMapping) -> List[Source]: return target.references diff --git a/tests/e2e/solc_parsing/test_ast_parsing.py b/tests/e2e/solc_parsing/test_ast_parsing.py index 522d49cce6..96346bf368 100644 --- a/tests/e2e/solc_parsing/test_ast_parsing.py +++ b/tests/e2e/solc_parsing/test_ast_parsing.py @@ -467,6 +467,8 @@ def make_version(minor: int, patch_min: int, patch_max: int) -> List[str]: ), Test("user_defined_operators-0.8.19.sol", ["0.8.19"]), Test("aliasing/main.sol", ["0.8.19"]), + Test("aliasing/alias-unit-NewContract.sol", ["0.8.19"]), + Test("aliasing/alias-symbol-NewContract.sol", ["0.8.19"]), Test("type-aliases.sol", ["0.8.19"]), Test("enum-max-min.sol", ["0.8.19"]), Test("event-top-level.sol", ["0.8.22"]), diff --git a/tests/e2e/solc_parsing/test_data/aliasing/MyContract.sol b/tests/e2e/solc_parsing/test_data/aliasing/MyContract.sol new file mode 100644 index 0000000000..ab7b85a102 --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/aliasing/MyContract.sol @@ -0,0 +1 @@ +contract MyContract {} \ No newline at end of file diff --git a/tests/e2e/solc_parsing/test_data/aliasing/alias-symbol-NewContract.sol b/tests/e2e/solc_parsing/test_data/aliasing/alias-symbol-NewContract.sol new file mode 100644 index 0000000000..1f545461b4 --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/aliasing/alias-symbol-NewContract.sol @@ -0,0 +1,8 @@ +import {MyContract as MyAliasedContract} from "./MyContract.sol"; + +contract Test { +MyAliasedContract c; + constructor() { + c = new MyAliasedContract(); + } +} diff --git a/tests/e2e/solc_parsing/test_data/aliasing/alias-unit-NewContract.sol b/tests/e2e/solc_parsing/test_data/aliasing/alias-unit-NewContract.sol new file mode 100644 index 0000000000..f46693d09f --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/aliasing/alias-unit-NewContract.sol @@ -0,0 +1,9 @@ + +import "./MyContract.sol" as MyAliasedContract; + +contract Test { +MyAliasedContract.MyContract c; + constructor() { + c = new MyAliasedContract.MyContract(); + } +} \ No newline at end of file diff --git a/tests/e2e/solc_parsing/test_data/compile/aliasing/alias-symbol-NewContract.sol-0.8.19-compact.zip b/tests/e2e/solc_parsing/test_data/compile/aliasing/alias-symbol-NewContract.sol-0.8.19-compact.zip new file mode 100644 index 0000000000..b630fb8c49 Binary files /dev/null and b/tests/e2e/solc_parsing/test_data/compile/aliasing/alias-symbol-NewContract.sol-0.8.19-compact.zip differ diff --git a/tests/e2e/solc_parsing/test_data/compile/aliasing/alias-unit-NewContract.sol-0.8.19-compact.zip b/tests/e2e/solc_parsing/test_data/compile/aliasing/alias-unit-NewContract.sol-0.8.19-compact.zip new file mode 100644 index 0000000000..da51136356 Binary files /dev/null and b/tests/e2e/solc_parsing/test_data/compile/aliasing/alias-unit-NewContract.sol-0.8.19-compact.zip differ diff --git a/tests/e2e/solc_parsing/test_data/expected/aliasing/alias-symbol-NewContract.sol-0.8.19-compact.json b/tests/e2e/solc_parsing/test_data/expected/aliasing/alias-symbol-NewContract.sol-0.8.19-compact.json new file mode 100644 index 0000000000..b0ce256111 --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/expected/aliasing/alias-symbol-NewContract.sol-0.8.19-compact.json @@ -0,0 +1,6 @@ +{ + "MyContract": {}, + "Test": { + "constructor()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n" + } +} \ No newline at end of file diff --git a/tests/e2e/solc_parsing/test_data/expected/aliasing/alias-unit-NewContract.sol-0.8.19-compact.json b/tests/e2e/solc_parsing/test_data/expected/aliasing/alias-unit-NewContract.sol-0.8.19-compact.json new file mode 100644 index 0000000000..b0ce256111 --- /dev/null +++ b/tests/e2e/solc_parsing/test_data/expected/aliasing/alias-unit-NewContract.sol-0.8.19-compact.json @@ -0,0 +1,6 @@ +{ + "MyContract": {}, + "Test": { + "constructor()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: EXPRESSION 1\n\"];\n}\n" + } +} \ No newline at end of file diff --git a/tests/unit/core/test_contract_declaration.py b/tests/unit/core/test_contract_declaration.py index 28d0fae550..e3f644b547 100644 --- a/tests/unit/core/test_contract_declaration.py +++ b/tests/unit/core/test_contract_declaration.py @@ -11,26 +11,36 @@ def test_abstract_contract(solc_binary_path) -> None: solc_path = solc_binary_path("0.8.0") slither = Slither(Path(CONTRACT_DECL_TEST_ROOT, "abstract.sol").as_posix(), solc=solc_path) - assert not slither.contracts[0].is_fully_implemented + explicit_abstract = slither.contracts[0] + assert not explicit_abstract.is_fully_implemented + assert explicit_abstract.is_abstract solc_path = solc_binary_path("0.5.0") slither = Slither( Path(CONTRACT_DECL_TEST_ROOT, "implicit_abstract.sol").as_posix(), solc=solc_path ) - assert not slither.contracts[0].is_fully_implemented + implicit_abstract = slither.get_contract_from_name("ImplicitAbstract")[0] + assert not implicit_abstract.is_fully_implemented + # This only is expected to work for newer versions of Solidity + assert not implicit_abstract.is_abstract slither = Slither( Path(CONTRACT_DECL_TEST_ROOT, "implicit_abstract.sol").as_posix(), solc_force_legacy_json=True, solc=solc_path, ) - assert not slither.contracts[0].is_fully_implemented + implicit_abstract = slither.get_contract_from_name("ImplicitAbstract")[0] + assert not implicit_abstract.is_fully_implemented + # This only is expected to work for newer versions of Solidity + assert not implicit_abstract.is_abstract def test_concrete_contract(solc_binary_path) -> None: solc_path = solc_binary_path("0.8.0") slither = Slither(Path(CONTRACT_DECL_TEST_ROOT, "concrete.sol").as_posix(), solc=solc_path) - assert slither.contracts[0].is_fully_implemented + concrete = slither.get_contract_from_name("Concrete")[0] + assert concrete.is_fully_implemented + assert not concrete.is_abstract solc_path = solc_binary_path("0.5.0") slither = Slither( @@ -38,7 +48,9 @@ def test_concrete_contract(solc_binary_path) -> None: solc_force_legacy_json=True, solc=solc_path, ) - assert slither.contracts[0].is_fully_implemented + concrete_old = slither.get_contract_from_name("ConcreteOld")[0] + assert concrete_old.is_fully_implemented + assert not concrete_old.is_abstract def test_private_variable(solc_binary_path) -> None: diff --git a/tests/unit/core/test_data/src_mapping/TopLevelReferences.sol b/tests/unit/core/test_data/src_mapping/TopLevelReferences.sol new file mode 100644 index 0000000000..68f7b48ad9 --- /dev/null +++ b/tests/unit/core/test_data/src_mapping/TopLevelReferences.sol @@ -0,0 +1,16 @@ +type T is uint256; +uint constant U = 1; +error V(T); +event W(T); + +contract E { + type X is int256; + function f() public { + T t = T.wrap(U); + if (T.unwrap(t) == 0) { + revert V(t); + } + emit W(t); + X x = X.wrap(1); + } +} \ No newline at end of file diff --git a/tests/unit/core/test_data/virtual_overrides.sol b/tests/unit/core/test_data/virtual_overrides.sol new file mode 100644 index 0000000000..fa9e1c388c --- /dev/null +++ b/tests/unit/core/test_data/virtual_overrides.sol @@ -0,0 +1,65 @@ +contract Test { + function myVirtualFunction() virtual external { + } +} + +contract A is Test { + function myVirtualFunction() virtual override external { + } +} + +contract B is A { + function myVirtualFunction() override external { + } + +} + +contract C is Test { + function myVirtualFunction() override external { + } +} + +contract X is Test { + function myVirtualFunction() virtual override external { + } +} + +contract Y { + function myVirtualFunction() virtual external { + } +} + +contract Z is Y, X{ + function myVirtualFunction() virtual override(Y, X) external { + } +} + + +abstract contract Name { + constructor() { + + } +} + +contract Name2 is Name { + constructor() { + + } +} + +abstract contract Test2 { + function f() virtual public; +} + +contract A2 is Test2 { + function f() virtual override public { + } +} + +abstract contract I { + function a() public virtual {} +} +contract J is I {} +contract K is J { + function a() public override {} +} \ No newline at end of file diff --git a/tests/unit/core/test_source_mapping.py b/tests/unit/core/test_source_mapping.py index 9577014297..298a192d5f 100644 --- a/tests/unit/core/test_source_mapping.py +++ b/tests/unit/core/test_source_mapping.py @@ -1,17 +1,27 @@ from pathlib import Path - +import pytest from slither import Slither -from slither.core.declarations import Function +from slither.core.declarations import Function, CustomErrorTopLevel, EventTopLevel +from slither.core.solidity_types.type_alias import TypeAliasTopLevel, TypeAliasContract +from slither.core.variables.top_level_variable import TopLevelVariable TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" SRC_MAPPING_TEST_ROOT = Path(TEST_DATA_DIR, "src_mapping") - -def test_source_mapping(solc_binary_path): - solc_path = solc_binary_path("0.6.12") +# Ensure issue fixed in https://github.com/crytic/crytic-compile/pull/554 does not regress in Slither's reference lookup. +@pytest.mark.parametrize("solc_version", ["0.6.12", "0.8.7", "0.8.8"]) +def test_source_mapping_inheritance(solc_binary_path, solc_version): + solc_path = solc_binary_path(solc_version) file = Path(SRC_MAPPING_TEST_ROOT, "inheritance.sol").as_posix() slither = Slither(file, solc=solc_path) + # 3 reference to A in inheritance `contract $ is A` + assert {(x.start, x.end) for x in slither.offset_to_references(file, 9)} == { + (121, 122), + (185, 186), + (299, 300), + } + # Check if A.f() is at the offset 27 functions = slither.offset_to_objects(file, 27) assert len(functions) == 1 @@ -23,8 +33,12 @@ def test_source_mapping(solc_binary_path): assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 27)} == {(26, 28)} # Only one reference for A.f(), in A.test() assert {(x.start, x.end) for x in slither.offset_to_references(file, 27)} == {(92, 93)} - # Only one implementation for A.f(), in A.test() - assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 27)} == {(17, 53)} + # Three overridden implementation of A.f(), in A.test() + assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 27)} == { + (17, 53), + (129, 166), + (193, 230), + } # Check if C.f() is at the offset 203 functions = slither.offset_to_objects(file, 203) @@ -52,11 +66,9 @@ def test_source_mapping(solc_binary_path): assert isinstance(function, Function) assert function.canonical_name in ["A.f()", "B.f()", "C.f()"] - # There are three definitions possible (in A, B or C) + # There is one definition in the lexical scope of A assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 93)} == { (26, 28), - (202, 204), - (138, 140), } # There are two references possible (in A.test() or C.test2() ) @@ -113,6 +125,62 @@ def test_references_user_defined_types_when_casting(solc_binary_path): assert lines == [12, 18] +def test_source_mapping_top_level_defs(solc_binary_path): + solc_path = solc_binary_path("0.8.24") + file = Path(SRC_MAPPING_TEST_ROOT, "TopLevelReferences.sol").as_posix() + slither = Slither(file, solc=solc_path) + + # Check if T is at the offset 5 + types = slither.offset_to_objects(file, 5) + assert len(types) == 1 + type_ = types.pop() + assert isinstance(type_, TypeAliasTopLevel) + assert type_.name == "T" + + assert {(x.start, x.end) for x in slither.offset_to_references(file, 5)} == { + (48, 49), + (60, 61), + (134, 135), + (140, 141), + (163, 164), + } + + # Check if U is at the offset 33 + constants = slither.offset_to_objects(file, 33) + assert len(constants) == 1 + constant = constants.pop() + assert isinstance(constant, TopLevelVariable) + assert constant.name == "U" + assert {(x.start, x.end) for x in slither.offset_to_references(file, 33)} == {(147, 148)} + + # Check if V is at the offset 46 + errors = slither.offset_to_objects(file, 46) + assert len(errors) == 1 + error = errors.pop() + assert isinstance(error, CustomErrorTopLevel) + assert error.name == "V" + assert {(x.start, x.end) for x in slither.offset_to_references(file, 46)} == {(202, 203)} + + # Check if W is at the offset 58 + events = slither.offset_to_objects(file, 58) + assert len(events) == 1 + event = events.pop() + assert isinstance(event, EventTopLevel) + assert event.name == "W" + assert {(x.start, x.end) for x in slither.offset_to_references(file, 58)} == {(231, 232)} + + # Check if X is at the offset 87 + types = slither.offset_to_objects(file, 87) + assert len(types) == 1 + type_ = types.pop() + assert isinstance(type_, TypeAliasContract) + assert type_.name == "X" + assert {(x.start, x.end) for x in slither.offset_to_references(file, 87)} == { + (245, 246), + (251, 252), + } + + def test_references_self_identifier(): """ Tests that shadowing state variables with local variables does not affect references. diff --git a/tests/unit/core/test_virtual_overrides.py b/tests/unit/core/test_virtual_overrides.py new file mode 100644 index 0000000000..a5ca4a8657 --- /dev/null +++ b/tests/unit/core/test_virtual_overrides.py @@ -0,0 +1,151 @@ +from pathlib import Path +from slither import Slither + +TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data" + + +def test_overrides(solc_binary_path) -> None: + # pylint: disable=too-many-statements,too-many-locals + solc_path = solc_binary_path("0.8.15") + slither = Slither(Path(TEST_DATA_DIR, "virtual_overrides.sol").as_posix(), solc=solc_path) + + test = slither.get_contract_from_name("Test")[0] + test_virtual_func = test.get_function_from_full_name("myVirtualFunction()") + assert test_virtual_func.is_virtual + assert not test_virtual_func.is_override + x = test.get_functions_overridden_by(test_virtual_func) + assert len(x) == 0 + x = test_virtual_func.overridden_by + assert len(x) == 5 + assert set(i.canonical_name for i in x) == set( + ["A.myVirtualFunction()", "C.myVirtualFunction()", "X.myVirtualFunction()"] + ) + + a = slither.get_contract_from_name("A")[0] + a_virtual_func = a.get_function_from_full_name("myVirtualFunction()") + assert a_virtual_func.is_virtual + assert a_virtual_func.is_override + x = a.get_functions_overridden_by(a_virtual_func) + assert len(x) == 2 + assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"]) + + b = slither.get_contract_from_name("B")[0] + b_virtual_func = b.get_function_from_full_name("myVirtualFunction()") + assert not b_virtual_func.is_virtual + assert b_virtual_func.is_override + x = b.get_functions_overridden_by(b_virtual_func) + assert len(x) == 2 + assert set(i.canonical_name for i in x) == set(["A.myVirtualFunction()"]) + assert len(b_virtual_func.overridden_by) == 0 + + c = slither.get_contract_from_name("C")[0] + c_virtual_func = c.get_function_from_full_name("myVirtualFunction()") + assert not c_virtual_func.is_virtual + assert c_virtual_func.is_override + x = c.get_functions_overridden_by(c_virtual_func) + assert len(x) == 2 + # C should not override B as they are distinct leaves in the inheritance tree + assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"]) + + y = slither.get_contract_from_name("Y")[0] + y_virtual_func = y.get_function_from_full_name("myVirtualFunction()") + assert y_virtual_func.is_virtual + assert not y_virtual_func.is_override + x = y_virtual_func.overridden_by + assert len(x) == 1 + assert x[0].canonical_name == "Z.myVirtualFunction()" + + z = slither.get_contract_from_name("Z")[0] + z_virtual_func = z.get_function_from_full_name("myVirtualFunction()") + assert z_virtual_func.is_virtual + assert z_virtual_func.is_override + x = z.get_functions_overridden_by(z_virtual_func) + assert len(x) == 4 + assert set(i.canonical_name for i in x) == set( + ["Y.myVirtualFunction()", "X.myVirtualFunction()"] + ) + + k = slither.get_contract_from_name("K")[0] + k_virtual_func = k.get_function_from_full_name("a()") + assert not k_virtual_func.is_virtual + assert k_virtual_func.is_override + assert len(k_virtual_func.overrides) == 3 + x = k_virtual_func.overrides + assert set(i.canonical_name for i in x) == set(["I.a()"]) + + i = slither.get_contract_from_name("I")[0] + i_virtual_func = i.get_function_from_full_name("a()") + assert i_virtual_func.is_virtual + assert not i_virtual_func.is_override + assert len(i_virtual_func.overrides) == 0 + x = i_virtual_func.overridden_by + assert len(x) == 1 + assert x[0].canonical_name == "K.a()" + + +def test_virtual_override_references_and_implementations(solc_binary_path) -> None: + solc_path = solc_binary_path("0.8.15") + file = Path(TEST_DATA_DIR, "virtual_overrides.sol").as_posix() + slither = Slither(file, solc=solc_path) + funcs = slither.offset_to_objects(file, 29) + assert len(funcs) == 1 + func = funcs.pop() + assert func.canonical_name == "Test.myVirtualFunction()" + assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 29)} == { + (20, 73), + (102, 164), + (274, 328), + (357, 419), + } + + funcs = slither.offset_to_objects(file, 111) + assert len(funcs) == 1 + func = funcs.pop() + assert func.canonical_name == "A.myVirtualFunction()" + # A.myVirtualFunction() is implemented in A and also overridden in B + assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 111)} == { + (102, 164), + (190, 244), + } + + # X is inherited by Z and Z.myVirtualFunction() overrides X.myVirtualFunction() + assert {(x.start, x.end) for x in slither.offset_to_references(file, 341)} == { + (514, 515), + (570, 571), + } + # The reference to X in inheritance specifier is the definition of Z + assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 514)} == {(341, 343)} + # The reference to X in the function override specifier is the definition of Z + assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 570)} == {(341, 343)} + + # Y is inherited by Z and Z.myVirtualFunction() overrides Y.myVirtualFunction() + assert {(x.start, x.end) for x in slither.offset_to_references(file, 432)} == { + (511, 512), + (567, 568), + } + # The reference to Y in inheritance specifier is the definition of Z + assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 511)} == {(432, 434)} + # The reference to Y in the function override specifier is the definition of Z + assert {(x.start, x.end) for x in slither.offset_to_definitions(file, 567)} == {(432, 434)} + + # Name is abstract and has no implementation. It is inherited and implemented by Name2 + assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 612)} == {(657, 718)} + + +def test_virtual_is_implemented(solc_binary_path): + solc_path = solc_binary_path("0.8.15") + file = Path(TEST_DATA_DIR, "virtual_overrides.sol").as_posix() + slither = Slither(file, solc=solc_path) + + test2 = slither.get_contract_from_name("Test2")[0] + f = test2.get_function_from_full_name("f()") + assert f.is_virtual + assert not f.is_implemented + + a2 = slither.get_contract_from_name("A2")[0] + f = a2.get_function_from_full_name("f()") + assert f.is_virtual + assert f.is_implemented + + # Test.2f() is not implemented, but A2 inherits from Test2 and overrides f() + assert {(x.start, x.end) for x in slither.offset_to_implementations(file, 759)} == {(809, 853)} diff --git a/tests/unit/slithir/test_implicit_returns.py b/tests/unit/slithir/test_implicit_returns.py index 2e72a010e5..551f869185 100644 --- a/tests/unit/slithir/test_implicit_returns.py +++ b/tests/unit/slithir/test_implicit_returns.py @@ -10,7 +10,8 @@ ) -def test_with_explicit_return(slither_from_solidity_source) -> None: +@pytest.mark.parametrize("legacy", [True, False]) +def test_with_explicit_return(slither_from_solidity_source, legacy) -> None: source = """ contract Contract { function foo(int x) public returns (int y) { @@ -22,31 +23,31 @@ def test_with_explicit_return(slither_from_solidity_source) -> None: } } """ - for legacy in [True, False]: - with slither_from_solidity_source(source, legacy=legacy) as slither: - c: Contract = slither.get_contract_from_name("Contract")[0] - f: Function = c.functions[0] - node_if: Node = f.nodes[1] - node_true = node_if.son_true - node_false = node_if.son_false - assert node_true.type == NodeType.RETURN - assert isinstance(node_true.irs[0], Return) - assert node_true.irs[0].values[0] == f.get_local_variable_from_name("x") - assert len(node_true.sons) == 0 - node_end_if = node_false.sons[0] - assert node_end_if.type == NodeType.ENDIF - assert node_end_if.sons[0].type == NodeType.RETURN - node_ret = node_end_if.sons[0] - assert isinstance(node_ret.irs[0], Return) - assert node_ret.irs[0].values[0] == f.get_local_variable_from_name("y") + with slither_from_solidity_source(source, legacy=legacy) as slither: + c: Contract = slither.get_contract_from_name("Contract")[0] + f: Function = c.functions[0] + node_if: Node = f.nodes[1] + node_true = node_if.son_true + node_false = node_if.son_false + assert node_true.type == NodeType.RETURN + assert isinstance(node_true.irs[0], Return) + assert node_true.irs[0].values[0] == f.get_local_variable_from_name("x") + assert len(node_true.sons) == 0 + node_end_if = node_false.sons[0] + assert node_end_if.type == NodeType.ENDIF + assert node_end_if.sons[0].type == NodeType.RETURN + node_ret = node_end_if.sons[0] + assert isinstance(node_ret.irs[0], Return) + assert node_ret.irs[0].values[0] == f.get_local_variable_from_name("y") -def test_return_multiple_with_struct(slither_from_solidity_source) -> None: +@pytest.mark.parametrize("legacy", [True, False]) +def test_return_multiple_with_struct(slither_from_solidity_source, legacy) -> None: source = """ struct St { uint256 value; } - + contract Contract { function foo(St memory x) public returns (St memory y, uint256 z) { z = x.value; @@ -54,16 +55,15 @@ def test_return_multiple_with_struct(slither_from_solidity_source) -> None: } } """ - for legacy in [True, False]: - with slither_from_solidity_source(source, legacy=legacy) as slither: - c: Contract = slither.get_contract_from_name("Contract")[0] - f: Function = c.functions[0] - assert len(f.nodes) == 4 - node = f.nodes[3] - assert node.type == NodeType.RETURN - assert isinstance(node.irs[0], Return) - assert node.irs[0].values[0] == f.get_local_variable_from_name("y") - assert node.irs[0].values[1] == f.get_local_variable_from_name("z") + with slither_from_solidity_source(source, legacy=legacy) as slither: + c: Contract = slither.get_contract_from_name("Contract")[0] + f: Function = c.functions[0] + assert len(f.nodes) == 4 + node = f.nodes[3] + assert node.type == NodeType.RETURN + assert isinstance(node.irs[0], Return) + assert node.irs[0].values[0] == f.get_local_variable_from_name("y") + assert node.irs[0].values[1] == f.get_local_variable_from_name("z") def test_nested_ifs_with_loop_legacy(slither_from_solidity_source) -> None: @@ -149,7 +149,8 @@ def test_nested_ifs_with_loop_compact(slither_from_solidity_source) -> None: @pytest.mark.xfail # Explicit returns inside assembly are currently not parsed as return nodes -def test_assembly_switch_cases(slither_from_solidity_source): +@pytest.mark.parametrize("legacy", [True, False]) +def test_assembly_switch_cases(slither_from_solidity_source, legacy): source = """ contract Contract { function foo(uint a) public returns (uint x) { @@ -164,28 +165,28 @@ def test_assembly_switch_cases(slither_from_solidity_source): } } """ - for legacy in [True, False]: - with slither_from_solidity_source(source, solc_version="0.8.0", legacy=legacy) as slither: - c: Contract = slither.get_contract_from_name("Contract")[0] - f = c.functions[0] - if legacy: - node = f.nodes[2] - assert node.type == NodeType.RETURN - assert isinstance(node.irs[0], Return) - assert node.irs[0].values[0] == f.get_local_variable_from_name("x") - else: - node_end_if = f.nodes[5] - assert node_end_if.sons[0].type == NodeType.RETURN - node_implicit = node_end_if.sons[0] - assert isinstance(node_implicit.irs[0], Return) - assert node_implicit.irs[0].values[0] == f.get_local_variable_from_name("x") - # This part will fail until issue #1927 is fixed - node_explicit = f.nodes[10] - assert node_explicit.type == NodeType.RETURN - assert len(node_explicit.sons) == 0 + with slither_from_solidity_source(source, solc_version="0.8.0", legacy=legacy) as slither: + c: Contract = slither.get_contract_from_name("Contract")[0] + f = c.functions[0] + if legacy: + node = f.nodes[2] + assert node.type == NodeType.RETURN + assert isinstance(node.irs[0], Return) + assert node.irs[0].values[0] == f.get_local_variable_from_name("x") + else: + node_end_if = f.nodes[5] + assert node_end_if.sons[0].type == NodeType.RETURN + node_implicit = node_end_if.sons[0] + assert isinstance(node_implicit.irs[0], Return) + assert node_implicit.irs[0].values[0] == f.get_local_variable_from_name("x") + # This part will fail until issue #1927 is fixed + node_explicit = f.nodes[10] + assert node_explicit.type == NodeType.RETURN + assert len(node_explicit.sons) == 0 -def test_issue_1846_ternary_in_ternary(slither_from_solidity_source): +@pytest.mark.parametrize("legacy", [True, False]) +def test_issue_1846_ternary_in_ternary(slither_from_solidity_source, legacy): source = """ contract Contract { function foo(uint x) public returns (uint y) { @@ -193,14 +194,13 @@ def test_issue_1846_ternary_in_ternary(slither_from_solidity_source): } } """ - for legacy in [True, False]: - with slither_from_solidity_source(source, legacy=legacy) as slither: - c: Contract = slither.get_contract_from_name("Contract")[0] - f = c.functions[0] - node_end_if = f.nodes[3] - assert node_end_if.type == NodeType.ENDIF - assert len(node_end_if.sons) == 1 - node_ret = node_end_if.sons[0] - assert node_ret.type == NodeType.RETURN - assert isinstance(node_ret.irs[0], Return) - assert node_ret.irs[0].values[0] == f.get_local_variable_from_name("y") + with slither_from_solidity_source(source, legacy=legacy) as slither: + c: Contract = slither.get_contract_from_name("Contract")[0] + f = c.functions[0] + node_end_if = f.nodes[3] + assert node_end_if.type == NodeType.ENDIF + assert len(node_end_if.sons) == 1 + node_ret = node_end_if.sons[0] + assert node_ret.type == NodeType.RETURN + assert isinstance(node_ret.irs[0], Return) + assert node_ret.irs[0].values[0] == f.get_local_variable_from_name("y")