From 5f0cae387ff0f7d1e35d5ffe123029ec3a74fed2 Mon Sep 17 00:00:00 2001 From: Josselin Date: Fri, 8 Feb 2019 18:16:12 -0500 Subject: [PATCH] Improve code readability of major modules: - Group methods together - Use region/endregion format Remove unused import --- slither/__main__.py | 290 +++--- .../data_dependency/data_dependency.py | 87 +- slither/core/cfg/node.py | 249 +++-- slither/core/declarations/contract.py | 405 +++++--- slither/core/declarations/function.py | 568 +++++++---- slither/core/solidity_types/function_type.py | 1 - slither/slithir/convert.py | 909 ++++++++++-------- slither/slithir/utils/ssa.py | 257 +++-- slither/solc_parsing/declarations/contract.py | 273 +++--- slither/solc_parsing/declarations/function.py | 379 +++++--- .../expressions/expression_parsing.py | 215 +++-- slither/solc_parsing/slitherSolc.py | 31 +- 12 files changed, 2174 insertions(+), 1490 deletions(-) diff --git a/slither/__main__.py b/slither/__main__.py index 57a97c39f1..34b11953ee 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -1,32 +1,38 @@ #!/usr/bin/env python3 -import inspect - import argparse import glob +import inspect import json import logging import os +import subprocess import sys import traceback -import subprocess from pkg_resources import iter_entry_points, require from slither.detectors import all_detectors -from slither.printers import all_printers from slither.detectors.abstract_detector import (AbstractDetector, DetectorClassification) +from slither.printers import all_printers from slither.printers.abstract_printer import AbstractPrinter from slither.slither import Slither -from slither.utils.colors import red -from slither.utils.command_line import output_to_markdown, output_detectors, output_printers, output_detectors_json, output_wiki -from slither.utils.colors import set_colorization_enabled +from slither.utils.colors import red, set_colorization_enabled +from slither.utils.command_line import (output_detectors, + output_detectors_json, output_printers, + output_to_markdown, output_wiki) logging.basicConfig() logger = logging.getLogger("Slither") +################################################################################### +################################################################################### +# region Process functions +################################################################################### +################################################################################### + def process(filename, args, detector_classes, printer_classes): """ The core high-level code for running Slither static analysis. @@ -104,10 +110,23 @@ def process_files(filenames, args, detector_classes, printer_classes): slither = Slither(all_contracts, args.solc, args.disable_solc_warnings, args.solc_args) return _process(slither, detector_classes, printer_classes) +# endregion +################################################################################### +################################################################################### +# region Output +################################################################################### +################################################################################### + def output_json(results, filename): with open(filename, 'w', encoding='utf8') as f: json.dump(results, f) +# endregion +################################################################################### +################################################################################### +# region Exit +################################################################################### +################################################################################### def exit(results): if not results: @@ -115,6 +134,13 @@ def exit(results): sys.exit(len(results)) +# endregion +################################################################################### +################################################################################### +# region Detectors and printers +################################################################################### +################################################################################### + def get_detectors_and_printers(): """ NOTE: This contains just a few detectors and printers that we made public. @@ -144,87 +170,69 @@ def get_detectors_and_printers(): return detectors, printers -def main(): - detectors, printers = get_detectors_and_printers() - - main_impl(all_detector_classes=detectors, all_printer_classes=printers) - - -def main_impl(all_detector_classes, all_printer_classes): - """ - :param all_detector_classes: A list of all detectors that can be included/excluded. - :param all_printer_classes: A list of all printers that can be included. - """ - args = parse_args(all_detector_classes, all_printer_classes) - - # Set colorization option - set_colorization_enabled(not args.disable_color) - - printer_classes = choose_printers(args, all_printer_classes) - detector_classes = choose_detectors(args, all_detector_classes) - - default_log = logging.INFO if not args.debug else logging.DEBUG +def choose_detectors(args, all_detector_classes): + # If detectors are specified, run only these ones - for (l_name, l_level) in [('Slither', default_log), - ('Contract', default_log), - ('Function', default_log), - ('Node', default_log), - ('Parsing', default_log), - ('Detectors', default_log), - ('FunctionSolc', default_log), - ('ExpressionParsing', default_log), - ('TypeParsing', default_log), - ('SSA_Conversion', default_log), - ('Printers', default_log)]: - l = logging.getLogger(l_name) - l.setLevel(l_level) + detectors_to_run = [] + detectors = {d.ARGUMENT: d for d in all_detector_classes} - try: - filename = args.filename + if args.detectors_to_run == 'all': + detectors_to_run = all_detector_classes + detectors_excluded = args.detectors_to_exclude.split(',') + for d in detectors: + if d in detectors_excluded: + detectors_to_run.remove(detectors[d]) + else: + for d in args.detectors_to_run.split(','): + if d in detectors: + detectors_to_run.append(detectors[d]) + else: + raise Exception('Error: {} is not a detector'.format(d)) + detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT) + return detectors_to_run - globbed_filenames = glob.glob(filename, recursive=True) + if args.exclude_informational: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.INFORMATIONAL] + if args.exclude_low: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.LOW] + if args.exclude_medium: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.MEDIUM] + if args.exclude_high: + detectors_to_run = [d for d in detectors_to_run if + d.IMPACT != DetectorClassification.HIGH] + if args.detectors_to_exclude: + detectors_to_run = [d for d in detectors_to_run if + d.ARGUMENT not in args.detectors_to_exclude] - if os.path.isfile(filename): - (results, number_contracts) = process(filename, args, detector_classes, printer_classes) + detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT) - elif os.path.isfile(os.path.join(filename, 'truffle.js')) or os.path.isfile(os.path.join(filename, 'truffle-config.js')): - (results, number_contracts) = process_truffle(filename, args, detector_classes, printer_classes) + return detectors_to_run - elif os.path.isdir(filename) or len(globbed_filenames) > 0: - extension = "*.sol" if not args.solc_ast else "*.json" - filenames = glob.glob(os.path.join(filename, extension)) - if not filenames: - filenames = globbed_filenames - number_contracts = 0 - results = [] - if args.splitted and args.solc_ast: - (results, number_contracts) = process_files(filenames, args, detector_classes, printer_classes) - else: - for filename in filenames: - (results_tmp, number_contracts_tmp) = process(filename, args, detector_classes, printer_classes) - number_contracts += number_contracts_tmp - results += results_tmp +def choose_printers(args, all_printer_classes): + printers_to_run = [] - else: - raise Exception("Unrecognised file/dir path: '#{filename}'".format(filename=filename)) + # disable default printer + if args.printers_to_run == '': + return [] - if args.json: - output_json(results, args.json) - # Dont print the number of result for printers - if number_contracts == 0: - logger.warn(red('No contract was analyzed')) - if printer_classes: - logger.info('%s analyzed (%d contracts)', filename, number_contracts) + printers = {p.ARGUMENT: p for p in all_printer_classes} + for p in args.printers_to_run.split(','): + if p in printers: + printers_to_run.append(printers[p]) else: - logger.info('%s analyzed (%d contracts), %d result(s) found', filename, number_contracts, len(results)) - exit(results) - - except Exception: - logging.error('Error in %s' % args.filename) - logging.error(traceback.format_exc()) - sys.exit(-1) + raise Exception('Error: {} is not a printer'.format(p)) + return printers_to_run +# endregion +################################################################################### +################################################################################### +# region Command line parsing +################################################################################### +################################################################################### def parse_args(detector_classes, printer_classes): parser = argparse.ArgumentParser(description='Slither', @@ -405,63 +413,99 @@ def __call__(self, parser, args, values, option_string=None): output_wiki(detectors, values) parser.exit() -def choose_detectors(args, all_detector_classes): - # If detectors are specified, run only these ones - detectors_to_run = [] - detectors = {d.ARGUMENT: d for d in all_detector_classes} +# endregion +################################################################################### +################################################################################### +# region Main +################################################################################### +################################################################################### - if args.detectors_to_run == 'all': - detectors_to_run = all_detector_classes - detectors_excluded = args.detectors_to_exclude.split(',') - for d in detectors: - if d in detectors_excluded: - detectors_to_run.remove(detectors[d]) - else: - for d in args.detectors_to_run.split(','): - if d in detectors: - detectors_to_run.append(detectors[d]) - else: - raise Exception('Error: {} is not a detector'.format(d)) - detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT) - return detectors_to_run +def main(): + detectors, printers = get_detectors_and_printers() - if args.exclude_informational: - detectors_to_run = [d for d in detectors_to_run if - d.IMPACT != DetectorClassification.INFORMATIONAL] - if args.exclude_low: - detectors_to_run = [d for d in detectors_to_run if - d.IMPACT != DetectorClassification.LOW] - if args.exclude_medium: - detectors_to_run = [d for d in detectors_to_run if - d.IMPACT != DetectorClassification.MEDIUM] - if args.exclude_high: - detectors_to_run = [d for d in detectors_to_run if - d.IMPACT != DetectorClassification.HIGH] - if args.detectors_to_exclude: - detectors_to_run = [d for d in detectors_to_run if - d.ARGUMENT not in args.detectors_to_exclude] + main_impl(all_detector_classes=detectors, all_printer_classes=printers) - detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT) - return detectors_to_run +def main_impl(all_detector_classes, all_printer_classes): + """ + :param all_detector_classes: A list of all detectors that can be included/excluded. + :param all_printer_classes: A list of all printers that can be included. + """ + args = parse_args(all_detector_classes, all_printer_classes) + # Set colorization option + set_colorization_enabled(not args.disable_color) -def choose_printers(args, all_printer_classes): - printers_to_run = [] + printer_classes = choose_printers(args, all_printer_classes) + detector_classes = choose_detectors(args, all_detector_classes) + + default_log = logging.INFO if not args.debug else logging.DEBUG + + for (l_name, l_level) in [('Slither', default_log), + ('Contract', default_log), + ('Function', default_log), + ('Node', default_log), + ('Parsing', default_log), + ('Detectors', default_log), + ('FunctionSolc', default_log), + ('ExpressionParsing', default_log), + ('TypeParsing', default_log), + ('SSA_Conversion', default_log), + ('Printers', default_log)]: + l = logging.getLogger(l_name) + l.setLevel(l_level) + + try: + filename = args.filename + + globbed_filenames = glob.glob(filename, recursive=True) + + if os.path.isfile(filename): + (results, number_contracts) = process(filename, args, detector_classes, printer_classes) + + elif os.path.isfile(os.path.join(filename, 'truffle.js')) or os.path.isfile(os.path.join(filename, 'truffle-config.js')): + (results, number_contracts) = process_truffle(filename, args, detector_classes, printer_classes) + + elif os.path.isdir(filename) or len(globbed_filenames) > 0: + extension = "*.sol" if not args.solc_ast else "*.json" + filenames = glob.glob(os.path.join(filename, extension)) + if not filenames: + filenames = globbed_filenames + number_contracts = 0 + results = [] + if args.splitted and args.solc_ast: + (results, number_contracts) = process_files(filenames, args, detector_classes, printer_classes) + else: + for filename in filenames: + (results_tmp, number_contracts_tmp) = process(filename, args, detector_classes, printer_classes) + number_contracts += number_contracts_tmp + results += results_tmp - # disable default printer - if args.printers_to_run == '': - return [] - printers = {p.ARGUMENT: p for p in all_printer_classes} - for p in args.printers_to_run.split(','): - if p in printers: - printers_to_run.append(printers[p]) else: - raise Exception('Error: {} is not a printer'.format(p)) - return printers_to_run + raise Exception("Unrecognised file/dir path: '#{filename}'".format(filename=filename)) + + if args.json: + output_json(results, args.json) + # Dont print the number of result for printers + if number_contracts == 0: + logger.warn(red('No contract was analyzed')) + if printer_classes: + logger.info('%s analyzed (%d contracts)', filename, number_contracts) + else: + logger.info('%s analyzed (%d contracts), %d result(s) found', filename, number_contracts, len(results)) + exit(results) + + except Exception: + logging.error('Error in %s' % args.filename) + logging.error(traceback.format_exc()) + sys.exit(-1) + if __name__ == '__main__': main() + + +# endregion diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index bf70c8cf0e..1412c24757 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -2,40 +2,18 @@ Compute the data depenency between all the SSA variables """ from slither.core.declarations import Contract, Function -from slither.slithir.operations import Index, Member, OperationWithLValue -from slither.slithir.variables import ReferenceVariable, Constant -from slither.slithir.variables import (Constant, LocalIRVariable, StateIRVariable, - ReferenceVariable, TemporaryVariable, - TupleVariable) - - from slither.core.declarations.solidity_variables import \ SolidityVariableComposed +from slither.slithir.operations import Index, OperationWithLValue +from slither.slithir.variables import (Constant, LocalIRVariable, + ReferenceVariable, StateIRVariable, + TemporaryVariable) -KEY_SSA = "DATA_DEPENDENCY_SSA" -KEY_NON_SSA = "DATA_DEPENDENCY" - -# Only for unprotected functions -KEY_SSA_UNPROTECTED = "DATA_DEPENDENCY_SSA_UNPROTECTED" -KEY_NON_SSA_UNPROTECTED = "DATA_DEPENDENCY_UNPROTECTED" - -KEY_INPUT = "DATA_DEPENDENCY_INPUT" -KEY_INPUT_SSA = "DATA_DEPENDENCY_INPUT_SSA" - -def pprint_dependency(context): - print('#### SSA ####') - context = context.context - for k, values in context[KEY_SSA].items(): - print('{} ({}):'.format(k, id(k))) - for v in values: - print('\t- {}'.format(v)) - - print('#### NON SSA ####') - for k, values in context[KEY_NON_SSA].items(): - print('{} ({}):'.format(k, hex(id(k)))) - for v in values: - print('\t- {} ({})'.format(v, hex(id(v)))) - +################################################################################### +################################################################################### +# region User APIs +################################################################################### +################################################################################### def is_dependent(variable, source, context, only_unprotected=False): ''' @@ -119,6 +97,53 @@ def is_tainted_ssa(variable, context, only_unprotected=False): taints |= GENERIC_TAINT return variable in taints or any(is_dependent_ssa(variable, t, context, only_unprotected) for t in taints) + +# endregion +################################################################################### +################################################################################### +# region Module constants +################################################################################### +################################################################################### + +KEY_SSA = "DATA_DEPENDENCY_SSA" +KEY_NON_SSA = "DATA_DEPENDENCY" + +# Only for unprotected functions +KEY_SSA_UNPROTECTED = "DATA_DEPENDENCY_SSA_UNPROTECTED" +KEY_NON_SSA_UNPROTECTED = "DATA_DEPENDENCY_UNPROTECTED" + +KEY_INPUT = "DATA_DEPENDENCY_INPUT" +KEY_INPUT_SSA = "DATA_DEPENDENCY_INPUT_SSA" + + +# endregion +################################################################################### +################################################################################### +# region Debug +################################################################################### +################################################################################### + +def pprint_dependency(context): + print('#### SSA ####') + context = context.context + for k, values in context[KEY_SSA].items(): + print('{} ({}):'.format(k, id(k))) + for v in values: + print('\t- {}'.format(v)) + + print('#### NON SSA ####') + for k, values in context[KEY_NON_SSA].items(): + print('{} ({}):'.format(k, hex(id(k)))) + for v in values: + print('\t- {} ({})'.format(v, hex(id(v)))) + +# endregion +################################################################################### +################################################################################### +# region Analyses +################################################################################### +################################################################################### + def compute_dependency(slither): slither.context[KEY_INPUT] = set() diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 106d220620..7b15937e77 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -5,25 +5,29 @@ from slither.core.children.child_function import ChildFunction from slither.core.declarations import Contract -from slither.core.declarations.solidity_variables import (SolidityFunction, - SolidityVariable) +from slither.core.declarations.solidity_variables import SolidityVariable from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.slithir.convert import convert_expression from slither.slithir.operations import (Balance, HighLevelCall, Index, InternalCall, Length, LibraryCall, LowLevelCall, Member, - OperationWithLValue, SolidityCall, Phi, PhiCallback) -from slither.slithir.variables import (Constant, ReferenceVariable, - TemporaryVariable, TupleVariable, StateIRVariable, LocalIRVariable) -from slither.visitors.expression.expression_printer import ExpressionPrinter -from slither.visitors.expression.read_var import ReadVar -from slither.visitors.expression.write_var import WriteVar + OperationWithLValue, Phi, PhiCallback, + SolidityCall) +from slither.slithir.variables import (Constant, LocalIRVariable, + ReferenceVariable, StateIRVariable, + TemporaryVariable, TupleVariable) logger = logging.getLogger("Node") +################################################################################### +################################################################################### +# region NodeType +################################################################################### +################################################################################### + class NodeType: ENTRYPOINT = 0x0 # no expression @@ -89,10 +93,22 @@ def str(t): return 'END_LOOP' return 'Unknown type {}'.format(hex(t)) + +# endregion +################################################################################### +################################################################################### +# region Utils +################################################################################### +################################################################################### + def link_nodes(n1, n2): n1.add_son(n2) n2.add_father(n1) + + +# endregion + class Node(SourceMapping, ChildFunction): """ Node class @@ -159,68 +175,11 @@ def __init__(self, node_type, node_id): self._expression_vars_read = [] self._expression_calls = [] - - @property - def dominators(self): - ''' - Returns: - set(Node) - ''' - return self._dominators - - @property - def immediate_dominator(self): - ''' - Returns: - Node or None - ''' - return self._immediate_dominator - - @property - def dominance_frontier(self): - ''' - Returns: - set(Node) - ''' - return self._dominance_frontier - - @property - def dominator_successors(self): - return self._dom_successors - - @dominators.setter - def dominators(self, dom): - self._dominators = dom - - @immediate_dominator.setter - def immediate_dominator(self, idom): - self._immediate_dominator = idom - - @dominance_frontier.setter - def dominance_frontier(self, dom): - self._dominance_frontier = dom - - @property - def phi_origins_local_variables(self): - return self._phi_origins_local_variables - - @property - def phi_origins_state_variables(self): - return self._phi_origins_state_variables - - def add_phi_origin_local_variable(self, variable, node): - if variable.name not in self._phi_origins_local_variables: - self._phi_origins_local_variables[variable.name] = (variable, set()) - (v, nodes) = self._phi_origins_local_variables[variable.name] - assert v == variable - nodes.add(node) - - def add_phi_origin_state_variable(self, variable, node): - if variable.canonical_name not in self._phi_origins_state_variables: - self._phi_origins_state_variables[variable.canonical_name] = (variable, set()) - (v, nodes) = self._phi_origins_state_variables[variable.canonical_name] - assert v == variable - nodes.add(node) + ################################################################################### + ################################################################################### + # region General's properties + ################################################################################### + ################################################################################### @property def slither(self): @@ -242,6 +201,13 @@ def type(self): def type(self, t): self._node_type = t + # endregion + ################################################################################### + ################################################################################### + # region Variables + ################################################################################### + ################################################################################### + @property def variables_read(self): """ @@ -291,8 +257,6 @@ def ssa_local_variables_read(self): """ return list(self._ssa_local_vars_read) - - @property def variables_read_as_expression(self): return self._expression_vars_read @@ -347,6 +311,13 @@ def ssa_local_variables_written(self): def variables_written_as_expression(self): return self._expression_vars_written + # endregion + ################################################################################### + ################################################################################### + # region Calls + ################################################################################### + ################################################################################### + @property def internal_calls(self): """ @@ -392,6 +363,13 @@ def external_calls_as_expressions(self): def calls_as_expression(self): return list(self._expression_calls) + # endregion + ################################################################################### + ################################################################################### + # region Expressions + ################################################################################### + ################################################################################### + @property def expression(self): """ @@ -418,9 +396,12 @@ def variable_declaration(self): """ return self._variable_declaration - def __str__(self): - txt = NodeType.str(self._node_type) + ' '+ str(self.expression) - return txt + # endregion + ################################################################################### + ################################################################################### + # region Summary information + ################################################################################### + ################################################################################### def contains_require_or_assert(self): """ @@ -449,6 +430,14 @@ def is_conditional(self, include_loop=True): """ return self.contains_if(include_loop) or self.contains_require_or_assert() + + # endregion + ################################################################################### + ################################################################################### + # region Graph + ################################################################################### + ################################################################################### + def add_father(self, father): """ Add a father node @@ -474,7 +463,6 @@ def fathers(self): """ return list(self._fathers) - def remove_father(self, father): """ Remove the father node. Do nothing if the node is not a father @@ -516,6 +504,13 @@ def sons(self): """ return list(self._sons) + # endregion + ################################################################################### + ################################################################################### + # region SlithIR + ################################################################################### + ################################################################################### + @property def irs(self): """ Returns the slithIR representation @@ -559,6 +554,92 @@ def _is_non_slithir_var(var): def _is_valid_slithir_var(var): return isinstance(var, (ReferenceVariable, TemporaryVariable, TupleVariable)) + # endregion + ################################################################################### + ################################################################################### + # region Dominators + ################################################################################### + ################################################################################### + + @property + def dominators(self): + ''' + Returns: + set(Node) + ''' + return self._dominators + + @property + def immediate_dominator(self): + ''' + Returns: + Node or None + ''' + return self._immediate_dominator + + @property + def dominance_frontier(self): + ''' + Returns: + set(Node) + ''' + return self._dominance_frontier + + @property + def dominator_successors(self): + return self._dom_successors + + @dominators.setter + def dominators(self, dom): + self._dominators = dom + + @immediate_dominator.setter + def immediate_dominator(self, idom): + self._immediate_dominator = idom + + @dominance_frontier.setter + def dominance_frontier(self, dom): + self._dominance_frontier = dom + + # endregion + ################################################################################### + ################################################################################### + # region Phi operation + ################################################################################### + ################################################################################### + + @property + def phi_origins_local_variables(self): + return self._phi_origins_local_variables + + @property + def phi_origins_state_variables(self): + return self._phi_origins_state_variables + + def add_phi_origin_local_variable(self, variable, node): + if variable.name not in self._phi_origins_local_variables: + self._phi_origins_local_variables[variable.name] = (variable, set()) + (v, nodes) = self._phi_origins_local_variables[variable.name] + assert v == variable + nodes.add(node) + + def add_phi_origin_state_variable(self, variable, node): + if variable.canonical_name not in self._phi_origins_state_variables: + self._phi_origins_state_variables[variable.canonical_name] = (variable, set()) + (v, nodes) = self._phi_origins_state_variables[variable.canonical_name] + assert v == variable + nodes.add(node) + + + + + # endregion + ################################################################################### + ################################################################################### + # region Analyses + ################################################################################### + ################################################################################### + def _find_read_write_call(self): for ir in self.irs: @@ -686,3 +767,17 @@ def update_read_write_using_ssa(self): self._vars_written += [v for v in vars_written if v not in self._vars_written] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] + + + # endregion + ################################################################################### + ################################################################################### + # region Built in definitions + ################################################################################### + ################################################################################### + + def __str__(self): + txt = NodeType.str(self._node_type) + ' '+ str(self.expression) + return txt + + # endregion diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index e95fb4e0e7..787172843f 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -38,15 +38,11 @@ def __init__(self): self._initial_state_variables = [] # ssa - def __eq__(self, other): - if isinstance(other, str): - return other == self.name - return NotImplemented - - def __neq__(self, other): - if isinstance(other, str): - return other != self.name - return NotImplemented + ################################################################################### + ################################################################################### + # region General's properties + ################################################################################### + ################################################################################### @property def name(self): @@ -59,59 +55,114 @@ def id(self): return self._id @property - def inheritance(self): - ''' - list(Contract): Inheritance list. Order: the first elem is the first father to be executed - ''' - return list(self._inheritance) + def contract_kind(self): + return self._kind + + # endregion + ################################################################################### + ################################################################################### + # region Structures + ################################################################################### + ################################################################################### @property - def immediate_inheritance(self): + def structures(self): ''' - list(Contract): List of contracts immediately inherited from (fathers). Order: order of declaration. + list(Structure): List of the structures ''' - return list(self._immediate_inheritance) + return list(self._structures.values()) + + def structures_as_dict(self): + return self._structures + + # endregion + ################################################################################### + ################################################################################### + # region Enums + ################################################################################### + ################################################################################### @property - def inheritance_reverse(self): + def enums(self): + return list(self._enums.values()) + + def enums_as_dict(self): + return self._enums + + # endregion + ################################################################################### + ################################################################################### + # region Events + ################################################################################### + ################################################################################### + + @property + def events(self): ''' - list(Contract): Inheritance list. Order: the last elem is the first father to be executed + list(Event): List of the events ''' - return reversed(self._inheritance) + return list(self._events.values()) - def setInheritance(self, inheritance, immediate_inheritance, called_base_constructor_contracts): - self._inheritance = inheritance - self._immediate_inheritance = immediate_inheritance - self._explicit_base_constructor_calls = called_base_constructor_contracts + def events_as_dict(self): + return self._events + + # endregion + ################################################################################### + ################################################################################### + # region Using for + ################################################################################### + ################################################################################### @property - def derived_contracts(self): + def using_for(self): + return self._using_for + + def reverse_using_for(self, name): ''' - list(Contract): Return the list of contracts derived from self + Returns: + (list) ''' - candidates = self.slither.contracts - return [c for c in candidates if self in c.inheritance] + return self._using_for[name] + + # endregion + ################################################################################### + ################################################################################### + # region Variables + ################################################################################### + ################################################################################### @property - def structures(self): + def variables(self): ''' - list(Structure): List of the structures + list(StateVariable): List of the state variables. Alias to self.state_variables ''' - return list(self._structures.values()) + return list(self.state_variables) - def structures_as_dict(self): - return self._structures + def variables_as_dict(self): + return self._variables @property - def enums(self): - return list(self._enums.values()) - - def enums_as_dict(self): - return self._enums + def state_variables(self): + ''' + list(StateVariable): List of the state variables. + ''' + return list(self._variables.values()) + @property + def slithir_variables(self): + ''' + List all of the slithir variables (non SSA) + ''' + slithir_variables = [f.slithir_variables for f in self.functions + self.modifiers] + slithir_variables = [item for sublist in slithir_variables for item in sublist] + return list(set(slithir_variables)) - def modifiers_as_dict(self): - return self._modifiers + # endregion + ################################################################################### + ################################################################################### + # region Constructors + ################################################################################### + ################################################################################### @property def constructor(self): @@ -141,6 +192,25 @@ def constructors(self): ''' return [func for func in self.functions if func.is_constructor] + @property + def explicit_base_constructor_calls(self): + """ + list(Function): List of the base constructors called explicitly by this contract definition. + + Base constructors called by any constructor definition will not be included. + Base constructors implicitly called by the contract definition (without + parenthesis) will not be included. + + On "contract B is A(){..}" it returns the constructor of A + """ + return [c.constructor for c in self._explicit_base_constructor_calls if c.constructor] + + # endregion + ################################################################################### + ################################################################################### + # region Functions and Modifiers + ################################################################################### + ################################################################################### @property def functions(self): @@ -149,6 +219,9 @@ def functions(self): ''' return list(self._functions.values()) + def functions_as_dict(self): + return self._functions + @property def functions_inherited(self): ''' @@ -170,19 +243,6 @@ def functions_entry_points(self): ''' return [f for f in self.functions if f.visibility in ['public', 'external']] - @property - def explicit_base_constructor_calls(self): - """ - list(Function): List of the base constructors called explicitly by this contract definition. - - Base constructors called by any constructor definition will not be included. - Base constructors implicitly called by the contract definition (without - parenthesis) will not be included. - - On "contract B is A(){..}" it returns the constructor of A - """ - return [c.constructor for c in self._explicit_base_constructor_calls if c.constructor] - @property def modifiers(self): ''' @@ -190,6 +250,9 @@ def modifiers(self): ''' return list(self._modifiers.values()) + def modifiers_as_dict(self): + return self._modifiers + @property def modifiers_inherited(self): ''' @@ -225,109 +288,53 @@ def functions_and_modifiers_not_inherited(self): ''' return self.functions_not_inherited + self.modifiers_not_inherited - def get_functions_overridden_by(self, function): - ''' - Return the list of functions overriden by the function - Args: - (core.Function) - Returns: - list(core.Function) - - ''' - candidates = [c.functions_not_inherited for c in self.inheritance] - candidates = [candidate for sublist in candidates for candidate in sublist] - return [f for f in candidates if f.full_name == function.full_name] - - @property - def all_functions_called(self): - ''' - list(Function): List of functions reachable from the contract (include super) - ''' - all_calls = [f.all_internal_calls() for f in self.functions + self.modifiers] + [self.functions + self.modifiers] - all_calls = [item for sublist in all_calls for item in sublist] + self.functions - all_calls = list(set(all_calls)) - - all_constructors = [c.constructor for c in self.inheritance] - all_constructors = list(set([c for c in all_constructors if c])) - - all_calls = set(all_calls+all_constructors) - - return [c for c in all_calls if isinstance(c, Function)] - - @property - def all_state_variables_written(self): - ''' - list(StateVariable): List all of the state variables written - ''' - all_state_variables_written = [f.all_state_variables_written() for f in self.functions + self.modifiers] - all_state_variables_written = [item for sublist in all_state_variables_written for item in sublist] - return list(set(all_state_variables_written)) + # endregion + ################################################################################### + ################################################################################### + # region Inheritance + ################################################################################### + ################################################################################### @property - def all_state_variables_read(self): - ''' - list(StateVariable): List all of the state variables read - ''' - all_state_variables_read = [f.all_state_variables_read() for f in self.functions + self.modifiers] - all_state_variables_read = [item for sublist in all_state_variables_read for item in sublist] - return list(set(all_state_variables_read)) - - @property - def slithir_variables(self): - ''' - List all of the slithir variables (non SSA) - ''' - slithir_variables = [f.slithir_variables for f in self.functions + self.modifiers] - slithir_variables = [item for sublist in slithir_variables for item in sublist] - return list(set(slithir_variables)) - - def functions_as_dict(self): - return self._functions - - @property - def events(self): + def inheritance(self): ''' - list(Event): List of the events + list(Contract): Inheritance list. Order: the first elem is the first father to be executed ''' - return list(self._events.values()) - - def events_as_dict(self): - return self._events + return list(self._inheritance) @property - def state_variables(self): + def immediate_inheritance(self): ''' - list(StateVariable): List of the state variables. + list(Contract): List of contracts immediately inherited from (fathers). Order: order of declaration. ''' - return list(self._variables.values()) + return list(self._immediate_inheritance) @property - def variables(self): + def inheritance_reverse(self): ''' - list(StateVariable): List of the state variables. Alias to self.state_variables + list(Contract): Inheritance list. Order: the last elem is the first father to be executed ''' - return list(self.state_variables) + return reversed(self._inheritance) - def variables_as_dict(self): - return self._variables + def setInheritance(self, inheritance, immediate_inheritance, called_base_constructor_contracts): + self._inheritance = inheritance + self._immediate_inheritance = immediate_inheritance + self._explicit_base_constructor_calls = called_base_constructor_contracts @property - def using_for(self): - return self._using_for - - def reverse_using_for(self, name): + def derived_contracts(self): ''' - Returns: - (list) + list(Contract): Return the list of contracts derived from self ''' - return self._using_for[name] - - @property - def contract_kind(self): - return self._kind + candidates = self.slither.contracts + return [c for c in candidates if self in c.inheritance] - def __str__(self): - return self.name + # endregion + ################################################################################### + ################################################################################### + # region Getters from/to object + ################################################################################### + ################################################################################### def get_functions_reading_from_variable(self, variable): ''' @@ -341,14 +348,6 @@ def get_functions_writing_to_variable(self, variable): ''' return [f for f in self.functions if f.is_writing(variable)] - def is_signature_only(self): - """ Detect if the contract has only abstract functions - - Returns: - bool: true if the function are abstract functions - """ - return all((not f.is_implemented) for f in self.functions) - def get_source_var_declaration(self, var): """ Return the source mapping where the variable is declared @@ -449,6 +448,85 @@ def get_enum_from_canonical_name(self, enum_name): """ return next((e for e in self.enums if e.canonical_name == enum_name), None) + def get_functions_overridden_by(self, function): + ''' + Return the list of functions overriden by the function + Args: + (core.Function) + Returns: + list(core.Function) + + ''' + candidates = [c.functions_not_inherited for c in self.inheritance] + candidates = [candidate for sublist in candidates for candidate in sublist] + return [f for f in candidates if f.full_name == function.full_name] + + # endregion + ################################################################################### + ################################################################################### + # region Recursive getters + ################################################################################### + ################################################################################### + + @property + def all_functions_called(self): + ''' + list(Function): List of functions reachable from the contract (include super) + ''' + all_calls = [f.all_internal_calls() for f in self.functions + self.modifiers] + [self.functions + self.modifiers] + all_calls = [item for sublist in all_calls for item in sublist] + self.functions + all_calls = list(set(all_calls)) + + all_constructors = [c.constructor for c in self.inheritance] + all_constructors = list(set([c for c in all_constructors if c])) + + all_calls = set(all_calls+all_constructors) + + return [c for c in all_calls if isinstance(c, Function)] + + @property + def all_state_variables_written(self): + ''' + list(StateVariable): List all of the state variables written + ''' + all_state_variables_written = [f.all_state_variables_written() for f in self.functions + self.modifiers] + all_state_variables_written = [item for sublist in all_state_variables_written for item in sublist] + return list(set(all_state_variables_written)) + + @property + def all_state_variables_read(self): + ''' + list(StateVariable): List all of the state variables read + ''' + all_state_variables_read = [f.all_state_variables_read() for f in self.functions + self.modifiers] + all_state_variables_read = [item for sublist in all_state_variables_read for item in sublist] + return list(set(all_state_variables_read)) + + # endregion + ################################################################################### + ################################################################################### + # region Summary information + ################################################################################### + ################################################################################### + + def get_summary(self): + """ Return the function summary + + Returns: + (str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries) + """ + func_summaries = [f.get_summary() for f in self.functions] + modif_summaries = [f.get_summary() for f in self.modifiers] + return (self.name, [str(x) for x in self.inheritance], [str(x) for x in self.variables], func_summaries, modif_summaries) + + def is_signature_only(self): + """ Detect if the contract has only abstract functions + + Returns: + bool: true if the function are abstract functions + """ + return all((not f.is_implemented) for f in self.functions) + def is_erc20(self): """ Check if the contract is an erc20 token @@ -462,16 +540,35 @@ def is_erc20(self): 'transferFrom(address,address,uint256)' in full_names and\ 'approve(address,uint256)' in full_names + # endregion + ################################################################################### + ################################################################################### + # region Function analyses + ################################################################################### + ################################################################################### + def update_read_write_using_ssa(self): for function in self.functions + self.modifiers: function.update_read_write_using_ssa() - def get_summary(self): - """ Return the function summary + # endregion + ################################################################################### + ################################################################################### + # region Built in definitions + ################################################################################### + ################################################################################### - Returns: - (str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries) - """ - func_summaries = [f.get_summary() for f in self.functions] - modif_summaries = [f.get_summary() for f in self.modifiers] - return (self.name, [str(x) for x in self.inheritance], [str(x) for x in self.variables], func_summaries, modif_summaries) + def __eq__(self, other): + if isinstance(other, str): + return other == self.name + return NotImplemented + + def __neq__(self, other): + if isinstance(other, str): + return other != self.name + return NotImplemented + + def __str__(self): + return self.name + + # endregion diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 6ac4f7415a..d2f3b6dbd6 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -82,28 +82,11 @@ def __init__(self): self._reachable_from_nodes = set() self._reachable_from_functions = set() - @property - def contains_assembly(self): - return self._contains_assembly - - @property - def return_type(self): - """ - Return the list of return type - If no return, return None - """ - returns = self.returns - if returns: - return [r.type for r in returns] - return None - - @property - def type(self): - """ - Return the list of return type - If no return, return None - """ - return self.return_type + ################################################################################### + ################################################################################### + # region General properties + ################################################################################### + ################################################################################### @property def name(self): @@ -118,25 +101,35 @@ def name(self): return self._name @property - def nodes(self): + def full_name(self): """ - list(Node): List of the nodes + str: func_name(type1,type2) + Return the function signature without the return values """ - return list(self._nodes) + name, parameters, _ = self.signature + return name+'('+','.join(parameters)+')' @property - def entry_point(self): + def is_constructor(self): """ - Node: Entry point of the function + bool: True if the function is the constructor """ - return self._entry_point + return self._is_constructor or self._name == self.contract.name @property - def visibility(self): - """ - str: Function visibility - """ - return self._visibility + def contains_assembly(self): + return self._contains_assembly + + @property + def slither(self): + return self.contract.slither + + # endregion + ################################################################################### + ################################################################################### + # region Payable + ################################################################################### + ################################################################################### @property def payable(self): @@ -145,12 +138,19 @@ def payable(self): """ return self._payable + # endregion + ################################################################################### + ################################################################################### + # region Visibility + ################################################################################### + ################################################################################### + @property - def is_constructor(self): + def visibility(self): """ - bool: True if the function is the constructor + str: Function visibility """ - return self._is_constructor or self._name == self.contract.name + return self._visibility @property def view(self): @@ -166,6 +166,13 @@ def pure(self): """ return self._pure + # endregion + ################################################################################### + ################################################################################### + # region Function's body + ################################################################################### + ################################################################################### + @property def is_implemented(self): """ @@ -180,6 +187,36 @@ def is_empty(self): """ return self._is_empty + + + # endregion + ################################################################################### + ################################################################################### + # region Nodes + ################################################################################### + ################################################################################### + + @property + def nodes(self): + """ + list(Node): List of the nodes + """ + return list(self._nodes) + + @property + def entry_point(self): + """ + Node: Entry point of the function + """ + return self._entry_point + + # endregion + ################################################################################### + ################################################################################### + # region Parameters + ################################################################################### + ################################################################################### + @property def parameters(self): """ @@ -197,6 +234,32 @@ def parameters_ssa(self): def add_parameter_ssa(self, var): self._parameters_ssa.append(var) + # endregion + ################################################################################### + ################################################################################### + # region Return values + ################################################################################### + ################################################################################### + + @property + def return_type(self): + """ + Return the list of return type + If no return, return None + """ + returns = self.returns + if returns: + return [r.type for r in returns] + return None + + @property + def type(self): + """ + Return the list of return type + If no return, return None + """ + return self.return_type + @property def returns(self): """ @@ -214,6 +277,13 @@ def returns_ssa(self): def add_return_ssa(self, var): self._returns_ssa.append(var) + # endregion + ################################################################################### + ################################################################################### + # region Modifiers + ################################################################################### + ################################################################################### + @property def modifiers(self): """ @@ -232,8 +302,13 @@ def explicit_base_constructor_calls(self): # This is a list of contracts internally, so we convert it to a list of constructor functions. return [c.constructor_not_inherited for c in self._explicit_base_constructor_calls if c.constructor_not_inherited] - def __str__(self): - return self._name + + # endregion + ################################################################################### + ################################################################################### + # region Variables + ################################################################################### + ################################################################################### @property def variables(self): @@ -311,6 +386,13 @@ def slithir_variables(self): return list(self._slithir_variables) + # endregion + ################################################################################### + ################################################################################### + # region Calls + ################################################################################### + ################################################################################### + @property def internal_calls(self): """ @@ -353,6 +435,13 @@ def external_calls_as_expressions(self): """ return list(self._external_calls_as_expressions) + # endregion + ################################################################################### + ################################################################################### + # region Expressions + ################################################################################### + ################################################################################### + @property def calls_as_expressions(self): return self._expression_calls @@ -368,6 +457,13 @@ def expressions(self): self._expressions = expressions return self._expressions + # endregion + ################################################################################### + ################################################################################### + # region SlithIR + ################################################################################### + ################################################################################### + @property def slithir_operations(self): """ @@ -379,6 +475,13 @@ def slithir_operations(self): self._slithir_operations = operations return self._slithir_operations + # endregion + ################################################################################### + ################################################################################### + # region Signature + ################################################################################### + ################################################################################### + @property def signature(self): """ @@ -396,14 +499,12 @@ def signature_str(self): name, parameters, returnVars = self.signature return name+'('+','.join(parameters)+') returns('+','.join(returnVars)+')' - @property - def full_name(self): - """ - str: func_name(type1,type2) - Return the function signature without the return values - """ - name, parameters, _ = self.signature - return name+'('+','.join(parameters)+')' + # endregion + ################################################################################### + ################################################################################### + # region Functions + ################################################################################### + ################################################################################### @property def functions_shadowed(self): @@ -418,9 +519,12 @@ def functions_shadowed(self): return [f for f in candidates if f.full_name == self.full_name] - @property - def slither(self): - return self.contract.slither + # endregion + ################################################################################### + ################################################################################### + # region Reachable + ################################################################################### + ################################################################################### @property def reachable_from_nodes(self): @@ -438,111 +542,12 @@ def add_reachable_from_node(self, n, ir): self._reachable_from_nodes.add(ReacheableNode(n, ir)) self._reachable_from_functions.add(n.function) - def _filter_state_variables_written(self, expressions): - ret = [] - for expression in expressions: - if isinstance(expression, Identifier): - ret.append(expression) - if isinstance(expression, UnaryOperation): - ret.append(expression.expression) - if isinstance(expression, MemberAccess): - ret.append(expression.expression) - if isinstance(expression, IndexAccess): - ret.append(expression.expression_left) - return ret - - def _analyze_read_write(self): - """ Compute variables read/written/... - - """ - write_var = [x.variables_written_as_expression for x in self.nodes] - write_var = [x for x in write_var if x] - write_var = [item for sublist in write_var for item in sublist] - write_var = list(set(write_var)) - # Remove dupplicate if they share the same string representation - write_var = [next(obj) for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))] - self._expression_vars_written = write_var - - write_var = [x.variables_written for x in self.nodes] - write_var = [x for x in write_var if x] - write_var = [item for sublist in write_var for item in sublist] - write_var = list(set(write_var)) - # Remove dupplicate if they share the same string representation - write_var = [next(obj) for i, obj in\ - groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))] - self._vars_written = write_var - - read_var = [x.variables_read_as_expression for x in self.nodes] - read_var = [x for x in read_var if x] - read_var = [item for sublist in read_var for item in sublist] - # Remove dupplicate if they share the same string representation - read_var = [next(obj) for i, obj in\ - groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))] - self._expression_vars_read = read_var - - read_var = [x.variables_read for x in self.nodes] - read_var = [x for x in read_var if x] - read_var = [item for sublist in read_var for item in sublist] - # Remove dupplicate if they share the same string representation - read_var = [next(obj) for i, obj in\ - groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))] - self._vars_read = read_var - - self._state_vars_written = [x for x in self.variables_written if\ - isinstance(x, StateVariable)] - self._state_vars_read = [x for x in self.variables_read if\ - isinstance(x, (StateVariable))] - self._solidity_vars_read = [x for x in self.variables_read if\ - isinstance(x, (SolidityVariable))] - - self._vars_read_or_written = self._vars_written + self._vars_read - - slithir_variables = [x.slithir_variables for x in self.nodes] - slithir_variables = [x for x in slithir_variables if x] - self._slithir_variables = [item for sublist in slithir_variables for item in sublist] - - def _analyze_calls(self): - calls = [x.calls_as_expression for x in self.nodes] - calls = [x for x in calls if x] - calls = [item for sublist in calls for item in sublist] - # Remove dupplicate if they share the same string representation - # TODO: check if groupby is still necessary here - calls = [next(obj) for i, obj in\ - groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))] - self._expression_calls = calls - - internal_calls = [x.internal_calls for x in self.nodes] - internal_calls = [x for x in internal_calls if x] - internal_calls = [item for sublist in internal_calls for item in sublist] - internal_calls = [next(obj) for i, obj in - groupby(sorted(internal_calls, key=lambda x: str(x)), lambda x: str(x))] - self._internal_calls = internal_calls - - self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] - - low_level_calls = [x.low_level_calls for x in self.nodes] - low_level_calls = [x for x in low_level_calls if x] - low_level_calls = [item for sublist in low_level_calls for item in sublist] - low_level_calls = [next(obj) for i, obj in - groupby(sorted(low_level_calls, key=lambda x: str(x)), lambda x: str(x))] - - self._low_level_calls = low_level_calls - - high_level_calls = [x.high_level_calls for x in self.nodes] - high_level_calls = [x for x in high_level_calls if x] - high_level_calls = [item for sublist in high_level_calls for item in sublist] - high_level_calls = [next(obj) for i, obj in - groupby(sorted(high_level_calls, key=lambda x: str(x)), lambda x: str(x))] - - self._high_level_calls = high_level_calls - - external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes] - external_calls_as_expressions = [x for x in external_calls_as_expressions if x] - external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist] - external_calls_as_expressions = [next(obj) for i, obj in - groupby(sorted(external_calls_as_expressions, key=lambda x: str(x)), lambda x: str(x))] - self._external_calls_as_expressions = external_calls_as_expressions - + # endregion + ################################################################################### + ################################################################################### + # region Recursive getters + ################################################################################### + ################################################################################### def _explore_functions(self, f_new_values): values = f_new_values(self) @@ -698,49 +703,12 @@ def all_solidity_variables_used_as_args(self): lambda x: self._explore_func_nodes(x, self._solidity_variable_in_internal_calls)) return self._all_solidity_variables_used_as_args - def is_reading(self, variable): - """ - Check if the function reads the variable - Args: - variable (Variable): - Returns: - bool: True if the variable is read - """ - return variable in self.variables_read - - def is_reading_in_conditional_node(self, variable): - """ - Check if the function reads the variable in a IF node - Args: - variable (Variable): - Returns: - bool: True if the variable is read - """ - variables_read = [n.variables_read for n in self.nodes if n.contains_if()] - variables_read = [item for sublist in variables_read for item in sublist] - return variable in variables_read - - def is_reading_in_require_or_assert(self, variable): - """ - Check if the function reads the variable in an require or assert - Args: - variable (Variable): - Returns: - bool: True if the variable is read - """ - variables_read = [n.variables_read for n in self.nodes if n.contains_require_or_assert()] - variables_read = [item for sublist in variables_read for item in sublist] - return variable in variables_read - - def is_writing(self, variable): - """ - Check if the function writes the variable - Args: - variable (Variable): - Returns: - bool: True if the variable is written - """ - return variable in self.variables_written + # endregion + ################################################################################### + ################################################################################### + # region Visitor + ################################################################################### + ################################################################################### def apply_visitor(self, Visitor): """ @@ -754,6 +722,29 @@ def apply_visitor(self, Visitor): v = [Visitor(e).result() for e in expressions] return [item for sublist in v for item in sublist] + # endregion + ################################################################################### + ################################################################################### + # region Getters from/to object + ################################################################################### + ################################################################################### + + def get_local_variable_from_name(self, variable_name): + """ + Return a local variable from a name + Args: + varible_name (str): name of the variable + Returns: + LocalVariable + """ + return next((v for v in self.variables if v.name == variable_name), None) + + # endregion + ################################################################################### + ################################################################################### + # region Export + ################################################################################### + ################################################################################### def cfg_to_dot(self, filename): """ @@ -812,6 +803,57 @@ def description(node): f.write("}\n") + # endregion + ################################################################################### + ################################################################################### + # region Summary information + ################################################################################### + ################################################################################### + + def is_reading(self, variable): + """ + Check if the function reads the variable + Args: + variable (Variable): + Returns: + bool: True if the variable is read + """ + return variable in self.variables_read + + def is_reading_in_conditional_node(self, variable): + """ + Check if the function reads the variable in a IF node + Args: + variable (Variable): + Returns: + bool: True if the variable is read + """ + variables_read = [n.variables_read for n in self.nodes if n.contains_if()] + variables_read = [item for sublist in variables_read for item in sublist] + return variable in variables_read + + def is_reading_in_require_or_assert(self, variable): + """ + Check if the function reads the variable in an require or assert + Args: + variable (Variable): + Returns: + bool: True if the variable is read + """ + variables_read = [n.variables_read for n in self.nodes if n.contains_require_or_assert()] + variables_read = [item for sublist in variables_read for item in sublist] + return variable in variables_read + + def is_writing(self, variable): + """ + Check if the function writes the variable + Args: + variable (Variable): + Returns: + bool: True if the variable is written + """ + return variable in self.variables_written + def get_summary(self): """ Return the function summary @@ -844,12 +886,128 @@ def is_protected(self): args_vars = self.all_solidity_variables_used_as_args() return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars - def get_local_variable_from_name(self, variable_name): - """ - Return a local variable from a name - Args: - varible_name (str): name of the variable - Returns: - LocalVariable + # endregion + ################################################################################### + ################################################################################### + # region Analyses + ################################################################################### + ################################################################################### + + def _filter_state_variables_written(self, expressions): + ret = [] + for expression in expressions: + if isinstance(expression, Identifier): + ret.append(expression) + if isinstance(expression, UnaryOperation): + ret.append(expression.expression) + if isinstance(expression, MemberAccess): + ret.append(expression.expression) + if isinstance(expression, IndexAccess): + ret.append(expression.expression_left) + return ret + + def _analyze_read_write(self): + """ Compute variables read/written/... + """ - return next((v for v in self.variables if v.name == variable_name), None) + write_var = [x.variables_written_as_expression for x in self.nodes] + write_var = [x for x in write_var if x] + write_var = [item for sublist in write_var for item in sublist] + write_var = list(set(write_var)) + # Remove dupplicate if they share the same string representation + write_var = [next(obj) for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))] + self._expression_vars_written = write_var + + write_var = [x.variables_written for x in self.nodes] + write_var = [x for x in write_var if x] + write_var = [item for sublist in write_var for item in sublist] + write_var = list(set(write_var)) + # Remove dupplicate if they share the same string representation + write_var = [next(obj) for i, obj in\ + groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x))] + self._vars_written = write_var + + read_var = [x.variables_read_as_expression for x in self.nodes] + read_var = [x for x in read_var if x] + read_var = [item for sublist in read_var for item in sublist] + # Remove dupplicate if they share the same string representation + read_var = [next(obj) for i, obj in\ + groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))] + self._expression_vars_read = read_var + + read_var = [x.variables_read for x in self.nodes] + read_var = [x for x in read_var if x] + read_var = [item for sublist in read_var for item in sublist] + # Remove dupplicate if they share the same string representation + read_var = [next(obj) for i, obj in\ + groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x))] + self._vars_read = read_var + + self._state_vars_written = [x for x in self.variables_written if\ + isinstance(x, StateVariable)] + self._state_vars_read = [x for x in self.variables_read if\ + isinstance(x, (StateVariable))] + self._solidity_vars_read = [x for x in self.variables_read if\ + isinstance(x, (SolidityVariable))] + + self._vars_read_or_written = self._vars_written + self._vars_read + + slithir_variables = [x.slithir_variables for x in self.nodes] + slithir_variables = [x for x in slithir_variables if x] + self._slithir_variables = [item for sublist in slithir_variables for item in sublist] + + def _analyze_calls(self): + calls = [x.calls_as_expression for x in self.nodes] + calls = [x for x in calls if x] + calls = [item for sublist in calls for item in sublist] + # Remove dupplicate if they share the same string representation + # TODO: check if groupby is still necessary here + calls = [next(obj) for i, obj in\ + groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))] + self._expression_calls = calls + + internal_calls = [x.internal_calls for x in self.nodes] + internal_calls = [x for x in internal_calls if x] + internal_calls = [item for sublist in internal_calls for item in sublist] + internal_calls = [next(obj) for i, obj in + groupby(sorted(internal_calls, key=lambda x: str(x)), lambda x: str(x))] + self._internal_calls = internal_calls + + self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] + + low_level_calls = [x.low_level_calls for x in self.nodes] + low_level_calls = [x for x in low_level_calls if x] + low_level_calls = [item for sublist in low_level_calls for item in sublist] + low_level_calls = [next(obj) for i, obj in + groupby(sorted(low_level_calls, key=lambda x: str(x)), lambda x: str(x))] + + self._low_level_calls = low_level_calls + + high_level_calls = [x.high_level_calls for x in self.nodes] + high_level_calls = [x for x in high_level_calls if x] + high_level_calls = [item for sublist in high_level_calls for item in sublist] + high_level_calls = [next(obj) for i, obj in + groupby(sorted(high_level_calls, key=lambda x: str(x)), lambda x: str(x))] + + self._high_level_calls = high_level_calls + + external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes] + external_calls_as_expressions = [x for x in external_calls_as_expressions if x] + external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist] + external_calls_as_expressions = [next(obj) for i, obj in + groupby(sorted(external_calls_as_expressions, key=lambda x: str(x)), lambda x: str(x))] + self._external_calls_as_expressions = external_calls_as_expressions + + + + # endregion + ################################################################################### + ################################################################################### + # region Built in definitions + ################################################################################### + ################################################################################### + + def __str__(self): + return self._name + + # endregion diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py index 3afb18de48..b5cdb0f639 100644 --- a/slither/core/solidity_types/function_type.py +++ b/slither/core/solidity_types/function_type.py @@ -1,6 +1,5 @@ from slither.core.solidity_types.type import Type from slither.core.variables.function_type_variable import FunctionTypeVariable -from slither.core.expressions.expression import Expression class FunctionType(Type): diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index f5562cd98b..e09cd2737f 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1,20 +1,24 @@ import logging -from slither.core.declarations import (Contract, Enum, Event, SolidityFunction, - Structure, SolidityVariableComposed, Function, SolidityVariable) -from slither.core.expressions import Identifier, Literal, TupleExpression -from slither.core.solidity_types import ElementaryType, UserDefinedType, MappingType, ArrayType, FunctionType +from slither.core.declarations import (Contract, Enum, Event, Function, + SolidityFunction, SolidityVariable, + SolidityVariableComposed, Structure) +from slither.core.expressions import Identifier, Literal +from slither.core.solidity_types import (ArrayType, ElementaryType, + FunctionType, MappingType, + UserDefinedType) from slither.core.variables.variable import Variable -from slither.slithir.operations import (Assignment, Binary, BinaryType, Call, - Condition, Delete, EventCall, - HighLevelCall, Index, InitArray, - InternalCall, InternalDynamicCall, LibraryCall, - LowLevelCall, Member, NewArray, - NewContract, NewElementaryType, - NewStructure, OperationWithLValue, - Push, Return, Send, SolidityCall, - Transfer, TypeConversion, Unary, - Unpack, Length, Balance) +from slither.slithir.operations import (Assignment, Balance, Binary, + BinaryType, Call, Condition, Delete, + EventCall, HighLevelCall, Index, + InitArray, InternalCall, + InternalDynamicCall, Length, + LibraryCall, LowLevelCall, Member, + NewArray, NewContract, + NewElementaryType, NewStructure, + OperationWithLValue, Push, Return, + Send, SolidityCall, Transfer, + TypeConversion, Unary, Unpack) from slither.slithir.tmp_operations.argument import Argument, ArgumentType from slither.slithir.tmp_operations.tmp_call import TmpCall from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray @@ -23,11 +27,47 @@ TmpNewElementaryType from slither.slithir.tmp_operations.tmp_new_structure import TmpNewStructure from slither.slithir.variables import (Constant, ReferenceVariable, - TemporaryVariable, TupleVariable) + TemporaryVariable) from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR logger = logging.getLogger('ConvertToIR') +def convert_expression(expression, node): + # handle standlone expression + # such as return true; + from slither.core.cfg.node import NodeType + + if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]: + result = [Condition(Constant(expression.value))] + return result + if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]: + result = [Condition(expression.value)] + return result + + + visitor = ExpressionToSlithIR(expression, node) + result = visitor.result() + + result = apply_ir_heuristics(result, node) + + if result: + if node.type in [NodeType.IF, NodeType.IFLOOP]: + assert isinstance(result[-1], (OperationWithLValue)) + result.append(Condition(result[-1].lvalue)) + elif node.type == NodeType.RETURN: + # May return None + if isinstance(result[-1], (OperationWithLValue)): + result.append(Return(result[-1].lvalue)) + + return result + + +################################################################################### +################################################################################### +# region Helpers +################################################################################### +################################################################################### + def is_value(ins): if isinstance(ins, TmpCall): if isinstance(ins.ori, Member): @@ -42,7 +82,65 @@ def is_gas(ins): return True return False +def get_sig(ir): + ''' + Return a list of potential signature + It is a list, as Constant variables can be converted to int256 + Args: + ir (slithIR.operation) + Returns: + list(str) + ''' + sig = '{}({})' + name = ir.function_name + + # list of list of arguments + argss = [[]] + for arg in ir.arguments: + if isinstance(arg, (list,)): + type_arg = '{}[{}]'.format(get_type(arg[0].type), len(arg)) + elif isinstance(arg, Function): + type_arg = arg.signature_str + else: + type_arg = get_type(arg.type) + if isinstance(arg, Constant) and arg.type == ElementaryType('uint256'): + # If it is a constant + # We dupplicate the existing list + # And we add uint256 and int256 cases + # There is no potential collision, as the compiler + # Prevent it with a + # "not unique after argument-dependent loopkup" issue + argss_new = [list(args) for args in argss] + for args in argss: + args.append(str(ElementaryType('uint256'))) + for args in argss_new: + args.append(str(ElementaryType('int256'))) + argss = argss + argss_new + else: + for args in argss: + args.append(type_arg) + return [sig.format(name, ','.join(args)) for args in argss] + +def is_temporary(ins): + return isinstance(ins, (Argument, + TmpNewElementaryType, + TmpNewContract, + TmpNewArray, + TmpNewStructure)) + + + +# endregion +################################################################################### +################################################################################### +# region Calls modification +################################################################################### +################################################################################### + def integrate_value_gas(result): + ''' + Integrate value and gas temporary arguments to call instruction + ''' was_changed = True calls = [] @@ -110,7 +208,17 @@ def integrate_value_gas(result): return result -def propage_type_and_convert_call(result, node): +# endregion +################################################################################### +################################################################################### +# region Calls modification and Type propagation +################################################################################### +################################################################################### + +def propagate_type_and_convert_call(result, node): + ''' + Propagate the types variables and convert tmp call to real call operation + ''' calls_value = {} calls_gas = {} @@ -179,133 +287,343 @@ def propage_type_and_convert_call(result, node): idx = idx +1 return result -def convert_to_low_level(ir): - """ - Convert to a transfer/send/or low level call - The funciton assume to receive a correct IR - The checks must be done by the caller - - Additionally convert abi... to solidityfunction - """ - if ir.function_name == 'transfer': - assert len(ir.arguments) == 1 - ir = Transfer(ir.destination, ir.arguments[0]) - return ir - elif ir.function_name == 'send': - assert len(ir.arguments) == 1 - ir = Send(ir.destination, ir.arguments[0], ir.lvalue) - ir.lvalue.set_type(ElementaryType('bool')) - return ir - elif ir.destination.name == 'abi' and ir.function_name in ['encode', - 'encodePacked', - 'encodeWithSelector', - 'encodeWithSignature', - 'decode']: - - call = SolidityFunction('abi.{}()'.format(ir.function_name)) - new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call) - new_ir.arguments = ir.arguments - if isinstance(call.return_type, list) and len(call.return_type) == 1: - new_ir.lvalue.set_type(call.return_type[0]) - else: - new_ir.lvalue.set_type(call.return_type) - return new_ir - elif ir.function_name in ['call', - 'delegatecall', - 'callcode', - 'staticcall']: - new_ir = LowLevelCall(ir.destination, - ir.function_name, - ir.nbr_arguments, - ir.lvalue, - ir.type_call) - new_ir.call_gas = ir.call_gas - new_ir.call_value = ir.call_value - new_ir.arguments = ir.arguments - new_ir.lvalue.set_type(ElementaryType('bool')) - return new_ir - logger.error('Incorrect conversion to low level {}'.format(ir)) - exit(-1) - -def convert_to_push(ir, node): - """ - Convert a call to a PUSH operaiton - - The funciton assume to receive a correct IR - The checks must be done by the caller - - May necessitate to create an intermediate operation (InitArray) - Necessitate to return the lenght (see push documentation) - As a result, the function return may return a list - """ - - - lvalue = ir.lvalue - if isinstance(ir.arguments[0], list): - ret = [] - - val = TemporaryVariable(node) - operation = InitArray(ir.arguments[0], val) - ret.append(operation) - - ir = Push(ir.destination, val) - - length = Literal(len(operation.init_values)) - t = operation.init_values[0].type - ir.lvalue.set_type(ArrayType(t, length)) - - ret.append(ir) - - if lvalue: - length = Length(ir.array, lvalue) - length.lvalue.points_to = ir.lvalue - ret.append(length) - - return ret - - ir = Push(ir.destination, ir.arguments[0]) - - if lvalue: - ret = [] - ret.append(ir) +def propagate_types(ir, node): + # propagate the type + using_for = node.function.contract.using_for + if isinstance(ir, OperationWithLValue): + # Force assignment in case of missing previous correct type + if not ir.lvalue.type: + if isinstance(ir, Assignment): + ir.lvalue.set_type(ir.rvalue.type) + elif isinstance(ir, Binary): + if BinaryType.return_bool(ir.type): + ir.lvalue.set_type(ElementaryType('bool')) + else: + ir.lvalue.set_type(ir.variable_left.type) + elif isinstance(ir, Delete): + # nothing to propagate + pass + elif isinstance(ir, LibraryCall): + return convert_type_library_call(ir, ir.destination) + elif isinstance(ir, HighLevelCall): + t = ir.destination.type - length = Length(ir.array, lvalue) - length.lvalue.points_to = ir.lvalue - ret.append(length) - return ret + # Temporary operation (they are removed later) + if t is None: + return - return ir + # convert library + if t in using_for or '*' in using_for: + new_ir = convert_to_library(ir, node, using_for) + if new_ir: + return new_ir -def look_for_library(contract, ir, node, using_for, t): - for destination in using_for[t]: - lib_contract = contract.slither.get_contract_from_name(str(destination)) - if lib_contract: - lib_call = LibraryCall(lib_contract, - ir.function_name, - ir.nbr_arguments, - ir.lvalue, - ir.type_call) - lib_call.call_gas = ir.call_gas - lib_call.arguments = [ir.destination] + ir.arguments - new_ir = convert_type_library_call(lib_call, lib_contract) - if new_ir: - new_ir.set_node(ir.node) - return new_ir - return None + if isinstance(t, UserDefinedType): + # UserdefinedType + t_type = t.type + if isinstance(t_type, Contract): + contract = node.slither.get_contract_from_name(t_type.name) + return convert_type_of_high_level_call(ir, contract) -def convert_to_library(ir, node, using_for): - contract = node.function.contract - t = ir.destination.type + # Convert HighLevelCall to LowLevelCall + if isinstance(t, ElementaryType) and t.name == 'address': + if ir.destination.name == 'this': + return convert_type_of_high_level_call(ir, node.function.contract) + return convert_to_low_level(ir) - if t in using_for: - new_ir = look_for_library(contract, ir, node, using_for, t) - if new_ir: - return new_ir + # Convert push operations + # May need to insert a new operation + # Which leads to return a list of operation + if isinstance(t, ArrayType): + if ir.function_name == 'push' and len(ir.arguments) == 1: + return convert_to_push(ir, node) - if '*' in using_for: - new_ir = look_for_library(contract, ir, node, using_for, '*') - if new_ir: - return new_ir + elif isinstance(ir, Index): + if isinstance(ir.variable_left.type, MappingType): + ir.lvalue.set_type(ir.variable_left.type.type_to) + elif isinstance(ir.variable_left.type, ArrayType): + ir.lvalue.set_type(ir.variable_left.type.type) + + elif isinstance(ir, InitArray): + length = len(ir.init_values) + t = ir.init_values[0].type + ir.lvalue.set_type(ArrayType(t, length)) + elif isinstance(ir, InternalCall): + # if its not a tuple, return a singleton + return_type = ir.function.return_type + if return_type: + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + elif len(return_type)>1: + ir.lvalue.set_type(return_type) + else: + ir.lvalue = None + elif isinstance(ir, InternalDynamicCall): + # if its not a tuple, return a singleton + return_type = ir.function_type.return_type + if return_type: + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + else: + ir.lvalue.set_type(return_type) + else: + ir.lvalue = None + elif isinstance(ir, LowLevelCall): + # Call are not yet converted + # This should not happen + assert False + elif isinstance(ir, Member): + # TODO we should convert the reference to a temporary if the member is a length or a balance + if ir.variable_right == 'length' and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)): + length = Length(ir.variable_left, ir.lvalue) + length.lvalue.points_to = ir.variable_left + return length + if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType): + return Balance(ir.variable_left, ir.lvalue) + left = ir.variable_left + if isinstance(left, (Variable, SolidityVariable)): + t = ir.variable_left.type + elif isinstance(left, (Contract, Enum, Structure)): + t = UserDefinedType(left) + # can be None due to temporary operation + if t: + if isinstance(t, UserDefinedType): + # UserdefinedType + type_t = t.type + if isinstance(type_t, Enum): + ir.lvalue.set_type(t) + elif isinstance(type_t, Structure): + elems = type_t.elems + for elem in elems: + if elem == ir.variable_right: + ir.lvalue.set_type(elems[elem].type) + else: + assert isinstance(type_t, Contract) + elif isinstance(ir, NewArray): + ir.lvalue.set_type(ir.array_type) + elif isinstance(ir, NewContract): + contract = node.slither.get_contract_from_name(ir.contract_name) + ir.lvalue.set_type(UserDefinedType(contract)) + elif isinstance(ir, NewElementaryType): + ir.lvalue.set_type(ir.type) + elif isinstance(ir, NewStructure): + ir.lvalue.set_type(UserDefinedType(ir.structure)) + elif isinstance(ir, Push): + # No change required + pass + elif isinstance(ir, Send): + ir.lvalue.set_type(ElementaryType('bool')) + elif isinstance(ir, SolidityCall): + return_type = ir.function.return_type + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + elif len(return_type)>1: + ir.lvalue.set_type(return_type) + elif isinstance(ir, TypeConversion): + ir.lvalue.set_type(ir.type) + elif isinstance(ir, Unary): + ir.lvalue.set_type(ir.rvalue.type) + elif isinstance(ir, Unpack): + types = ir.tuple.type.type + idx = ir.index + t = types[idx] + ir.lvalue.set_type(t) + elif isinstance(ir, (Argument, TmpCall, TmpNewArray, TmpNewContract, TmpNewStructure, TmpNewElementaryType)): + # temporary operation; they will be removed + pass + else: + logger.error('Not handling {} during type propgation'.format(type(ir))) + exit(-1) + +def extract_tmp_call(ins): + assert isinstance(ins, TmpCall) + + if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): + call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type) + call.call_id = ins.call_id + return call + if isinstance(ins.ori, Member): + if isinstance(ins.ori.variable_left, Contract): + st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right) + if st: + op = NewStructure(st, ins.lvalue) + op.call_id = ins.call_id + return op + libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) + libcall.call_id = ins.call_id + return libcall + msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) + msgcall.call_id = ins.call_id + return msgcall + + if isinstance(ins.ori, TmpCall): + r = extract_tmp_call(ins.ori) + return r + if isinstance(ins.called, SolidityVariableComposed): + if str(ins.called) == 'block.blockhash': + ins.called = SolidityFunction('blockhash(uint256)') + elif str(ins.called) == 'this.balance': + return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call) + + if isinstance(ins.called, SolidityFunction): + return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call) + + if isinstance(ins.ori, TmpNewElementaryType): + return NewElementaryType(ins.ori.type, ins.lvalue) + + if isinstance(ins.ori, TmpNewContract): + op = NewContract(Constant(ins.ori.contract_name), ins.lvalue) + op.call_id = ins.call_id + return op + + if isinstance(ins.ori, TmpNewArray): + return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue) + + if isinstance(ins.called, Structure): + op = NewStructure(ins.called, ins.lvalue) + op.call_id = ins.call_id + return op + + if isinstance(ins.called, Event): + return EventCall(ins.called.name) + + + raise Exception('Not extracted {} {}'.format(type(ins.called), ins)) + +# endregion +################################################################################### +################################################################################### +# region Conversion operations +################################################################################### +################################################################################### + +def convert_to_low_level(ir): + """ + Convert to a transfer/send/or low level call + The funciton assume to receive a correct IR + The checks must be done by the caller + + Additionally convert abi... to solidityfunction + """ + if ir.function_name == 'transfer': + assert len(ir.arguments) == 1 + ir = Transfer(ir.destination, ir.arguments[0]) + return ir + elif ir.function_name == 'send': + assert len(ir.arguments) == 1 + ir = Send(ir.destination, ir.arguments[0], ir.lvalue) + ir.lvalue.set_type(ElementaryType('bool')) + return ir + elif ir.destination.name == 'abi' and ir.function_name in ['encode', + 'encodePacked', + 'encodeWithSelector', + 'encodeWithSignature', + 'decode']: + + call = SolidityFunction('abi.{}()'.format(ir.function_name)) + new_ir = SolidityCall(call, ir.nbr_arguments, ir.lvalue, ir.type_call) + new_ir.arguments = ir.arguments + if isinstance(call.return_type, list) and len(call.return_type) == 1: + new_ir.lvalue.set_type(call.return_type[0]) + else: + new_ir.lvalue.set_type(call.return_type) + return new_ir + elif ir.function_name in ['call', + 'delegatecall', + 'callcode', + 'staticcall']: + new_ir = LowLevelCall(ir.destination, + ir.function_name, + ir.nbr_arguments, + ir.lvalue, + ir.type_call) + new_ir.call_gas = ir.call_gas + new_ir.call_value = ir.call_value + new_ir.arguments = ir.arguments + new_ir.lvalue.set_type(ElementaryType('bool')) + return new_ir + logger.error('Incorrect conversion to low level {}'.format(ir)) + exit(-1) + +def convert_to_push(ir, node): + """ + Convert a call to a PUSH operaiton + + The funciton assume to receive a correct IR + The checks must be done by the caller + + May necessitate to create an intermediate operation (InitArray) + Necessitate to return the lenght (see push documentation) + As a result, the function return may return a list + """ + + + lvalue = ir.lvalue + if isinstance(ir.arguments[0], list): + ret = [] + + val = TemporaryVariable(node) + operation = InitArray(ir.arguments[0], val) + ret.append(operation) + + ir = Push(ir.destination, val) + + length = Literal(len(operation.init_values)) + t = operation.init_values[0].type + ir.lvalue.set_type(ArrayType(t, length)) + + ret.append(ir) + + if lvalue: + length = Length(ir.array, lvalue) + length.lvalue.points_to = ir.lvalue + ret.append(length) + + return ret + + ir = Push(ir.destination, ir.arguments[0]) + + if lvalue: + ret = [] + ret.append(ir) + + length = Length(ir.array, lvalue) + length.lvalue.points_to = ir.lvalue + ret.append(length) + return ret + + return ir + +def look_for_library(contract, ir, node, using_for, t): + for destination in using_for[t]: + lib_contract = contract.slither.get_contract_from_name(str(destination)) + if lib_contract: + lib_call = LibraryCall(lib_contract, + ir.function_name, + ir.nbr_arguments, + ir.lvalue, + ir.type_call) + lib_call.call_gas = ir.call_gas + lib_call.arguments = [ir.destination] + ir.arguments + new_ir = convert_type_library_call(lib_call, lib_contract) + if new_ir: + new_ir.set_node(ir.node) + return new_ir + return None + +def convert_to_library(ir, node, using_for): + contract = node.function.contract + t = ir.destination.type + + if t in using_for: + new_ir = look_for_library(contract, ir, node, using_for, t) + if new_ir: + return new_ir + + if '*' in using_for: + new_ir = look_for_library(contract, ir, node, using_for, '*') + if new_ir: + return new_ir return None @@ -316,47 +634,8 @@ def get_type(t): """ if isinstance(t, UserDefinedType): if isinstance(t.type, Contract): - return 'address' - return str(t) - -def get_sig(ir): - ''' - Return a list of potential signature - It is a list, as Constant variables can be converted to int256 - Args: - ir (slithIR.operation) - Returns: - list(str) - ''' - sig = '{}({})' - name = ir.function_name - - # list of list of arguments - argss = [[]] - for arg in ir.arguments: - if isinstance(arg, (list,)): - type_arg = '{}[{}]'.format(get_type(arg[0].type), len(arg)) - elif isinstance(arg, Function): - type_arg = arg.signature_str - else: - type_arg = get_type(arg.type) - if isinstance(arg, Constant) and arg.type == ElementaryType('uint256'): - # If it is a constant - # We dupplicate the existing list - # And we add uint256 and int256 cases - # There is no potential collision, as the compiler - # Prevent it with a - # "not unique after argument-dependent loopkup" issue - argss_new = [list(args) for args in argss] - for args in argss: - args.append(str(ElementaryType('uint256'))) - for args in argss_new: - args.append(str(ElementaryType('int256'))) - argss = argss + argss_new - else: - for args in argss: - args.append(type_arg) - return [sig.format(name, ','.join(args)) for args in argss] + return 'address' + return str(t) def convert_type_library_call(ir, lib_contract): sigs = get_sig(ir) @@ -448,167 +727,12 @@ def convert_type_of_high_level_call(ir, contract): return None -def propagate_types(ir, node): - # propagate the type - using_for = node.function.contract.using_for - if isinstance(ir, OperationWithLValue): - # Force assignment in case of missing previous correct type - if not ir.lvalue.type: - if isinstance(ir, Assignment): - ir.lvalue.set_type(ir.rvalue.type) - elif isinstance(ir, Binary): - if BinaryType.return_bool(ir.type): - ir.lvalue.set_type(ElementaryType('bool')) - else: - ir.lvalue.set_type(ir.variable_left.type) - elif isinstance(ir, Delete): - # nothing to propagate - pass - elif isinstance(ir, LibraryCall): - return convert_type_library_call(ir, ir.destination) - elif isinstance(ir, HighLevelCall): - t = ir.destination.type - - # Temporary operation (they are removed later) - if t is None: - return - - # convert library - if t in using_for or '*' in using_for: - new_ir = convert_to_library(ir, node, using_for) - if new_ir: - return new_ir - - if isinstance(t, UserDefinedType): - # UserdefinedType - t_type = t.type - if isinstance(t_type, Contract): - contract = node.slither.get_contract_from_name(t_type.name) - return convert_type_of_high_level_call(ir, contract) - - # Convert HighLevelCall to LowLevelCall - if isinstance(t, ElementaryType) and t.name == 'address': - if ir.destination.name == 'this': - return convert_type_of_high_level_call(ir, node.function.contract) - return convert_to_low_level(ir) - - # Convert push operations - # May need to insert a new operation - # Which leads to return a list of operation - if isinstance(t, ArrayType): - if ir.function_name == 'push' and len(ir.arguments) == 1: - return convert_to_push(ir, node) - - elif isinstance(ir, Index): - if isinstance(ir.variable_left.type, MappingType): - ir.lvalue.set_type(ir.variable_left.type.type_to) - elif isinstance(ir.variable_left.type, ArrayType): - ir.lvalue.set_type(ir.variable_left.type.type) - - elif isinstance(ir, InitArray): - length = len(ir.init_values) - t = ir.init_values[0].type - ir.lvalue.set_type(ArrayType(t, length)) - elif isinstance(ir, InternalCall): - # if its not a tuple, return a singleton - return_type = ir.function.return_type - if return_type: - if len(return_type) == 1: - ir.lvalue.set_type(return_type[0]) - elif len(return_type)>1: - ir.lvalue.set_type(return_type) - else: - ir.lvalue = None - elif isinstance(ir, InternalDynamicCall): - # if its not a tuple, return a singleton - return_type = ir.function_type.return_type - if return_type: - if len(return_type) == 1: - ir.lvalue.set_type(return_type[0]) - else: - ir.lvalue.set_type(return_type) - else: - ir.lvalue = None - elif isinstance(ir, LowLevelCall): - # Call are not yet converted - # This should not happen - assert False - elif isinstance(ir, Member): - # TODO we should convert the reference to a temporary if the member is a length or a balance - if ir.variable_right == 'length' and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)): - length = Length(ir.variable_left, ir.lvalue) - length.lvalue.points_to = ir.variable_left - return length - if ir.variable_right == 'balance'and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, ElementaryType): - return Balance(ir.variable_left, ir.lvalue) - left = ir.variable_left - if isinstance(left, (Variable, SolidityVariable)): - t = ir.variable_left.type - elif isinstance(left, (Contract, Enum, Structure)): - t = UserDefinedType(left) - # can be None due to temporary operation - if t: - if isinstance(t, UserDefinedType): - # UserdefinedType - type_t = t.type - if isinstance(type_t, Enum): - ir.lvalue.set_type(t) - elif isinstance(type_t, Structure): - elems = type_t.elems - for elem in elems: - if elem == ir.variable_right: - ir.lvalue.set_type(elems[elem].type) - else: - assert isinstance(type_t, Contract) - elif isinstance(ir, NewArray): - ir.lvalue.set_type(ir.array_type) - elif isinstance(ir, NewContract): - contract = node.slither.get_contract_from_name(ir.contract_name) - ir.lvalue.set_type(UserDefinedType(contract)) - elif isinstance(ir, NewElementaryType): - ir.lvalue.set_type(ir.type) - elif isinstance(ir, NewStructure): - ir.lvalue.set_type(UserDefinedType(ir.structure)) - elif isinstance(ir, Push): - # No change required - pass - elif isinstance(ir, Send): - ir.lvalue.set_type(ElementaryType('bool')) - elif isinstance(ir, SolidityCall): - return_type = ir.function.return_type - if len(return_type) == 1: - ir.lvalue.set_type(return_type[0]) - elif len(return_type)>1: - ir.lvalue.set_type(return_type) - elif isinstance(ir, TypeConversion): - ir.lvalue.set_type(ir.type) - elif isinstance(ir, Unary): - ir.lvalue.set_type(ir.rvalue.type) - elif isinstance(ir, Unpack): - types = ir.tuple.type.type - idx = ir.index - t = types[idx] - ir.lvalue.set_type(t) - elif isinstance(ir, (Argument, TmpCall, TmpNewArray, TmpNewContract, TmpNewStructure, TmpNewElementaryType)): - # temporary operation; they will be removed - pass - else: - logger.error('Not handling {} during type propgation'.format(type(ir))) - exit(-1) - -def apply_ir_heuristics(irs, node): - """ - Apply a set of heuristic to improve slithIR - """ - - irs = integrate_value_gas(irs) - - irs = propage_type_and_convert_call(irs, node) - irs = remove_unused(irs) - find_references_origin(irs) - - - return irs +# endregion +################################################################################### +################################################################################### +# region Points to operation +################################################################################### +################################################################################### def find_references_origin(irs): """ @@ -619,13 +743,12 @@ def find_references_origin(irs): if isinstance(ir, (Index, Member)): ir.lvalue.points_to = ir.variable_left -def is_temporary(ins): - return isinstance(ins, (Argument, - TmpNewElementaryType, - TmpNewContract, - TmpNewArray, - TmpNewStructure)) - +# endregion +################################################################################### +################################################################################### +# region Operation filtering +################################################################################### +################################################################################### def remove_temporary(result): result = [ins for ins in result if not isinstance(ins, (Argument, @@ -668,88 +791,24 @@ def remove_unused(result): result = [i for i in result if not i in to_remove] return result +# endregion +################################################################################### +################################################################################### +# region Heuristics selection +################################################################################### +################################################################################### +def apply_ir_heuristics(irs, node): + """ + Apply a set of heuristic to improve slithIR + """ -def extract_tmp_call(ins): - assert isinstance(ins, TmpCall) - - if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): - call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type) - call.call_id = ins.call_id - return call - if isinstance(ins.ori, Member): - if isinstance(ins.ori.variable_left, Contract): - st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right) - if st: - op = NewStructure(st, ins.lvalue) - op.call_id = ins.call_id - return op - libcall = LibraryCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) - libcall.call_id = ins.call_id - return libcall - msgcall = HighLevelCall(ins.ori.variable_left, ins.ori.variable_right, ins.nbr_arguments, ins.lvalue, ins.type_call) - msgcall.call_id = ins.call_id - return msgcall - - if isinstance(ins.ori, TmpCall): - r = extract_tmp_call(ins.ori) - return r - if isinstance(ins.called, SolidityVariableComposed): - if str(ins.called) == 'block.blockhash': - ins.called = SolidityFunction('blockhash(uint256)') - elif str(ins.called) == 'this.balance': - return SolidityCall(SolidityFunction('this.balance()'), ins.nbr_arguments, ins.lvalue, ins.type_call) - - if isinstance(ins.called, SolidityFunction): - return SolidityCall(ins.called, ins.nbr_arguments, ins.lvalue, ins.type_call) - - if isinstance(ins.ori, TmpNewElementaryType): - return NewElementaryType(ins.ori.type, ins.lvalue) - - if isinstance(ins.ori, TmpNewContract): - op = NewContract(Constant(ins.ori.contract_name), ins.lvalue) - op.call_id = ins.call_id - return op - - if isinstance(ins.ori, TmpNewArray): - return NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue) - - if isinstance(ins.called, Structure): - op = NewStructure(ins.called, ins.lvalue) - op.call_id = ins.call_id - return op - - if isinstance(ins.called, Event): - return EventCall(ins.called.name) - - - raise Exception('Not extracted {} {}'.format(type(ins.called), ins)) - -def convert_expression(expression, node): - # handle standlone expression - # such as return true; - from slither.core.cfg.node import NodeType - - if isinstance(expression, Literal) and node.type in [NodeType.IF, NodeType.IFLOOP]: - result = [Condition(Constant(expression.value))] - return result - if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]: - result = [Condition(expression.value)] - return result - + irs = integrate_value_gas(irs) - visitor = ExpressionToSlithIR(expression, node) - result = visitor.result() + irs = propagate_type_and_convert_call(irs, node) + irs = remove_unused(irs) + find_references_origin(irs) - result = apply_ir_heuristics(result, node) - if result: - if node.type in [NodeType.IF, NodeType.IFLOOP]: - assert isinstance(result[-1], (OperationWithLValue)) - result.append(Condition(result[-1].lvalue)) - elif node.type == NodeType.RETURN: - # May return None - if isinstance(result[-1], (OperationWithLValue)): - result.append(Return(result[-1].lvalue)) + return irs - return result diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 30db0a9d33..3b16555bb4 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -3,28 +3,32 @@ from slither.core.cfg.node import NodeType from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable -from slither.slithir.operations import (Assignment, Balance, Binary, - BinaryType, Condition, Delete, - EventCall, HighLevelCall, Index, - InitArray, InternalCall, +from slither.slithir.operations import (Assignment, Balance, Binary, Condition, + Delete, EventCall, HighLevelCall, + Index, InitArray, InternalCall, InternalDynamicCall, Length, LibraryCall, LowLevelCall, Member, NewArray, NewContract, NewElementaryType, NewStructure, - OperationWithLValue, Phi, PhiCallback, Push, Return, - Send, SolidityCall, Transfer, - TypeConversion, Unary, Unpack) -from slither.slithir.variables import (Constant, LocalIRVariable, StateIRVariable, - ReferenceVariable, TemporaryVariable, + OperationWithLValue, Phi, PhiCallback, + Push, Return, Send, SolidityCall, + Transfer, TypeConversion, Unary, + Unpack) +from slither.slithir.variables import (LocalIRVariable, ReferenceVariable, + StateIRVariable, TemporaryVariable, TupleVariable) logger = logging.getLogger('SSA_Conversion') - +################################################################################### +################################################################################### +# region SlihtIR variables to SSA +################################################################################### +################################################################################### def transform_slithir_vars_to_ssa(function): """ - Transform slithIR vars to SSA + Transform slithIR vars to SSA (TemporaryVariable, ReferenceVariable, TupleVariable) """ variables = [] for node in function.nodes: @@ -42,6 +46,12 @@ def transform_slithir_vars_to_ssa(function): for idx in range(len(tuple_variables)): tuple_variables[idx].index = idx +################################################################################### +################################################################################### +# region SSA conversion +################################################################################### +################################################################################### + def add_ssa_ir(function, all_state_variables_instances): ''' Add SSA version of the IR @@ -134,98 +144,6 @@ def add_ssa_ir(function, all_state_variables_instances): all_state_variables_instances, init_local_variables_instances) - -def last_name(n, var, init_vars): - candidates = [] - # Todo optimize by creating a variables_ssa_written attribute - for ir_ssa in n.irs_ssa: - if isinstance(ir_ssa, OperationWithLValue): - lvalue = ir_ssa.lvalue - while isinstance(lvalue, ReferenceVariable): - lvalue = lvalue.points_to - if lvalue and lvalue.name == var.name: - candidates.append(lvalue) - if n.variable_declaration and n.variable_declaration.name == var.name: - candidates.append(LocalIRVariable(n.variable_declaration)) - if n.type == NodeType.ENTRYPOINT: - if var.name in init_vars: - candidates.append(init_vars[var.name]) - assert candidates - return max(candidates, key=lambda v: v.index) - -def update_lvalue(new_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances): - if isinstance(new_ir, OperationWithLValue): - lvalue = new_ir.lvalue - update_through_ref = False - if isinstance(new_ir, (Assignment, Binary)): - if isinstance(lvalue, ReferenceVariable): - update_through_ref = True - while isinstance(lvalue, ReferenceVariable): - lvalue = lvalue.points_to - if isinstance(lvalue, (LocalIRVariable, StateIRVariable)): - if isinstance(lvalue, LocalIRVariable): - new_var = LocalIRVariable(lvalue) - new_var.index = all_local_variables_instances[lvalue.name].index + 1 - all_local_variables_instances[lvalue.name] = new_var - local_variables_instances[lvalue.name] = new_var - else: - new_var = StateIRVariable(lvalue) - new_var.index = all_state_variables_instances[lvalue.canonical_name].index + 1 - all_state_variables_instances[lvalue.canonical_name] = new_var - state_variables_instances[lvalue.canonical_name] = new_var - if update_through_ref: - phi_operation = Phi(new_var, {node}) - phi_operation.rvalues = [lvalue] - node.add_ssa_ir(phi_operation) - if not isinstance(new_ir.lvalue, ReferenceVariable): - new_ir.lvalue = new_var - else: - to_update = new_ir.lvalue - while isinstance(to_update.points_to, ReferenceVariable): - to_update = to_update.points_to - to_update.points_to = new_var - -def is_used_later(initial_node, variable): - # TODO: does not handle the case where its read and written in the declaration node - # It can be problematic if this happens in a loop/if structure - # Ex: - # for(;true;){ - # if(true){ - # uint a = a; - # } - # .. - to_explore = {initial_node} - explored = set() - - while to_explore: - node = to_explore.pop() - explored.add(node) - if isinstance(variable, LocalVariable): - if any(v.name == variable.name for v in node.local_variables_read): - return True - if any(v.name == variable.name for v in node.local_variables_written): - return False - if isinstance(variable, StateVariable): - if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_read): - return True - if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_written): - return False - for son in node.sons: - if not son in explored: - to_explore.add(son) - - return False - - -def initiate_all_local_variables_instances(nodes, local_variables_instances, all_local_variables_instances): - for node in nodes: - if node.variable_declaration: - new_var = LocalIRVariable(node.variable_declaration) - if new_var.name in all_local_variables_instances: - new_var.index = all_local_variables_instances[new_var.name].index + 1 - local_variables_instances[node.variable_declaration.name] = new_var - all_local_variables_instances[node.variable_declaration.name] = new_var - def generate_ssa_irs(node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances, init_local_variables_instances, visited): if node in visited: @@ -308,6 +226,129 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan init_local_variables_instances, visited) +# endregion +################################################################################### +################################################################################### +# region Helpers +################################################################################### +################################################################################### + +def last_name(n, var, init_vars): + candidates = [] + # Todo optimize by creating a variables_ssa_written attribute + for ir_ssa in n.irs_ssa: + if isinstance(ir_ssa, OperationWithLValue): + lvalue = ir_ssa.lvalue + while isinstance(lvalue, ReferenceVariable): + lvalue = lvalue.points_to + if lvalue and lvalue.name == var.name: + candidates.append(lvalue) + if n.variable_declaration and n.variable_declaration.name == var.name: + candidates.append(LocalIRVariable(n.variable_declaration)) + if n.type == NodeType.ENTRYPOINT: + if var.name in init_vars: + candidates.append(init_vars[var.name]) + assert candidates + return max(candidates, key=lambda v: v.index) + +def is_used_later(initial_node, variable): + # TODO: does not handle the case where its read and written in the declaration node + # It can be problematic if this happens in a loop/if structure + # Ex: + # for(;true;){ + # if(true){ + # uint a = a; + # } + # .. + to_explore = {initial_node} + explored = set() + + while to_explore: + node = to_explore.pop() + explored.add(node) + if isinstance(variable, LocalVariable): + if any(v.name == variable.name for v in node.local_variables_read): + return True + if any(v.name == variable.name for v in node.local_variables_written): + return False + if isinstance(variable, StateVariable): + if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_read): + return True + if any(v.name == variable.name and v.contract == variable.contract for v in node.state_variables_written): + return False + for son in node.sons: + if not son in explored: + to_explore.add(son) + + return False + + +# endregion +################################################################################### +################################################################################### +# region Update operation +################################################################################### +################################################################################### + +def update_lvalue(new_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances): + if isinstance(new_ir, OperationWithLValue): + lvalue = new_ir.lvalue + update_through_ref = False + if isinstance(new_ir, (Assignment, Binary)): + if isinstance(lvalue, ReferenceVariable): + update_through_ref = True + while isinstance(lvalue, ReferenceVariable): + lvalue = lvalue.points_to + if isinstance(lvalue, (LocalIRVariable, StateIRVariable)): + if isinstance(lvalue, LocalIRVariable): + new_var = LocalIRVariable(lvalue) + new_var.index = all_local_variables_instances[lvalue.name].index + 1 + all_local_variables_instances[lvalue.name] = new_var + local_variables_instances[lvalue.name] = new_var + else: + new_var = StateIRVariable(lvalue) + new_var.index = all_state_variables_instances[lvalue.canonical_name].index + 1 + all_state_variables_instances[lvalue.canonical_name] = new_var + state_variables_instances[lvalue.canonical_name] = new_var + if update_through_ref: + phi_operation = Phi(new_var, {node}) + phi_operation.rvalues = [lvalue] + node.add_ssa_ir(phi_operation) + if not isinstance(new_ir.lvalue, ReferenceVariable): + new_ir.lvalue = new_var + else: + to_update = new_ir.lvalue + while isinstance(to_update.points_to, ReferenceVariable): + to_update = to_update.points_to + to_update.points_to = new_var + + +# endregion +################################################################################### +################################################################################### +# region Initialization +################################################################################### +################################################################################### + +def initiate_all_local_variables_instances(nodes, local_variables_instances, all_local_variables_instances): + for node in nodes: + if node.variable_declaration: + new_var = LocalIRVariable(node.variable_declaration) + if new_var.name in all_local_variables_instances: + new_var.index = all_local_variables_instances[new_var.name].index + 1 + local_variables_instances[node.variable_declaration.name] = new_var + all_local_variables_instances[node.variable_declaration.name] = new_var + + + +# endregion +################################################################################### +################################################################################### +# region Phi Operations +################################################################################### +################################################################################### + + def fix_phi_rvalues_and_storage_ref(node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances, init_local_variables_instances): for ir in node.irs_ssa: if isinstance(ir, (Phi)) and not ir.rvalues: @@ -367,6 +408,14 @@ def add_phi_origins(node, local_variables_definition, state_variables_definition for succ in node.dominator_successors: add_phi_origins(succ, local_variables_definition, state_variables_definition) + +# endregion +################################################################################### +################################################################################### +# region IR copy +################################################################################### +################################################################################### + def copy_ir(ir, local_variables_instances, state_variables_instances, temporary_variables_instances, reference_variables_instances, all_local_variables_instances): ''' Args: @@ -591,4 +640,4 @@ def traversal(values): logger.error('Impossible ir copy on {} ({})'.format(ir, type(ir))) exit(-1) - +# endregion diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index c895a3eb25..953210bf5a 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -2,16 +2,13 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.enum import Enum - -from slither.solc_parsing.declarations.structure import StructureSolc +from slither.slithir.variables import StateIRVariable from slither.solc_parsing.declarations.event import EventSolc -from slither.solc_parsing.declarations.modifier import ModifierSolc from slither.solc_parsing.declarations.function import FunctionSolc - -from slither.solc_parsing.variables.state_variable import StateVariableSolc +from slither.solc_parsing.declarations.modifier import ModifierSolc +from slither.solc_parsing.declarations.structure import StructureSolc from slither.solc_parsing.solidity_types.type_parsing import parse_type - -from slither.slithir.variables import StateIRVariable +from slither.solc_parsing.variables.state_variable import StateVariableSolc logger = logging.getLogger("ContractSolcParsing") @@ -54,10 +51,25 @@ def __init__(self, slitherSolc, data): self._parse_contract_items() + ################################################################################### + ################################################################################### + # region General Properties + ################################################################################### + ################################################################################### + @property def is_analyzed(self): return self._is_analyzed + def set_is_analyzed(self, is_analyzed): + self._is_analyzed = is_analyzed + + ################################################################################### + ################################################################################### + # region AST + ################################################################################### + ################################################################################### + def get_key(self): return self.slither.get_key() @@ -74,8 +86,12 @@ def remapping(self): def is_compact_ast(self): return self.slither.is_compact_ast - def set_is_analyzed(self, is_analyzed): - self._is_analyzed = is_analyzed + # endregion + ################################################################################### + ################################################################################### + # region SlithIR + ################################################################################### + ################################################################################### def _parse_contract_info(self): if self.is_compact_ast: @@ -174,70 +190,6 @@ def _parse_contract_items(self): exit(-1) return - def analyze_using_for(self): - for father in self.inheritance: - self._using_for.update(father.using_for) - - if self.is_compact_ast: - for using_for in self._usingForNotParsed: - lib_name = parse_type(using_for['libraryName'], self) - if 'typeName' in using_for and using_for['typeName']: - type_name = parse_type(using_for['typeName'], self) - else: - type_name = '*' - if not type_name in self._using_for: - self.using_for[type_name] = [] - self._using_for[type_name].append(lib_name) - else: - for using_for in self._usingForNotParsed: - children = using_for[self.get_children()] - assert children and len(children) <= 2 - if len(children) == 2: - new = parse_type(children[0], self) - old = parse_type(children[1], self) - else: - new = parse_type(children[0], self) - old = '*' - if not old in self._using_for: - self.using_for[old] = [] - self._using_for[old].append(new) - self._usingForNotParsed = [] - - def analyze_enums(self): - - for father in self.inheritance: - self._enums.update(father.enums_as_dict()) - - for enum in self._enumsNotParsed: - # for enum, we can parse and analyze it - # at the same time - self._analyze_enum(enum) - self._enumsNotParsed = None - - def _analyze_enum(self, enum): - # Enum can be parsed in one pass - if self.is_compact_ast: - name = enum['name'] - canonicalName = enum['canonicalName'] - else: - name = enum['attributes'][self.get_key()] - if 'canonicalName' in enum['attributes']: - canonicalName = enum['attributes']['canonicalName'] - else: - canonicalName = self.name + '.' + name - values = [] - for child in enum[self.get_children('members')]: - assert child[self.get_key()] == 'EnumValue' - if self.is_compact_ast: - values.append(child['name']) - else: - values.append(child['attributes'][self.get_key()]) - - new_enum = Enum(name, canonicalName, values) - new_enum.set_contract(self) - new_enum.set_offset(enum['src'], self.slither) - self._enums[canonicalName] = new_enum - def _parse_struct(self, struct): if self.is_compact_ast: name = struct['name'] @@ -259,9 +211,6 @@ def _parse_struct(self, struct): st.set_offset(struct['src'], self.slither) self._structures[name] = st - def _analyze_struct(self, struct): - struct.analyze() - def parse_structs(self): for father in self.inheritance_reverse: self._structures.update(father.structures_as_dict()) @@ -270,24 +219,6 @@ def parse_structs(self): self._parse_struct(struct) self._structuresNotParsed = None - def analyze_structs(self): - for struct in self.structures: - self._analyze_struct(struct) - - - def analyze_events(self): - for father in self.inheritance_reverse: - self._events.update(father.events_as_dict()) - - for event_to_parse in self._eventsNotParsed: - event = EventSolc(event_to_parse, self) - event.analyze(self) - event.set_contract(self) - event.set_offset(event_to_parse['src'], self.slither) - self._events[event.full_name] = event - - self._eventsNotParsed = None - def parse_state_variables(self): for father in self.inheritance_reverse: self._variables.update(father.variables_as_dict()) @@ -299,22 +230,6 @@ def parse_state_variables(self): self._variables[var.name] = var - def analyze_constant_state_variables(self): - from slither.solc_parsing.expressions.expression_parsing import VariableNotFound - for var in self.variables: - if var.is_constant: - # cant parse constant expression based on function calls - try: - var.analyze(self) - except VariableNotFound: - pass - return - - def analyze_state_variables(self): - for var in self.variables: - var.analyze(self) - return - def _parse_modifier(self, modifier): modif = ModifierSolc(modifier, self) @@ -347,6 +262,24 @@ def parse_functions(self): return + # endregion + ################################################################################### + ################################################################################### + # region Analyze + ################################################################################### + ################################################################################### + + def analyze_content_modifiers(self): + for modifier in self.modifiers: + modifier.analyze_content() + return + + def analyze_content_functions(self): + for function in self.functions: + function.analyze_content() + + return + def analyze_params_modifiers(self): for father in self.inheritance_reverse: self._modifiers.update(father.modifiers_as_dict()) @@ -390,17 +323,114 @@ def analyze_params_functions(self): self._functions_no_params = [] return - def analyze_content_modifiers(self): - for modifier in self.modifiers: - modifier.analyze_content() + def analyze_constant_state_variables(self): + from slither.solc_parsing.expressions.expression_parsing import VariableNotFound + for var in self.variables: + if var.is_constant: + # cant parse constant expression based on function calls + try: + var.analyze(self) + except VariableNotFound: + pass return - def analyze_content_functions(self): - for function in self.functions: - function.analyze_content() - + def analyze_state_variables(self): + for var in self.variables: + var.analyze(self) return + def analyze_using_for(self): + for father in self.inheritance: + self._using_for.update(father.using_for) + + if self.is_compact_ast: + for using_for in self._usingForNotParsed: + lib_name = parse_type(using_for['libraryName'], self) + if 'typeName' in using_for and using_for['typeName']: + type_name = parse_type(using_for['typeName'], self) + else: + type_name = '*' + if not type_name in self._using_for: + self.using_for[type_name] = [] + self._using_for[type_name].append(lib_name) + else: + for using_for in self._usingForNotParsed: + children = using_for[self.get_children()] + assert children and len(children) <= 2 + if len(children) == 2: + new = parse_type(children[0], self) + old = parse_type(children[1], self) + else: + new = parse_type(children[0], self) + old = '*' + if not old in self._using_for: + self.using_for[old] = [] + self._using_for[old].append(new) + self._usingForNotParsed = [] + + def analyze_enums(self): + + for father in self.inheritance: + self._enums.update(father.enums_as_dict()) + + for enum in self._enumsNotParsed: + # for enum, we can parse and analyze it + # at the same time + self._analyze_enum(enum) + self._enumsNotParsed = None + + def _analyze_enum(self, enum): + # Enum can be parsed in one pass + if self.is_compact_ast: + name = enum['name'] + canonicalName = enum['canonicalName'] + else: + name = enum['attributes'][self.get_key()] + if 'canonicalName' in enum['attributes']: + canonicalName = enum['attributes']['canonicalName'] + else: + canonicalName = self.name + '.' + name + values = [] + for child in enum[self.get_children('members')]: + assert child[self.get_key()] == 'EnumValue' + if self.is_compact_ast: + values.append(child['name']) + else: + values.append(child['attributes'][self.get_key()]) + + new_enum = Enum(name, canonicalName, values) + new_enum.set_contract(self) + new_enum.set_offset(enum['src'], self.slither) + self._enums[canonicalName] = new_enum + + def _analyze_struct(self, struct): + struct.analyze() + + def analyze_structs(self): + for struct in self.structures: + self._analyze_struct(struct) + + def analyze_events(self): + for father in self.inheritance_reverse: + self._events.update(father.events_as_dict()) + + for event_to_parse in self._eventsNotParsed: + event = EventSolc(event_to_parse, self) + event.analyze(self) + event.set_contract(self) + event.set_offset(event_to_parse['src'], self.slither) + self._events[event.full_name] = event + + self._eventsNotParsed = None + + + + # endregion + ################################################################################### + ################################################################################### + # region SlithIR + ################################################################################### + ################################################################################### def convert_expression_to_slithir(self): for func in self.functions + self.modifiers: @@ -442,5 +472,14 @@ def fix_phi(self): func.fix_phi(last_state_variables_instances, initial_state_variables_instances) + # endregion + ################################################################################### + ################################################################################### + # region Built in definitions + ################################################################################### + ################################################################################### + def __hash__(self): return self._id + + # endregion diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 73974cf07c..45f90a72f4 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -1,20 +1,19 @@ """ - Event module """ import logging from slither.core.cfg.node import NodeType, link_nodes +from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function from slither.core.dominators.utils import (compute_dominance_frontier, compute_dominators) from slither.core.expressions import AssignmentOperation from slither.core.variables.state_variable import StateVariable -from slither.slithir.operations import (Assignment, HighLevelCall, - InternalCall, InternalDynamicCall, - LowLevelCall, OperationWithLValue, Phi, - PhiCallback, LibraryCall) +from slither.slithir.operations import (InternalCall, OperationWithLValue, Phi, + PhiCallback) from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa -from slither.slithir.variables import LocalIRVariable, ReferenceVariable +from slither.slithir.variables import (Constant, ReferenceVariable, + StateIRVariable) from slither.solc_parsing.cfg.node import NodeSolc from slither.solc_parsing.expressions.expression_parsing import \ parse_expression @@ -24,18 +23,14 @@ from slither.solc_parsing.variables.variable_declaration import \ MultipleVariablesDeclaration from slither.utils.expression_manipulations import SplitTernaryExpression +from slither.utils.utils import unroll from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.has_conditional import HasConditional -from slither.core.declarations.contract import Contract - -from slither.slithir.variables import StateIRVariable, LocalIRVariable, Constant -from slither.utils.utils import unroll logger = logging.getLogger("FunctionSolc") class FunctionSolc(Function): """ - Event class """ # elems = [(type, name)] @@ -60,6 +55,12 @@ def __init__(self, function, contract): # which is only possible with solc > 0.5 self._variables_renamed = {} + ################################################################################### + ################################################################################### + # region AST format + ################################################################################### + ################################################################################### + def get_key(self): return self.slither.get_key() @@ -72,6 +73,13 @@ def get_children(self, key): def is_compact_ast(self): return self.slither.is_compact_ast + # endregion + ################################################################################### + ################################################################################### + # region Variables + ################################################################################### + ################################################################################### + @property def variables_renamed(self): return self._variables_renamed @@ -89,6 +97,12 @@ def _add_local_variable(self, local_var): self._variables_renamed[local_var.reference_id] = local_var self._variables[local_var.name] = local_var + # endregion + ################################################################################### + ################################################################################### + # region Analyses + ################################################################################### + ################################################################################### def _analyze_attributes(self): if self.is_compact_ast: @@ -133,6 +147,77 @@ def _analyze_attributes(self): if 'payable' in attributes: self._payable = attributes['payable'] + def analyze_params(self): + # Can be re-analyzed due to inheritance + if self._params_was_analyzed: + return + + self._params_was_analyzed = True + + self._analyze_attributes() + + if self.is_compact_ast: + params = self._functionNotParsed['parameters'] + returns = self._functionNotParsed['returnParameters'] + else: + children = self._functionNotParsed[self.get_children('children')] + params = children[0] + returns = children[1] + + if params: + self._parse_params(params) + if returns: + self._parse_returns(returns) + + def analyze_content(self): + if self._content_was_analyzed: + return + + self._content_was_analyzed = True + + if self.is_compact_ast: + body = self._functionNotParsed['body'] + + if body and body[self.get_key()] == 'Block': + self._is_implemented = True + self._parse_cfg(body) + + for modifier in self._functionNotParsed['modifiers']: + self._parse_modifier(modifier) + + else: + children = self._functionNotParsed[self.get_children('children')] + self._is_implemented = False + for child in children[2:]: + if child[self.get_key()] == 'Block': + self._is_implemented = True + self._parse_cfg(child) + + # Parse modifier after parsing all the block + # In the case a local variable is used in the modifier + for child in children[2:]: + if child[self.get_key()] == 'ModifierInvocation': + self._parse_modifier(child) + + for local_vars in self.variables: + local_vars.analyze(self) + + for node in self.nodes: + node.analyze_expressions(self) + + self._filter_ternary() + self._remove_alone_endif() + + + + + # endregion + ################################################################################### + ################################################################################### + # region Nodes + ################################################################################### + ################################################################################### + def _new_node(self, node_type, src): node = NodeSolc(node_type, self._counter_nodes) node.set_offset(src, self.slither) @@ -141,6 +226,13 @@ def _new_node(self, node_type, src): self._nodes.append(node) return node + # endregion + ################################################################################### + ################################################################################### + # region Parsing function + ################################################################################### + ################################################################################### + def _parse_if(self, ifStatement, node): # IfStatement = 'if' '(' Expression ')' Statement ( 'else' Statement )? falseStatement = None @@ -662,6 +754,13 @@ def _parse_cfg(self, cfg): self._remove_incorrect_edges() self._remove_alone_endif() + # endregion + ################################################################################### + ################################################################################### + # region Loops + ################################################################################### + ################################################################################### + def _find_end_loop(self, node, visited, counter): # counter allows to explore nested loop if node in visited: @@ -723,48 +822,6 @@ def _fix_continue_node(self, node): node.set_sons([start_node]) start_node.add_father(node) - def _remove_incorrect_edges(self): - for node in self._nodes: - if node.type in [NodeType.RETURN, NodeType.THROW]: - for son in node.sons: - son.remove_father(node) - node.set_sons([]) - if node.type in [NodeType.BREAK]: - self._fix_break_node(node) - if node.type in [NodeType.CONTINUE]: - self._fix_continue_node(node) - - def _remove_alone_endif(self): - """ - Can occur on: - if(..){ - return - } - else{ - return - } - - Iterate until a fix point to remove the ENDIF node - creates on the following pattern - if(){ - return - } - else if(){ - return - } - """ - prev_nodes = [] - while set(prev_nodes) != set(self.nodes): - prev_nodes = self.nodes - to_remove = [] - for node in self.nodes: - if node.type == NodeType.ENDIF and not node.fathers: - for son in node.sons: - son.remove_father(node) - node.set_sons([]) - to_remove.append(node) - self._nodes = [n for n in self.nodes if not n in to_remove] -# def _parse_params(self, params): assert params[self.get_key()] == 'ParameterList' @@ -824,67 +881,61 @@ def _parse_modifier(self, modifier): elif isinstance(m, Contract): self._explicit_base_constructor_calls.append(m) + # endregion + ################################################################################### + ################################################################################### + # region Edges + ################################################################################### + ################################################################################### - def analyze_params(self): - # Can be re-analyzed due to inheritance - if self._params_was_analyzed: - return - - self._params_was_analyzed = True - - self._analyze_attributes() - - if self.is_compact_ast: - params = self._functionNotParsed['parameters'] - returns = self._functionNotParsed['returnParameters'] - else: - children = self._functionNotParsed[self.get_children('children')] - params = children[0] - returns = children[1] - - if params: - self._parse_params(params) - if returns: - self._parse_returns(returns) - - def analyze_content(self): - if self._content_was_analyzed: - return - - self._content_was_analyzed = True - - if self.is_compact_ast: - body = self._functionNotParsed['body'] - - if body and body[self.get_key()] == 'Block': - self._is_implemented = True - self._parse_cfg(body) - - for modifier in self._functionNotParsed['modifiers']: - self._parse_modifier(modifier) - - else: - children = self._functionNotParsed[self.get_children('children')] - self._is_implemented = False - for child in children[2:]: - if child[self.get_key()] == 'Block': - self._is_implemented = True - self._parse_cfg(child) - - # Parse modifier after parsing all the block - # In the case a local variable is used in the modifier - for child in children[2:]: - if child[self.get_key()] == 'ModifierInvocation': - self._parse_modifier(child) + def _remove_incorrect_edges(self): + for node in self._nodes: + if node.type in [NodeType.RETURN, NodeType.THROW]: + for son in node.sons: + son.remove_father(node) + node.set_sons([]) + if node.type in [NodeType.BREAK]: + self._fix_break_node(node) + if node.type in [NodeType.CONTINUE]: + self._fix_continue_node(node) - for local_vars in self.variables: - local_vars.analyze(self) + def _remove_alone_endif(self): + """ + Can occur on: + if(..){ + return + } + else{ + return + } - for node in self.nodes: - node.analyze_expressions(self) + Iterate until a fix point to remove the ENDIF node + creates on the following pattern + if(){ + return + } + else if(){ + return + } + """ + prev_nodes = [] + while set(prev_nodes) != set(self.nodes): + prev_nodes = self.nodes + to_remove = [] + for node in self.nodes: + if node.type == NodeType.ENDIF and not node.fathers: + for son in node.sons: + son.remove_father(node) + node.set_sons([]) + to_remove.append(node) + self._nodes = [n for n in self.nodes if not n in to_remove] - self._filter_ternary() - self._remove_alone_endif() + # endregion + ################################################################################### + ################################################################################### + # region Ternary + ################################################################################### + ################################################################################### def _filter_ternary(self): ternary_found = True @@ -902,6 +953,64 @@ def _filter_ternary(self): ternary_found = True break + def split_ternary_node(self, node, condition, true_expr, false_expr): + condition_node = self._new_node(NodeType.IF, node.source_mapping) + condition_node.add_expression(condition) + condition_node.analyze_expressions(self) + + if node.type == NodeType.VARIABLE: + condition_node.add_variable_declaration(node.variable_declaration) + + true_node = self._new_node(NodeType.EXPRESSION, node.source_mapping) + if node.type == NodeType.VARIABLE: + assert isinstance(true_expr, AssignmentOperation) + #true_expr = true_expr.expression_right + elif node.type == NodeType.RETURN: + true_node.type = NodeType.RETURN + true_node.add_expression(true_expr) + true_node.analyze_expressions(self) + + false_node = self._new_node(NodeType.EXPRESSION, node.source_mapping) + if node.type == NodeType.VARIABLE: + assert isinstance(false_expr, AssignmentOperation) + elif node.type == NodeType.RETURN: + false_node.type = NodeType.RETURN + #false_expr = false_expr.expression_right + false_node.add_expression(false_expr) + false_node.analyze_expressions(self) + + endif_node = self._new_node(NodeType.ENDIF, node.source_mapping) + + for father in node.fathers: + father.remove_son(node) + father.add_son(condition_node) + condition_node.add_father(father) + + for son in node.sons: + son.remove_father(node) + son.add_father(endif_node) + endif_node.add_son(son) + + link_nodes(condition_node, true_node) + link_nodes(condition_node, false_node) + + + if not true_node.type in [NodeType.THROW, NodeType.RETURN]: + link_nodes(true_node, endif_node) + if not false_node.type in [NodeType.THROW, NodeType.RETURN]: + link_nodes(false_node, endif_node) + + self._nodes = [n for n in self._nodes if n.node_id != node.node_id] + + + + # endregion + ################################################################################### + ################################################################################### + # region SlithIr and SSA + ################################################################################### + ################################################################################### + def get_last_ssa_state_variables_instances(self): if not self.is_implemented: return dict() @@ -1006,53 +1115,3 @@ def update_read_write_using_ssa(self): node.update_read_write_using_ssa() self._analyze_read_write() - def split_ternary_node(self, node, condition, true_expr, false_expr): - condition_node = self._new_node(NodeType.IF, node.source_mapping) - condition_node.add_expression(condition) - condition_node.analyze_expressions(self) - - if node.type == NodeType.VARIABLE: - condition_node.add_variable_declaration(node.variable_declaration) - - true_node = self._new_node(NodeType.EXPRESSION, node.source_mapping) - if node.type == NodeType.VARIABLE: - assert isinstance(true_expr, AssignmentOperation) - #true_expr = true_expr.expression_right - elif node.type == NodeType.RETURN: - true_node.type = NodeType.RETURN - true_node.add_expression(true_expr) - true_node.analyze_expressions(self) - - false_node = self._new_node(NodeType.EXPRESSION, node.source_mapping) - if node.type == NodeType.VARIABLE: - assert isinstance(false_expr, AssignmentOperation) - elif node.type == NodeType.RETURN: - false_node.type = NodeType.RETURN - #false_expr = false_expr.expression_right - false_node.add_expression(false_expr) - false_node.analyze_expressions(self) - - endif_node = self._new_node(NodeType.ENDIF, node.source_mapping) - - for father in node.fathers: - father.remove_son(node) - father.add_son(condition_node) - condition_node.add_father(father) - - for son in node.sons: - son.remove_father(node) - son.add_father(endif_node) - endif_node.add_son(son) - - link_nodes(condition_node, true_node) - link_nodes(condition_node, false_node) - - - if not true_node.type in [NodeType.THROW, NodeType.RETURN]: - link_nodes(true_node, endif_node) - if not false_node.type in [NodeType.THROW, NodeType.RETURN]: - link_nodes(false_node, endif_node) - - self._nodes = [n for n in self._nodes if n.node_id != node.node_id] - - diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 14bb73e039..592b7c5b8d 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -1,38 +1,59 @@ import logging import re -from slither.core.expressions.unary_operation import UnaryOperation, UnaryOperationType -from slither.core.expressions.binary_operation import BinaryOperation, BinaryOperationType -from slither.core.expressions.literal import Literal + +from slither.core.declarations.contract import Contract +from slither.core.declarations.function import Function +from slither.core.declarations.solidity_variables import (SOLIDITY_FUNCTIONS, + SOLIDITY_VARIABLES, + SOLIDITY_VARIABLES_COMPOSED, + SolidityFunction, + SolidityVariable, + SolidityVariableComposed) +from slither.core.expressions.assignment_operation import (AssignmentOperation, + AssignmentOperationType) +from slither.core.expressions.binary_operation import (BinaryOperation, + BinaryOperationType) +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import \ + ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import \ + ElementaryTypeNameExpression from slither.core.expressions.identifier import Identifier -from slither.core.expressions.super_identifier import SuperIdentifier from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal from slither.core.expressions.member_access import MemberAccess -from slither.core.expressions.tuple_expression import TupleExpression -from slither.core.expressions.conditional_expression import ConditionalExpression -from slither.core.expressions.assignment_operation import AssignmentOperation, AssignmentOperationType -from slither.core.expressions.type_conversion import TypeConversion -from slither.core.expressions.call_expression import CallExpression -from slither.core.expressions.super_call_expression import SuperCallExpression from slither.core.expressions.new_array import NewArray from slither.core.expressions.new_contract import NewContract from slither.core.expressions.new_elementary_type import NewElementaryType -from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression - -from slither.solc_parsing.solidity_types.type_parsing import parse_type, UnknownType - -from slither.core.declarations.contract import Contract -from slither.core.declarations.function import Function - -from slither.core.declarations.solidity_variables import SOLIDITY_VARIABLES, SOLIDITY_FUNCTIONS, SOLIDITY_VARIABLES_COMPOSED -from slither.core.declarations.solidity_variables import SolidityVariable, SolidityFunction, SolidityVariableComposed, solidity_function_signature +from slither.core.expressions.super_call_expression import SuperCallExpression +from slither.core.expressions.super_identifier import SuperIdentifier +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import (UnaryOperation, + UnaryOperationType) +from slither.core.solidity_types import (ArrayType, ElementaryType, + FunctionType, MappingType) +from slither.solc_parsing.solidity_types.type_parsing import (UnknownType, + parse_type) -from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, FunctionType +logger = logging.getLogger("ExpressionParsing") -logger = logging.getLogger("ExpressionParsing") +################################################################################### +################################################################################### +# region Exception +################################################################################### +################################################################################### class VariableNotFound(Exception): pass +# endregion +################################################################################### +################################################################################### +# region Helpers +################################################################################### +################################################################################### + def get_pointer_name(variable): curr_type = variable.type while(isinstance(curr_type, (ArrayType, MappingType))): @@ -135,6 +156,92 @@ def find_variable(var_name, caller_context, referenced_declaration=None): raise VariableNotFound('Variable not found: {}'.format(var_name)) +# endregion +################################################################################### +################################################################################### +# region Filtering +################################################################################### +################################################################################### + +def filter_name(value): + value = value.replace(' memory', '') + value = value.replace(' storage', '') + value = value.replace(' external', '') + value = value.replace(' internal', '') + value = value.replace('struct ', '') + value = value.replace('contract ', '') + value = value.replace('enum ', '') + value = value.replace(' ref', '') + value = value.replace(' pointer', '') + value = value.replace(' pure', '') + value = value.replace(' view', '') + value = value.replace(' constant', '') + value = value.replace(' payable', '') + value = value.replace('function (', 'function(') + value = value.replace('returns (', 'returns(') + + # remove the text remaining after functio(...) + # which should only be ..returns(...) + # nested parenthesis so we use a system of counter on parenthesis + idx = value.find('(') + if idx: + counter = 1 + max_idx = len(value) + while counter: + assert idx < max_idx + idx = idx +1 + if value[idx] == '(': + counter += 1 + elif value[idx] == ')': + counter -= 1 + value = value[:idx+1] + return value + +# endregion +################################################################################### +################################################################################### +# region Conversion +################################################################################### +################################################################################### + +def convert_subdenomination(value, sub): + if sub is None: + return value + # to allow 0.1 ether conversion + if value[0:2] == "0x": + value = float(int(value, 16)) + else: + value = float(value) + if sub == 'wei': + return int(value) + if sub == 'szabo': + return int(value * int(1e12)) + if sub == 'finney': + return int(value * int(1e15)) + if sub == 'ether': + return int(value * int(1e18)) + if sub == 'seconds': + return int(value) + if sub == 'minutes': + return int(value * 60) + if sub == 'hours': + return int(value * 60 * 60) + if sub == 'days': + return int(value * 60 * 60 * 24) + if sub == 'weeks': + return int(value * 60 * 60 * 24 * 7) + if sub == 'years': + return int(value * 60 * 60 * 24 * 7 * 365) + + logger.error('Subdemoniation not found {}'.format(sub)) + return int(value) + +# endregion +################################################################################### +################################################################################### +# region Parsing +################################################################################### +################################################################################### def parse_call(expression, caller_context): @@ -208,72 +315,6 @@ def parse_super_name(expression, is_compact_ast): return base_name+arguments -def filter_name(value): - value = value.replace(' memory', '') - value = value.replace(' storage', '') - value = value.replace(' external', '') - value = value.replace(' internal', '') - value = value.replace('struct ', '') - value = value.replace('contract ', '') - value = value.replace('enum ', '') - value = value.replace(' ref', '') - value = value.replace(' pointer', '') - value = value.replace(' pure', '') - value = value.replace(' view', '') - value = value.replace(' constant', '') - value = value.replace(' payable', '') - value = value.replace('function (', 'function(') - value = value.replace('returns (', 'returns(') - - # remove the text remaining after functio(...) - # which should only be ..returns(...) - # nested parenthesis so we use a system of counter on parenthesis - idx = value.find('(') - if idx: - counter = 1 - max_idx = len(value) - while counter: - assert idx < max_idx - idx = idx +1 - if value[idx] == '(': - counter += 1 - elif value[idx] == ')': - counter -= 1 - value = value[:idx+1] - return value - -def convert_subdenomination(value, sub): - if sub is None: - return value - # to allow 0.1 ether conversion - if value[0:2] == "0x": - value = float(int(value, 16)) - else: - value = float(value) - if sub == 'wei': - return int(value) - if sub == 'szabo': - return int(value * int(1e12)) - if sub == 'finney': - return int(value * int(1e15)) - if sub == 'ether': - return int(value * int(1e18)) - if sub == 'seconds': - return int(value) - if sub == 'minutes': - return int(value * 60) - if sub == 'hours': - return int(value * 60 * 60) - if sub == 'days': - return int(value * 60 * 60 * 24) - if sub == 'weeks': - return int(value * 60 * 60 * 24 * 7) - if sub == 'years': - return int(value * 60 * 60 * 24 * 7 * 365) - - logger.error('Subdemoniation not found {}'.format(sub)) - return int(value) - def parse_expression(expression, caller_context): """ diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index d75a437978..c2bffdd4c2 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -26,6 +26,13 @@ def __init__(self, filename): self._is_compact_ast = False + + ################################################################################### + ################################################################################### + # region AST + ################################################################################### + ################################################################################### + def get_key(self): if self._is_compact_ast: return 'nodeType' @@ -40,6 +47,13 @@ def get_children(self): def is_compact_ast(self): return self._is_compact_ast + # endregion + ################################################################################### + ################################################################################### + # region Parsing + ################################################################################### + ################################################################################### + def _parse_contracts_from_json(self, json_data): try: data_loaded = json.loads(json_data) @@ -148,6 +162,16 @@ def _parse_source_unit(self, data, filename): source_code = f.read() self.source_code[name] = source_code + # endregion + ################################################################################### + ################################################################################### + # region Analyze + ################################################################################### + ################################################################################### + + @property + def analyzed(self): + return self._analyzed def _analyze_contracts(self): if not self._contractsNotParsed: @@ -234,11 +258,6 @@ def _analyze_contracts(self): compute_dependency(self) - # TODO refactor the following functions, and use a lambda function - - @property - def analyzed(self): - return self._analyzed def _analyze_all_enums(self, contracts_to_be_analyzed): while contracts_to_be_analyzed: @@ -362,4 +381,4 @@ def _convert_to_slithir(self): contract.fix_phi() contract.update_read_write_using_ssa() - + # endregion