diff --git a/slither/core/expressions/literal.py b/slither/core/expressions/literal.py index 87090b93f4..2eaeb715d7 100644 --- a/slither/core/expressions/literal.py +++ b/slither/core/expressions/literal.py @@ -10,16 +10,25 @@ class Literal(Expression): - def __init__(self, value, custom_type, subdenomination=None): + def __init__( + self, value: Union[int, str], custom_type: "Type", subdenomination: Optional[str] = None + ): super().__init__() - self._value: Union[int, str] = value + self._value = value self._type = custom_type - self._subdenomination: Optional[str] = subdenomination + self._subdenomination = subdenomination @property def value(self) -> Union[int, str]: return self._value + @property + def converted_value(self) -> Union[int, str]: + """Return the value of the literal, accounting for subdenomination e.g. ether""" + if self.subdenomination: + return convert_subdenomination(self._value, self.subdenomination) + return self._value + @property def type(self) -> "Type": return self._type @@ -28,9 +37,9 @@ def type(self) -> "Type": def subdenomination(self) -> Optional[str]: return self._subdenomination - def __str__(self): + def __str__(self) -> str: if self.subdenomination: - return str(convert_subdenomination(self._value, self.subdenomination)) + return str(self.converted_value) if self.type in Int + Uint + Fixed + Ufixed + ["address"]: return str(convert_string_to_int(self._value)) @@ -38,7 +47,7 @@ def __str__(self): # be sure to handle any character return str(self._value) - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, Literal): return False return (self.value, self.subdenomination) == (other.value, other.subdenomination) diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index dbfa541218..45b3ab3327 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -30,6 +30,7 @@ ) from slither.slithir.operations.binary import Binary from slither.slithir.variables import Constant +from slither.visitors.expression.constants_folding import ConstantFolding def _get_name(f: Union[Function, Variable]) -> str: @@ -175,6 +176,11 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n all_cst_used_in_binary[str(ir.type)].append( ConstantValue(str(r.value), str(r.type)) ) + if isinstance(ir.variable_left, Constant) and isinstance(ir.variable_right, Constant): + if ir.lvalue: + type_ = ir.lvalue.type + cst = ConstantFolding(ir.expression, type_).result() + all_cst_used.append(ConstantValue(str(cst.value), str(type_))) if isinstance(ir, TypeConversion): if isinstance(ir.variable, Constant): all_cst_used.append(ConstantValue(str(ir.variable.value), str(ir.type))) diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index b324ed8425..797d1f46e4 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -1,4 +1,12 @@ -from slither.core.expressions import BinaryOperationType, Literal, UnaryOperationType +from fractions import Fraction +from slither.core.expressions import ( + BinaryOperationType, + Literal, + UnaryOperationType, + Identifier, + BinaryOperation, + UnaryOperation, +) from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor @@ -27,9 +35,15 @@ def __init__(self, expression, custom_type): super().__init__(expression) def result(self): - return Literal(int(get_val(self._expression)), self._type) - - def _post_identifier(self, expression): + value = get_val(self._expression) + if isinstance(value, Fraction): + value = int(value) + # emulate 256-bit wrapping + if str(self._type).startswith("uint"): + value = value & (2**256 - 1) + return Literal(value, self._type) + + def _post_identifier(self, expression: Identifier): if not expression.value.is_constant: raise NotConstant expr = expression.value.expression @@ -37,9 +51,10 @@ def _post_identifier(self, expression): if not isinstance(expr, Literal): cf = ConstantFolding(expr, self._type) expr = cf.result() - set_val(expression, convert_string_to_int(expr.value)) + set_val(expression, convert_string_to_int(expr.converted_value)) - def _post_binary_operation(self, expression): + # pylint: disable=too-many-branches + def _post_binary_operation(self, expression: BinaryOperation): left = get_val(expression.expression_left) right = get_val(expression.expression_right) if expression.type == BinaryOperationType.POWER: @@ -53,18 +68,39 @@ def _post_binary_operation(self, expression): elif expression.type == BinaryOperationType.ADDITION: set_val(expression, left + right) elif expression.type == BinaryOperationType.SUBTRACTION: - if (left - right) < 0: - # Could trigger underflow - raise NotConstant set_val(expression, left - right) + # Convert to int for operations not supported by Fraction elif expression.type == BinaryOperationType.LEFT_SHIFT: - set_val(expression, left << right) + set_val(expression, int(left) << int(right)) elif expression.type == BinaryOperationType.RIGHT_SHIFT: - set_val(expression, left >> right) + set_val(expression, int(left) >> int(right)) + elif expression.type == BinaryOperationType.AND: + set_val(expression, int(left) & int(right)) + elif expression.type == BinaryOperationType.CARET: + set_val(expression, int(left) ^ int(right)) + elif expression.type == BinaryOperationType.OR: + set_val(expression, int(left) | int(right)) + elif expression.type == BinaryOperationType.LESS: + set_val(expression, int(left) < int(right)) + elif expression.type == BinaryOperationType.LESS_EQUAL: + set_val(expression, int(left) <= int(right)) + elif expression.type == BinaryOperationType.GREATER: + set_val(expression, int(left) > int(right)) + elif expression.type == BinaryOperationType.GREATER_EQUAL: + set_val(expression, int(left) >= int(right)) + elif expression.type == BinaryOperationType.EQUAL: + set_val(expression, int(left) == int(right)) + elif expression.type == BinaryOperationType.NOT_EQUAL: + set_val(expression, int(left) != int(right)) + # Convert boolean literals from string to bool + elif expression.type == BinaryOperationType.ANDAND: + set_val(expression, left == "true" and right == "true") + elif expression.type == BinaryOperationType.OROR: + set_val(expression, left == "true" or right == "true") else: raise NotConstant - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation): # Case of uint a = -7; uint[-a] arr; if expression.type == UnaryOperationType.MINUS_PRE: expr = expression.expression @@ -72,15 +108,18 @@ def _post_unary_operation(self, expression): cf = ConstantFolding(expr, self._type) expr = cf.result() assert isinstance(expr, Literal) - set_val(expression, -convert_string_to_fraction(expr.value)) + set_val(expression, -convert_string_to_fraction(expr.converted_value)) else: raise NotConstant - def _post_literal(self, expression): - try: - set_val(expression, convert_string_to_fraction(expression.value)) - except ValueError as e: - raise NotConstant from e + def _post_literal(self, expression: Literal): + if expression.converted_value in ["true", "false"]: + set_val(expression, expression.converted_value) + else: + try: + set_val(expression, convert_string_to_fraction(expression.converted_value)) + except ValueError as e: + raise NotConstant from e def _post_assignement_operation(self, expression): raise NotConstant @@ -115,9 +154,12 @@ def _post_tuple_expression(self, expression): cf = ConstantFolding(expression.expressions[0], self._type) expr = cf.result() assert isinstance(expr, Literal) - set_val(expression, convert_string_to_fraction(expr.value)) + set_val(expression, convert_string_to_fraction(expr.converted_value)) return raise NotConstant def _post_type_conversion(self, expression): - raise NotConstant + cf = ConstantFolding(expression.expression, self._type) + expr = cf.result() + assert isinstance(expr, Literal) + set_val(expression, convert_string_to_fraction(expr.converted_value)) diff --git a/tests/constant_folding_binop.sol b/tests/constant_folding_binop.sol new file mode 100644 index 0000000000..923418ce71 --- /dev/null +++ b/tests/constant_folding_binop.sol @@ -0,0 +1,14 @@ +contract BinOp { + uint a = 1 & 2; + uint b = 1 ^ 2; + uint c = 1 | 2; + bool d = 2 < 1; + bool e = 1 > 2; + bool f = 1 <= 2; + bool g = 1 >= 2; + bool h = 1 == 2; + bool i = 1 != 2; + bool j = true && false; + bool k = true || false; + uint l = uint(1) - uint(2); +} \ No newline at end of file diff --git a/tests/test_constant_folding.py b/tests/test_constant_folding.py index efc3119a84..21517ddc4c 100644 --- a/tests/test_constant_folding.py +++ b/tests/test_constant_folding.py @@ -43,3 +43,59 @@ def test_constant_folding_rational(): variable_g = contract.get_state_variable_from_name("g") assert str(variable_g.type) == "int64" assert str(ConstantFolding(variable_g.expression, "int64").result()) == "-7" + + +def test_constant_folding_binary_expressions(): + sl = Slither("./tests/constant_folding_binop.sol") + contract = sl.get_contract_from_name("BinOp")[0] + + variable_a = contract.get_state_variable_from_name("a") + assert str(variable_a.type) == "uint256" + assert str(ConstantFolding(variable_a.expression, "uint256").result()) == "0" + + variable_b = contract.get_state_variable_from_name("b") + assert str(variable_b.type) == "uint256" + assert str(ConstantFolding(variable_b.expression, "uint256").result()) == "3" + + variable_c = contract.get_state_variable_from_name("c") + assert str(variable_c.type) == "uint256" + assert str(ConstantFolding(variable_c.expression, "uint256").result()) == "3" + + variable_d = contract.get_state_variable_from_name("d") + assert str(variable_d.type) == "bool" + assert str(ConstantFolding(variable_d.expression, "bool").result()) == "False" + + variable_e = contract.get_state_variable_from_name("e") + assert str(variable_e.type) == "bool" + assert str(ConstantFolding(variable_e.expression, "bool").result()) == "False" + + variable_f = contract.get_state_variable_from_name("f") + assert str(variable_f.type) == "bool" + assert str(ConstantFolding(variable_f.expression, "bool").result()) == "True" + + variable_g = contract.get_state_variable_from_name("g") + assert str(variable_g.type) == "bool" + assert str(ConstantFolding(variable_g.expression, "bool").result()) == "False" + + variable_h = contract.get_state_variable_from_name("h") + assert str(variable_h.type) == "bool" + assert str(ConstantFolding(variable_h.expression, "bool").result()) == "False" + + variable_i = contract.get_state_variable_from_name("i") + assert str(variable_i.type) == "bool" + assert str(ConstantFolding(variable_i.expression, "bool").result()) == "True" + + variable_j = contract.get_state_variable_from_name("j") + assert str(variable_j.type) == "bool" + assert str(ConstantFolding(variable_j.expression, "bool").result()) == "False" + + variable_k = contract.get_state_variable_from_name("k") + assert str(variable_k.type) == "bool" + assert str(ConstantFolding(variable_k.expression, "bool").result()) == "True" + + variable_l = contract.get_state_variable_from_name("l") + assert str(variable_l.type) == "uint256" + assert ( + str(ConstantFolding(variable_l.expression, "uint256").result()) + == "115792089237316195423570985008687907853269984665640564039457584007913129639935" + )