Skip to content

Commit

Permalink
Clean call-graph printer (+ export every contract individually)
Browse files Browse the repository at this point in the history
Add call_graph.py script that is similar to the call-graph printer, but without the view/pure functions
  • Loading branch information
montyly committed Mar 1, 2019
1 parent fba467e commit 833e390
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 50 deletions.
54 changes: 54 additions & 0 deletions examples/scripts/call_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
import logging
import argparse
from slither import Slither
from slither.printers.all_printers import PrinterCallGraph
from slither.core.declarations.function import Function

logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO)
logging.getLogger("Printers").setLevel(logging.INFO)

class PrinterCallGraphStateChange(PrinterCallGraph):

def _process_function(self, contract, function, contract_functions, contract_calls, solidity_functions, solidity_calls, external_calls, all_contracts):
if function.view or function.pure:
return
super()._process_function(contract, function, contract_functions, contract_calls, solidity_functions, solidity_calls, external_calls, all_contracts)

def _process_internal_call(self, contract, function, internal_call, contract_calls, solidity_functions, solidity_calls):
if isinstance(internal_call, Function):
if internal_call.view or internal_call.pure:
return
super()._process_internal_call(contract, function, internal_call, contract_calls, solidity_functions, solidity_calls)

def _process_external_call(self, contract, function, external_call, contract_functions, external_calls, all_contracts):
if isinstance(external_call[1], Function):
if external_call[1].view or external_call[1].pure:
return
super()._process_external_call(contract, function, external_call, contract_functions, external_calls, all_contracts)

def parse_args():
"""
"""
parser = argparse.ArgumentParser(description='Call graph printer. Similar to --print call-graph, but without printing the view/pure functions',
usage='call_graph.py filename')

parser.add_argument('filename',
help='The filename of the contract or truffle directory to analyze.')

parser.add_argument('--solc', help='solc path', default='solc')

return parser.parse_args()

def main():

args = parse_args()
slither = Slither(args.filename, is_truffle=os.path.isdir(args.filename), solc=args.solc)

slither.register_printer(PrinterCallGraphStateChange)

slither.run_printers()

