Skip to content

Commit

Permalink
Fix incorrect slithIR conversion: internal call to inherited contract…
Browse files Browse the repository at this point in the history
… where considered as external call

Simplify removal of dupplicate node.* elements
  • Loading branch information
montyly committed Feb 28, 2019
1 parent 9edf5ad commit c6e090e
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 35 deletions.
29 changes: 6 additions & 23 deletions slither/core/declarations/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,51 +997,34 @@ def _analyze_calls(self):
calls = [x.calls_as_expression for x in self.nodes]
calls = [x for x in calls if x]
calls = [item for sublist in calls for item in sublist]
# Remove dupplicate if they share the same string representation
# TODO: check if groupby is still necessary here
calls = [next(obj) for i, obj in\
groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))]
self._expression_calls = calls
self._expression_calls = list(set(calls))

internal_calls = [x.internal_calls for x in self.nodes]
internal_calls = [x for x in internal_calls if x]
internal_calls = [item for sublist in internal_calls for item in sublist]
internal_calls = [next(obj) for i, obj in
groupby(sorted(internal_calls, key=lambda x: str(x)), lambda x: str(x))]
self._internal_calls = internal_calls
self._internal_calls = list(set(internal_calls))

self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)]

low_level_calls = [x.low_level_calls for x in self.nodes]
low_level_calls = [x for x in low_level_calls if x]
low_level_calls = [item for sublist in low_level_calls for item in sublist]
low_level_calls = [next(obj) for i, obj in
groupby(sorted(low_level_calls, key=lambda x: str(x)), lambda x: str(x))]

self._low_level_calls = low_level_calls
self._low_level_calls = list(set(low_level_calls))

high_level_calls = [x.high_level_calls for x in self.nodes]
high_level_calls = [x for x in high_level_calls if x]
high_level_calls = [item for sublist in high_level_calls for item in sublist]
high_level_calls = [next(obj) for i, obj in
groupby(sorted(high_level_calls, key=lambda x: str(x)), lambda x: str(x))]

self._high_level_calls = high_level_calls
self._high_level_calls = list(set(high_level_calls))

library_calls = [x.library_calls for x in self.nodes]
library_calls = [x for x in library_calls if x]
library_calls = [item for sublist in library_calls for item in sublist]
library_calls = [next(obj) for i, obj in
groupby(sorted(library_calls, key=lambda x: str(x)), lambda x: str(x))]

self._library_calls = library_calls
self._library_calls = list(set(library_calls))

external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes]
external_calls_as_expressions = [x for x in external_calls_as_expressions if x]
external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist]
external_calls_as_expressions = [next(obj) for i, obj in
groupby(sorted(external_calls_as_expressions, key=lambda x: str(x)), lambda x: str(x))]
self._external_calls_as_expressions = external_calls_as_expressions
self._external_calls_as_expressions = list(set(external_calls_as_expressions))



Expand Down
19 changes: 13 additions & 6 deletions slither/slithir/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def propagate_type_and_convert_call(result, node):
ins = result[idx]

if isinstance(ins, TmpCall):
new_ins = extract_tmp_call(ins)
new_ins = extract_tmp_call(ins, node.function.contract)
if new_ins:
new_ins.set_node(ins.node)
ins = new_ins
Expand Down Expand Up @@ -323,12 +323,12 @@ def propagate_types(ir, node):
t_type = t.type
if isinstance(t_type, Contract):
contract = node.slither.get_contract_from_name(t_type.name)
return convert_type_of_high_level_call(ir, contract)
return convert_type_of_high_and_internal_level_call(ir, contract)

# Convert HighLevelCall to LowLevelCall
if isinstance(t, ElementaryType) and t.name == 'address':
if ir.destination.name == 'this':
return convert_type_of_high_level_call(ir, node.function.contract)
return convert_type_of_high_and_internal_level_call(ir, node.function.contract)
return convert_to_low_level(ir)

