Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Expose immediate base contracts and base constructors #132

Merged
merged 17 commits into from
Jan 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/json_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
print('Usage: python json_diff.py 1.json 2.json')
exit(-1)

with open(sys.argv[1]) as f:
with open(sys.argv[1], encoding='utf8') as f:
d1 = json.load(f)

with open(sys.argv[2]) as f:
with open(sys.argv[2], encoding='utf8') as f:
d2 = json.load(f)


Expand Down
4 changes: 2 additions & 2 deletions slither/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def process_files(filenames, args, detector_classes, printer_classes):
all_contracts = []

for filename in filenames:
with open(filename) as f:
with open(filename, encoding='utf8') as f:
contract_loaded = json.load(f)
all_contracts.append(contract_loaded['ast'])

Expand All @@ -93,7 +93,7 @@ def process_files(filenames, args, detector_classes, printer_classes):


def output_json(results, filename):
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf8') as f:
json.dump(results, f)


Expand Down
59 changes: 56 additions & 3 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ def __init__(self):

self._name = None
self._id = None
self._inheritance = []
self._inheritance = [] # all contract inherited, c3 linearization
self._immediate_inheritance = [] # immediate inheritance

# Constructors called on contract's definition
# contract B is A(1) { ..
self._explicit_base_constructor_calls = []

self._enums = {}
self._structures = {}
Expand Down Expand Up @@ -60,15 +65,24 @@ def inheritance(self):
'''
return list(self._inheritance)

@property
def immediate_inheritance(self):
'''
list(Contract): List of contracts immediately inherited from (fathers). Order: order of declaration.
'''
return list(self._immediate_inheritance)

@property
def inheritance_reverse(self):
'''
list(Contract): Inheritance list. Order: the last elem is the first father to be executed
'''
return reversed(self._inheritance)

def setInheritance(self, inheritance):
def setInheritance(self, inheritance, immediate_inheritance, called_base_constructor_contracts):
self._inheritance = inheritance
self._immediate_inheritance = immediate_inheritance
self._explicit_base_constructor_calls = called_base_constructor_contracts

@property
def derived_contracts(self):
Expand Down Expand Up @@ -101,7 +115,32 @@ def modifiers_as_dict(self):

@property
def constructor(self):
return next((func for func in self.functions if func.is_constructor), None)
'''
Return the contract's immediate constructor.
If there is no immediate constructor, returns the first constructor
executed, following the c3 linearization
Return None if there is no constructor.
'''
cst = self.constructor_not_inherited
if cst:
return cst
for inherited_contract in self.inheritance:
cst = inherited_contract.constructor_not_inherited
if cst:
return cst
return None

@property
def constructor_not_inherited(self):
return next((func for func in self.functions if func.is_constructor and func.contract == self), None)

@property
def constructors(self):
'''
Return the list of constructors (including inherited)
'''
return [func for func in self.functions if func.is_constructor]


@property
def functions(self):
Expand Down Expand Up @@ -131,6 +170,19 @@ def functions_entry_points(self):
'''
return [f for f in self.functions if f.visibility in ['public', 'external']]

@property
def explicit_base_constructor_calls(self):
"""
list(Function): List of the base constructors called explicitly by this contract definition.

Base constructors called by any constructor definition will not be included.
Base constructors implicitly called by the contract definition (without
parenthesis) will not be included.

On "contract B is A(){..}" it returns the constructor of A
"""
return [c.constructor for c in self._explicit_base_constructor_calls if c.constructor]

@property
def modifiers(self):
'''
Expand Down Expand Up @@ -210,6 +262,7 @@ def all_state_variables_written(self):
all_state_variables_written = [f.all_state_variables_written() for f in self.functions + self.modifiers]
all_state_variables_written = [item for sublist in all_state_variables_written for item in sublist]
return list(set(all_state_variables_written))

@property
def all_state_variables_read(self):
'''
Expand Down
18 changes: 15 additions & 3 deletions slither/core/declarations/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self):
self._expression_calls = []
self._expression_modifiers = []
self._modifiers = []
self._explicit_base_constructor_calls = []
self._payable = False
self._contains_assembly = False

