diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 034879997e..52abaa5d19 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -2,6 +2,7 @@ Node module """ import logging +from typing import Optional from slither.core.children.child_function import ChildFunction from slither.core.declarations import Contract @@ -24,6 +25,7 @@ logger = logging.getLogger("Node") + ################################################################################### ################################################################################### # region NodeType @@ -31,23 +33,22 @@ ################################################################################### class NodeType: - ENTRYPOINT = 0x0 # no expression # Node with expression EXPRESSION = 0x10 # normal case - RETURN = 0x11 # RETURN may contain an expression + RETURN = 0x11 # RETURN may contain an expression IF = 0x12 - VARIABLE = 0x13 # Declaration of variable + VARIABLE = 0x13 # Declaration of variable ASSEMBLY = 0x14 IFLOOP = 0x15 # Merging nodes # Can have phi IR operation - ENDIF = 0x50 # ENDIF node source mapping points to the if/else body - STARTLOOP = 0x51 # STARTLOOP node source mapping points to the entire loop body - ENDLOOP = 0x52 # ENDLOOP node source mapping points to the entire loop body + ENDIF = 0x50 # ENDIF node source mapping points to the if/else body + STARTLOOP = 0x51 # STARTLOOP node source mapping points to the entire loop body + ENDLOOP = 0x52 # ENDLOOP node source mapping points to the entire loop body # Below the nodes have no expression # But are used to expression CFG structure @@ -69,8 +70,7 @@ class NodeType: # Use for state variable declaration OTHER_ENTRYPOINT = 0x50 - -# @staticmethod + # @staticmethod def str(t): if t == NodeType.ENTRYPOINT: return 'ENTRY_POINT' @@ -118,6 +118,7 @@ def link_nodes(n1, n2): n1.add_son(n2) n2.add_father(n1) + def insert_node(origin, node_inserted): sons = origin.sons link_nodes(origin, node_inserted) @@ -127,6 +128,7 @@ def insert_node(origin, node_inserted): link_nodes(node_inserted, son) + def recheable(node): ''' Return the set of nodes reacheable from the node @@ -167,7 +169,7 @@ def __init__(self, node_type, node_id): self._dominators = set() self._immediate_dominator = None ## Nodes of the dominators tree - #self._dom_predecessors = set() + # self._dom_predecessors = set() self._dom_successors = set() # Dominance frontier self._dominance_frontier = set() @@ -189,7 +191,7 @@ def __init__(self, node_type, node_id): self._internal_calls = [] self._solidity_calls = [] - self._high_level_calls = [] # contains library calls + self._high_level_calls = [] # contains library calls self._library_calls = [] self._low_level_calls = [] self._external_calls_as_expressions = [] @@ -207,7 +209,7 @@ def __init__(self, node_type, node_id): self._local_vars_read = [] self._local_vars_written = [] - self._slithir_vars = set() # non SSA + self._slithir_vars = set() # non SSA self._ssa_local_vars_read = [] self._ssa_local_vars_written = [] @@ -396,6 +398,7 @@ def library_calls(self): Include library calls """ return list(self._library_calls) + @property def low_level_calls(self): """ @@ -547,7 +550,6 @@ def inline_asm(self): def add_inline_asm(self, asm): self._asm_source_code = asm - # endregion ################################################################################### ################################################################################### @@ -621,6 +623,20 @@ def sons(self): """ return list(self._sons) + @property + def son_true(self) -> Optional["Node"]: + if self.type == NodeType.IF: + return self._sons[0] + else: + return None + + @property + def son_false(self) -> Optional["Node"]: + if self.type == NodeType.IF and len(self._sons) >= 1: + return self._sons[1] + else: + return None + # endregion ################################################################################### ################################################################################### @@ -648,7 +664,7 @@ def irs_ssa(self): @irs_ssa.setter def irs_ssa(self, irs): - self._irs_ssa = irs + self._irs_ssa = irs def add_ssa_ir(self, ir): ''' @@ -748,9 +764,6 @@ def add_phi_origin_state_variable(self, variable, node): assert v == variable nodes.add(node) - - - # endregion ################################################################################### ################################################################################### @@ -789,7 +802,7 @@ def _find_read_write_call(self): if isinstance(var, (ReferenceVariable)): var = var.points_to_origin if var and self._is_non_slithir_var(var): - self._vars_written.append(var) + self._vars_written.append(var) if isinstance(ir, InternalCall): self._internal_calls.append(ir.function) @@ -809,7 +822,8 @@ def _find_read_write_call(self): try: self._high_level_calls.append((ir.destination.type.type, ir.function)) except AttributeError: - raise SlitherException(f'Function not found on {ir}. Please try compiling with a recent Solidity version.') + raise SlitherException( + f'Function not found on {ir}. Please try compiling with a recent Solidity version.') elif isinstance(ir, LibraryCall): assert isinstance(ir.destination, Contract) self._high_level_calls.append((ir.destination, ir.function)) @@ -884,7 +898,6 @@ def update_read_write_using_ssa(self): vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] - self._vars_read += [v for v in vars_read if v not in self._vars_read] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] @@ -893,7 +906,6 @@ def update_read_write_using_ssa(self): self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] - # endregion ################################################################################### ################################################################################### @@ -902,7 +914,7 @@ def update_read_write_using_ssa(self): ################################################################################### def __str__(self): - txt = NodeType.str(self._node_type) + ' '+ str(self.expression) + txt = NodeType.str(self._node_type) + ' ' + str(self.expression) return txt # endregion diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 1cb9bb2f52..04219b3f4e 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -1111,8 +1111,16 @@ def slithir_cfg_to_dot_str(self) -> str: if node.irs: label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs]) content += '{}[label="{}"];\n'.format(node.node_id, label) - for son in node.sons: - content += '{}->{};\n'.format(node.node_id, son.node_id) + if node.type == NodeType.IF: + true_node = node.son_true + if true_node: + content += '{}->{}[label="True"];\n'.format(node.node_id, true_node.node_id) + false_node = node.son_false + if false_node: + content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id) + else: + for son in node.sons: + content += '{}->{};\n'.format(node.node_id, son.node_id) content += "}\n" return content