# Convert push operations
Expand All @@ -350,6 +350,8 @@ def propagate_types(ir, node):
ir.lvalue.set_type(ArrayType(t, length))
elif isinstance(ir, InternalCall):
# if its not a tuple, return a singleton
if ir.function is None:
convert_type_of_high_and_internal_level_call(ir, ir.contract)
return_type = ir.function.return_type
if return_type:
if len(return_type) == 1:
Expand Down Expand Up @@ -435,14 +437,19 @@ def propagate_types(ir, node):
logger.error('Not handling {} during type propgation'.format(type(ir)))
exit(-1)

def extract_tmp_call(ins):
def extract_tmp_call(ins, contract):
assert isinstance(ins, TmpCall)

if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType):
call = InternalDynamicCall(ins.lvalue, ins.called, ins.called.type)
call.call_id = ins.call_id
return call
if isinstance(ins.ori, Member):
# If there is a call on an inherited contract, it is an internal call
if ins.ori.variable_left in contract.inheritance + [contract]:
internalcall = InternalCall(ins.ori.variable_right, ins.ori.variable_left, ins.nbr_arguments, ins.lvalue, ins.type_call)
internalcall.call_id = ins.call_id
return internalcall
if isinstance(ins.ori.variable_left, Contract):
st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right)
if st:
Expand All @@ -457,7 +464,7 @@ def extract_tmp_call(ins):
return msgcall

if isinstance(ins.ori, TmpCall):
r = extract_tmp_call(ins.ori)
r = extract_tmp_call(ins.ori, contract)
return r
if isinstance(ins.called, SolidityVariableComposed):
if str(ins.called) == 'block.blockhash':
Expand Down Expand Up @@ -671,7 +678,7 @@ def convert_type_library_call(ir, lib_contract):
ir.lvalue = None
return ir

def convert_type_of_high_level_call(ir, contract):
def convert_type_of_high_and_internal_level_call(ir, contract):
func = None
sigs = get_sig(ir)
for sig in sigs:
Expand Down
26 changes: 22 additions & 4 deletions slither/slithir/operations/internal_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.variables.variable import Variable

from slither.slithir.variables import Constant

class InternalCall(Call, OperationWithLValue):

def __init__(self, function, nbr_arguments, result, type_call):
assert isinstance(function, Function)
def __init__(self, function, contract, nbr_arguments, result, type_call):
super(InternalCall, self).__init__()
self._function = function
if isinstance(function, Function):
self._function = function
self._function_name = function.name
else:
isinstance(function, Constant)
self._function = None
self._function_name = function
self._contract = contract
self._nbr_arguments = nbr_arguments
self._type_call = type_call
self._lvalue = result
Expand All @@ -22,6 +28,18 @@ def read(self):
def function(self):
return self._function

@function.setter
def function(self, f):
self._function = f

@property
def contract(self):
return self._contract

@property
def function_name(self):
return self._function_name

@property
def nbr_arguments(self):
return self._nbr_arguments
Expand Down
2 changes: 1 addition & 1 deletion slither/slithir/utils/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def copy_ir(ir, *instances):
nbr_arguments = ir.nbr_arguments
lvalue = get_variable(ir, lambda x: x.lvalue, *instances)
type_call = ir.type_call
new_ir = InternalCall(function, nbr_arguments, lvalue, type_call)
new_ir = InternalCall(function, function.contract, nbr_arguments, lvalue, type_call)
new_ir.arguments = get_arguments(ir, *instances)
return new_ir
elif isinstance(ir, InternalDynamicCall):
Expand Down
2 changes: 1 addition & 1 deletion slither/visitors/slithir/expression_to_slithir.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _post_call_expression(self, expression):
val = TupleVariable(self._node)
else:
val = TemporaryVariable(self._node)
internal_call = InternalCall(called, len(args), val, expression.type_call)
internal_call = InternalCall(called, called.contract, len(args), val, expression.type_call)
self._result.append(internal_call)
set_val(expression, val)
else:
Expand Down

0 comments on commit c6e090e

Please sign in to comment.