Expand Down Expand Up @@ -198,6 +199,17 @@ def modifiers(self):
"""
return list(self._modifiers)

@property
def explicit_base_constructor_calls(self):
"""
list(Function): List of the base constructors called explicitly by this presumed constructor definition.

Base constructors implicitly or explicitly called by the contract definition will not be
included.
"""
# This is a list of contracts internally, so we convert it to a list of constructor functions.
return [c.constructor_not_inherited for c in self._explicit_base_constructor_calls if c.constructor_not_inherited]

def __str__(self):
return self._name

Expand Down Expand Up @@ -642,7 +654,7 @@ def cfg_to_dot(self, filename):
Args:
filename (str)
"""
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf8') as f:
f.write('digraph{\n')
for node in self.nodes:
f.write('{}[label="{}"];\n'.format(node.node_id, str(node)))
Expand All @@ -658,7 +670,7 @@ def slithir_cfg_to_dot(self, filename):
filename (str)
"""
from slither.core.cfg.node import NodeType
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf8') as f:
f.write('digraph{\n')
for node in self.nodes:
label = 'Node Type: {} {}\n'.format(NodeType.str(node.type), node.node_id)
Expand All @@ -684,7 +696,7 @@ def description(node):
if node.dominance_frontier:
desc += '\ndominance frontier: {}'.format([n.node_id for n in node.dominance_frontier])
return desc
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf8') as f:
f.write('digraph{\n')
for node in self.nodes:
f.write('{}[label="{}"];\n'.format(node.node_id, description(node)))
Expand Down
3 changes: 3 additions & 0 deletions slither/core/solidity_types/array_type.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from slither.core.variables.variable import Variable
from slither.core.solidity_types.type import Type
from slither.core.expressions.expression import Expression
from slither.core.expressions import Literal

class ArrayType(Type):

def __init__(self, t, length):
assert isinstance(t, Type)
if length:
if isinstance(length, int):
length = Literal(length)
assert isinstance(length, Expression)
super(ArrayType, self).__init__()
self._type = t
Expand Down
2 changes: 1 addition & 1 deletion slither/printers/call/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def output(self, filename):

self.info(f'Call Graph: {filename}')

with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf8') as f:
f.write('\n'.join([
'strict digraph {',
self._render_internal_calls(),
Expand Down
2 changes: 1 addition & 1 deletion slither/printers/inheritance/inheritance_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def output(self, filename):
filename += ".dot"
info = 'Inheritance Graph: ' + filename
self.info(info)
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf8') as f:
f.write('digraph{\n')
for c in self.contracts:
f.write(self._summary(c))
Expand Down
2 changes: 1 addition & 1 deletion slither/slither.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _run_solc(self, filename, solc, disable_solc_warnings, solc_arguments, ast_f
raise Exception('Incorrect file format')

if is_ast_file:
with open(filename) as astFile:
with open(filename, encoding='utf8') as astFile:
stdout = astFile.read()
if not stdout:
logger.info('Empty AST file: %s', filename)
Expand Down
9 changes: 6 additions & 3 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def propage_type_and_convert_call(result, node):
assert ins.get_type() == ArgumentType.CALL
call_data.append(ins.argument)

if isinstance(ins, (HighLevelCall, NewContract)):
if isinstance(ins, (HighLevelCall, NewContract, InternalDynamicCall)):
if ins.call_id in calls_value:
ins.call_value = calls_value[ins.call_id]
if ins.call_id in calls_gas:
Expand Down Expand Up @@ -668,6 +668,11 @@ def remove_unused(result):

def extract_tmp_call(ins):
assert isinstance(ins, TmpCall)

if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
call.call_id = ins.call_id
return call
if isinstance(ins.ori, Member):
if isinstance(ins.ori.variable_left, Contract):
st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
Expand Down Expand Up @@ -713,8 +718,6 @@ def extract_tmp_call(ins):
if isinstance(ins.called, Event):
return EventCall(ins.called.name)

if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
return InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)

raise Exception('Not extracted {} {}'.format(type(ins.called), ins))

Expand Down
41 changes: 39 additions & 2 deletions slither/slithir/operations/internal_dynamic_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def __init__(self, lvalue, function, function_type):
self._function_type = function_type
self._lvalue = lvalue

self._callid = None # only used if gas/value != 0
self._call_value = None
self._call_gas = None

@property
def read(self):
return self._unroll(self.arguments) + [self.function]
Expand All @@ -29,16 +33,49 @@ def function(self):
def function_type(self):
return self._function_type

@property
def call_value(self):
return self._call_value

@call_value.setter
def call_value(self, v):
self._call_value = v


@property
def call_gas(self):
return self._call_gas

@call_gas.setter
def call_gas(self, v):
self._call_gas = v

@property
def call_id(self):
return self._callid

@call_id.setter
def call_id(self, c):
self._callid = c

def __str__(self):
value = ''
gas = ''
args = [str(a) for a in self.arguments]
if self.call_value:
value = 'value:{}'.format(self.call_value)
if self.call_gas:
gas = 'gas:{}'.format(self.call_gas)
if not self.lvalue:
lvalue = ''
elif isinstance(self.lvalue.type, (list,)):
lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type))
else:
lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type)
txt = '{}INTERNAL_DYNAMIC_CALL {}({})'
txt = '{}INTERNAL_DYNAMIC_CALL {}({}) {} {}'
return txt.format(lvalue,
self.function.name,
','.join(args))
','.join(args),
value,
gas)

50 changes: 50 additions & 0 deletions slither/solc_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,62 @@ def _parse_contract_info(self):
self.linearizedBaseContracts = attributes['linearizedBaseContracts']
self.fullyImplemented = attributes['fullyImplemented']

# Parse base contract information
self._parse_base_contract_info()

# trufle does some re-mapping of id
if 'baseContracts' in self._data:
for elem in self._data['baseContracts']:
if elem['nodeType'] == 'InheritanceSpecifier':
self._remapping[elem['baseName']['referencedDeclaration']] = elem['baseName']['name']

def _parse_base_contract_info(self):
# Parse base contracts (immediate, non-linearized)
self.baseContracts = []
self.baseConstructorContractsCalled = []
if self.is_compact_ast:
# Parse base contracts + constructors in compact-ast
if 'baseContracts' in self._data:
for base_contract in self._data['baseContracts']:
if base_contract['nodeType'] != 'InheritanceSpecifier':
continue
if 'baseName' not in base_contract or 'referencedDeclaration' not in base_contract['baseName']:
continue

# Obtain our contract reference and add it to our base contract list
referencedDeclaration = base_contract['baseName']['referencedDeclaration']
self.baseContracts.append(referencedDeclaration)

# If we have defined arguments in our arguments object, this is a constructor invocation.
# (note: 'arguments' can be [], which is not the same as None. [] implies a constructor was
# called with no arguments, while None implies no constructor was called).
if 'arguments' in base_contract and base_contract['arguments'] is not None:
self.baseConstructorContractsCalled.append(referencedDeclaration)
else:
# Parse base contracts + constructors in legacy-ast
if 'children' in self._data:
for base_contract in self._data['children']:
if base_contract['name'] != 'InheritanceSpecifier':
continue
if 'children' not in base_contract or len(base_contract['children']) == 0:
continue
# Obtain all items for this base contract specification (base contract, followed by arguments)
base_contract_items = base_contract['children']
if 'name' not in base_contract_items[0] or base_contract_items[0]['name'] != 'UserDefinedTypeName':
continue
if 'attributes' not in base_contract_items[0] or 'referencedDeclaration' not in \
base_contract_items[0]['attributes']:
continue

# Obtain our contract reference and add it to our base contract list
referencedDeclaration = base_contract_items[0]['attributes']['referencedDeclaration']
self.baseContracts.append(referencedDeclaration)

# If we have an 'attributes'->'arguments' which is None, this is not a constructor call.
if 'attributes' not in base_contract or 'arguments' not in base_contract['attributes'] or \
base_contract['attributes']['arguments'] is not None:
self.baseConstructorContractsCalled.append(referencedDeclaration)

def _parse_contract_items(self):
if not self.get_children() in self._data: # empty contract
return
Expand Down
Loading