diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index 7bcc9c6642..a396c9af01 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -2,6 +2,7 @@ from slither.core.solidity_types.type import Type from slither.core.expressions.expression import Expression from slither.core.expressions import Literal +from slither.visitors.expression.constants_folding import ConstantFolding class ArrayType(Type): @@ -11,6 +12,9 @@ def __init__(self, t, length): if isinstance(length, int): length = Literal(length) assert isinstance(length, Expression) + if not isinstance(length, Literal): + cf = ConstantFolding(length) + length = cf.result() super(ArrayType, self).__init__() self._type = t self._length = length diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py new file mode 100644 index 0000000000..bca1bfe046 --- /dev/null +++ b/slither/visitors/expression/constants_folding.py @@ -0,0 +1,104 @@ +import logging + +from .expression import ExpressionVisitor +from slither.core.expressions import BinaryOperationType, Literal + +class NotConstant(Exception): + pass + + +KEY = 'ConstantFolding' + +def get_val(expression): + val = expression.context[KEY] + # we delete the item to reduce memory use + del expression.context[KEY] + return val + +def set_val(expression, val): + expression.context[KEY] = val + +class ConstantFolding(ExpressionVisitor): + + def result(self): + return Literal(int(get_val(self._expression))) + + def _post_identifier(self, expression): + if not expression.value.is_constant: + raise NotConstant + expr = expression.value.expression + # assumption that we won't have infinite loop + if not isinstance(expr, Literal): + cf = ConstantFolding(expr) + expr = cf.result() + set_val(expression, int(expr.value)) + + def _post_binary_operation(self, expression): + left = get_val(expression.expression_left) + right = get_val(expression.expression_right) + if expression.type == BinaryOperationType.POWER: + set_val(expression, left ** right) + elif expression.type == BinaryOperationType.MULTIPLICATION: + set_val(expression, left * right) + elif expression.type == BinaryOperationType.DIVISION: + set_val(expression, left / right) + elif expression.type == BinaryOperationType.MODULO: + set_val(expression, left % right) + elif expression.type == BinaryOperationType.ADDITION: + set_val(expression, left + right) + elif expression.type == BinaryOperationType.SUBTRACTION: + if(left-right) <0: + # Could trigger underflow + raise NotConstant + set_val(expression, left - right) + elif expression.type == BinaryOperationType.LEFT_SHIFT: + set_val(expression, left << right) + elif expression.type == BinaryOperationType.RIGHT_SHIFT: + set_val(expression, left >> right) + else: + raise NotConstant + + def _post_unary_operation(self, expression): + raise NotConstant + + def _post_literal(self, expression): + if expression.value.isdigit(): + set_val(expression, int(expression.value)) + else: + raise NotConstant + + def _post_assignement_operation(self, expression): + raise NotConstant + + def _post_call_expression(self, expression): + raise NotConstant + + def _post_conditional_expression(self, expression): + raise NotConstant + + def _post_elementary_type_name_expression(self, expression): + raise NotConstant + + def _post_index_access(self, expression): + raise NotConstant + + def _post_member_access(self, expression): + raise NotConstant + + def _post_new_array(self, expression): + raise NotConstant + + def _post_new_contract(self, expression): + raise NotConstant + + def _post_new_elementary_type(self, expression): + raise NotConstant + + def _post_tuple_expression(self, expression): + raise NotConstant + + def _post_type_conversion(self, expression): + raise NotConstant + + +