if __name__ == '__main__':
main()
116 changes: 66 additions & 50 deletions slither/printers/call/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
what are the contracts/functions called.
The output is a dot file named filename.dot
"""

from collections import defaultdict
from slither.printers.abstract_printer import AbstractPrinter
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.declarations.function import Function
Expand Down Expand Up @@ -44,117 +44,133 @@ class PrinterCallGraph(AbstractPrinter):

WIKI = 'https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph'

def __init__(self, slither, logger):
super(PrinterCallGraph, self).__init__(slither, logger)
def _process_functions(self, functions):

contract_functions = defaultdict(set) # contract -> contract functions nodes
contract_calls = defaultdict(set) # contract -> contract calls edges

solidity_functions = set() # solidity function nodes
solidity_calls = set() # solidity calls edges
external_calls = set() # external calls edges

self.contract_functions = {} # contract -> contract functions nodes
self.contract_calls = {} # contract -> contract calls edges
all_contracts = set()

for contract in slither.contracts:
self.contract_functions[contract] = set()
self.contract_calls[contract] = set()
for function in functions:
all_contracts.add(function.contract)
for function in functions:
self._process_function(function.contract,
function,
contract_functions,
contract_calls,
solidity_functions,
solidity_calls,
external_calls,
all_contracts)

self.solidity_functions = set() # solidity function nodes
self.solidity_calls = set() # solidity calls edges
render_internal_calls = ''
for contract in all_contracts:
render_internal_calls += self._render_internal_calls(contract, contract_functions, contract_calls)

self.external_calls = set() # external calls edges
render_solidity_calls = '' #self._render_solidity_calls(solidity_functions, solidity_calls)

self._process_contracts(slither.contracts)
render_external_calls = self._render_external_calls(external_calls)

def _process_contracts(self, contracts):
for contract in contracts:
for function in contract.functions:
self._process_function(contract, function)
return render_internal_calls + render_solidity_calls + render_external_calls

def _process_function(self, contract, function):
self.contract_functions[contract].add(
def _process_function(self, contract, function, contract_functions, contract_calls, solidity_functions, solidity_calls, external_calls, all_contracts):
contract_functions[contract].add(
_node(_function_node(contract, function), function.name),
)

for internal_call in function.internal_calls:
self._process_internal_call(contract, function, internal_call)
self._process_internal_call(contract, function, internal_call, contract_calls, solidity_functions, solidity_calls)
for external_call in function.high_level_calls:
self._process_external_call(contract, function, external_call)
self._process_external_call(contract, function, external_call, contract_functions, external_calls, all_contracts)

def _process_internal_call(self, contract, function, internal_call):
def _process_internal_call(self, contract, function, internal_call, contract_calls, solidity_functions, solidity_calls):
if isinstance(internal_call, (Function)):
self.contract_calls[contract].add(_edge(
contract_calls[contract].add(_edge(
_function_node(contract, function),
_function_node(contract, internal_call),
))
elif isinstance(internal_call, (SolidityFunction)):
self.solidity_functions.add(
solidity_functions.add(
_node(_solidity_function_node(internal_call)),
)
self.solidity_calls.add(_edge(
solidity_calls.add(_edge(
_function_node(contract, function),
_solidity_function_node(internal_call),
))

def _process_external_call(self, contract, function, external_call):
def _process_external_call(self, contract, function, external_call, contract_functions, external_calls, all_contracts):
external_contract, external_function = external_call

if not external_contract in all_contracts:
return

# add variable as node to respective contract
if isinstance(external_function, (Variable)):
self.contract_functions[external_contract].add(_node(
return
contract_functions[external_contract].add(_node(
_function_node(external_contract, external_function),
external_function.name
))

self.external_calls.add(_edge(
external_calls.add(_edge(
_function_node(contract, function),
_function_node(external_contract, external_function),
))

def _render_internal_calls(self):
def _render_internal_calls(self, contract, contract_functions, contract_calls):
lines = []

for contract in self.contract_functions:
lines.append(f'subgraph {_contract_subgraph(contract)} {{')
lines.append(f'label = "{contract.name}"')
lines.append(f'subgraph {_contract_subgraph(contract)} {{')
lines.append(f'label = "{contract.name}"')

lines.extend(self.contract_functions[contract])
lines.extend(self.contract_calls[contract])
lines.extend(contract_functions[contract])
lines.extend(contract_calls[contract])

lines.append('}')
lines.append('}')

return '\n'.join(lines)

def _render_solidity_calls(self):
def _render_solidity_calls(self, solidity_functions, solidity_calls):
lines = []

lines.append('subgraph cluster_solidity {')
lines.append('label = "[Solidity]"')

lines.extend(self.solidity_functions)
lines.extend(self.solidity_calls)
lines.extend(solidity_functions)
lines.extend(solidity_calls)

lines.append('}')

return '\n'.join(lines)

def _render_external_calls(self):
return '\n'.join(self.external_calls)
def _render_external_calls(self, external_calls):
return '\n'.join(external_calls)



def output(self, filename):
"""
Output the graph in filename
Args:
filename(string)
"""
if not filename:
filename = "contracts.dot"

if not filename.endswith('.dot'):
filename += '.dot'

self.info(f'Call Graph: {filename}')
if filename == ".dot":
filename = "all_contracts.dot"

with open(filename, 'w', encoding='utf8') as f:
f.write('\n'.join([
'strict digraph {',
self._render_internal_calls(),
self._render_solidity_calls(),
self._render_external_calls(),
'}',
]))
self.info(f'Call Graph: {filename}')
f.write('\n'.join(['strict digraph {'] + [self._process_functions(self.slither.functions)] + ['}']))


for derived_contract in self.slither.contracts_derived:
with open(f'{derived_contract.name}.dot', 'w', encoding='utf8') as f:
self.info(f'Call Graph: {derived_contract.name}.dot')
f.write('\n'.join(['strict digraph {'] + [self._process_functions(derived_contract.functions)] + ['}']))

0 comments on commit 833e390

Please sign in to comment.