Skip to content

Commit

Permalink
Merge pull request #1508 from crytic/binary-constant-folding
Browse files Browse the repository at this point in the history
fold binary expressions with constant operands for fuzzing guidance
  • Loading branch information
montyly authored Jan 9, 2023
2 parents b0b1c6a + 811dd78 commit 2e41679
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 26 deletions.
21 changes: 15 additions & 6 deletions slither/core/expressions/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,25 @@


class Literal(Expression):
def __init__(self, value, custom_type, subdenomination=None):
def __init__(
self, value: Union[int, str], custom_type: "Type", subdenomination: Optional[str] = None
):
super().__init__()
self._value: Union[int, str] = value
self._value = value
self._type = custom_type
self._subdenomination: Optional[str] = subdenomination
self._subdenomination = subdenomination

@property
def value(self) -> Union[int, str]:
return self._value

@property
def converted_value(self) -> Union[int, str]:
"""Return the value of the literal, accounting for subdenomination e.g. ether"""
if self.subdenomination:
return convert_subdenomination(self._value, self.subdenomination)
return self._value

@property
def type(self) -> "Type":
return self._type
Expand All @@ -28,17 +37,17 @@ def type(self) -> "Type":
def subdenomination(self) -> Optional[str]:
return self._subdenomination

def __str__(self):
def __str__(self) -> str:
if self.subdenomination:
return str(convert_subdenomination(self._value, self.subdenomination))
return str(self.converted_value)

if self.type in Int + Uint + Fixed + Ufixed + ["address"]:
return str(convert_string_to_int(self._value))

# be sure to handle any character
return str(self._value)

def __eq__(self, other):
def __eq__(self, other) -> bool:
if not isinstance(other, Literal):
return False
return (self.value, self.subdenomination) == (other.value, other.subdenomination)
6 changes: 6 additions & 0 deletions slither/printers/guidance/echidna.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from slither.slithir.operations.binary import Binary
from slither.slithir.variables import Constant
from slither.visitors.expression.constants_folding import ConstantFolding


def _get_name(f: Union[Function, Variable]) -> str:
Expand Down Expand Up @@ -175,6 +176,11 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n
all_cst_used_in_binary[str(ir.type)].append(
ConstantValue(str(r.value), str(r.type))
)
if isinstance(ir.variable_left, Constant) and isinstance(ir.variable_right, Constant):
if ir.lvalue:
type_ = ir.lvalue.type
cst = ConstantFolding(ir.expression, type_).result()
all_cst_used.append(ConstantValue(str(cst.value), str(type_)))
if isinstance(ir, TypeConversion):
if isinstance(ir.variable, Constant):
all_cst_used.append(ConstantValue(str(ir.variable.value), str(ir.type)))
Expand Down
82 changes: 62 additions & 20 deletions slither/visitors/expression/constants_folding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from slither.core.expressions import BinaryOperationType, Literal, UnaryOperationType
from fractions import Fraction
from slither.core.expressions import (
BinaryOperationType,
Literal,
UnaryOperationType,
Identifier,
BinaryOperation,
UnaryOperation,
)
from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int
from slither.visitors.expression.expression import ExpressionVisitor

Expand Down Expand Up @@ -27,19 +35,26 @@ def __init__(self, expression, custom_type):
super().__init__(expression)

def result(self):
return Literal(int(get_val(self._expression)), self._type)

def _post_identifier(self, expression):
value = get_val(self._expression)
if isinstance(value, Fraction):
value = int(value)
# emulate 256-bit wrapping
if str(self._type).startswith("uint"):
value = value & (2**256 - 1)
return Literal(value, self._type)

def _post_identifier(self, expression: Identifier):
if not expression.value.is_constant:
raise NotConstant
expr = expression.value.expression
# assumption that we won't have infinite loop
if not isinstance(expr, Literal):
cf = ConstantFolding(expr, self._type)
expr = cf.result()
set_val(expression, convert_string_to_int(expr.value))
set_val(expression, convert_string_to_int(expr.converted_value))

def _post_binary_operation(self, expression):
# pylint: disable=too-many-branches
def _post_binary_operation(self, expression: BinaryOperation):
left = get_val(expression.expression_left)
right = get_val(expression.expression_right)
if expression.type == BinaryOperationType.POWER:
Expand All @@ -53,34 +68,58 @@ def _post_binary_operation(self, expression):
elif expression.type == BinaryOperationType.ADDITION:
set_val(expression, left + right)
elif expression.type == BinaryOperationType.SUBTRACTION:
if (left - right) < 0:
# Could trigger underflow
raise NotConstant
set_val(expression, left - right)
# Convert to int for operations not supported by Fraction
elif expression.type == BinaryOperationType.LEFT_SHIFT:
set_val(expression, left << right)
set_val(expression, int(left) << int(right))
elif expression.type == BinaryOperationType.RIGHT_SHIFT:
set_val(expression, left >> right)
set_val(expression, int(left) >> int(right))
elif expression.type == BinaryOperationType.AND:
set_val(expression, int(left) & int(right))
elif expression.type == BinaryOperationType.CARET:
set_val(expression, int(left) ^ int(right))
elif expression.type == BinaryOperationType.OR:
set_val(expression, int(left) | int(right))
elif expression.type == BinaryOperationType.LESS:
set_val(expression, int(left) < int(right))
elif expression.type == BinaryOperationType.LESS_EQUAL:
set_val(expression, int(left) <= int(right))
elif expression.type == BinaryOperationType.GREATER:
set_val(expression, int(left) > int(right))
elif expression.type == BinaryOperationType.GREATER_EQUAL:
set_val(expression, int(left) >= int(right))
elif expression.type == BinaryOperationType.EQUAL:
set_val(expression, int(left) == int(right))
elif expression.type == BinaryOperationType.NOT_EQUAL:
set_val(expression, int(left) != int(right))
# Convert boolean literals from string to bool
elif expression.type == BinaryOperationType.ANDAND:
set_val(expression, left == "true" and right == "true")
elif expression.type == BinaryOperationType.OROR:
set_val(expression, left == "true" or right == "true")
else:
raise NotConstant

def _post_unary_operation(self, expression):
def _post_unary_operation(self, expression: UnaryOperation):
# Case of uint a = -7; uint[-a] arr;
if expression.type == UnaryOperationType.MINUS_PRE:
expr = expression.expression
if not isinstance(expr, Literal):
cf = ConstantFolding(expr, self._type)
expr = cf.result()
assert isinstance(expr, Literal)
set_val(expression, -convert_string_to_fraction(expr.value))
set_val(expression, -convert_string_to_fraction(expr.converted_value))
else:
raise NotConstant

def _post_literal(self, expression):
try:
set_val(expression, convert_string_to_fraction(expression.value))
except ValueError as e:
raise NotConstant from e
def _post_literal(self, expression: Literal):
if expression.converted_value in ["true", "false"]:
set_val(expression, expression.converted_value)
else:
try:
set_val(expression, convert_string_to_fraction(expression.converted_value))
except ValueError as e:
raise NotConstant from e

def _post_assignement_operation(self, expression):
raise NotConstant
Expand Down Expand Up @@ -115,9 +154,12 @@ def _post_tuple_expression(self, expression):
cf = ConstantFolding(expression.expressions[0], self._type)
expr = cf.result()
assert isinstance(expr, Literal)
set_val(expression, convert_string_to_fraction(expr.value))
set_val(expression, convert_string_to_fraction(expr.converted_value))
return
raise NotConstant

def _post_type_conversion(self, expression):
raise NotConstant
cf = ConstantFolding(expression.expression, self._type)
expr = cf.result()
assert isinstance(expr, Literal)
set_val(expression, convert_string_to_fraction(expr.converted_value))
14 changes: 14 additions & 0 deletions tests/constant_folding_binop.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
contract BinOp {
uint a = 1 & 2;
uint b = 1 ^ 2;
uint c = 1 | 2;
bool d = 2 < 1;
bool e = 1 > 2;
bool f = 1 <= 2;
bool g = 1 >= 2;
bool h = 1 == 2;
bool i = 1 != 2;
bool j = true && false;
bool k = true || false;
uint l = uint(1) - uint(2);
}
56 changes: 56 additions & 0 deletions tests/test_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,59 @@ def test_constant_folding_rational():
variable_g = contract.get_state_variable_from_name("g")
assert str(variable_g.type) == "int64"
assert str(ConstantFolding(variable_g.expression, "int64").result()) == "-7"


def test_constant_folding_binary_expressions():
sl = Slither("./tests/constant_folding_binop.sol")
contract = sl.get_contract_from_name("BinOp")[0]

variable_a = contract.get_state_variable_from_name("a")
assert str(variable_a.type) == "uint256"
assert str(ConstantFolding(variable_a.expression, "uint256").result()) == "0"

variable_b = contract.get_state_variable_from_name("b")
assert str(variable_b.type) == "uint256"
assert str(ConstantFolding(variable_b.expression, "uint256").result()) == "3"

variable_c = contract.get_state_variable_from_name("c")
assert str(variable_c.type) == "uint256"
assert str(ConstantFolding(variable_c.expression, "uint256").result()) == "3"

variable_d = contract.get_state_variable_from_name("d")
assert str(variable_d.type) == "bool"
assert str(ConstantFolding(variable_d.expression, "bool").result()) == "False"

variable_e = contract.get_state_variable_from_name("e")
assert str(variable_e.type) == "bool"
assert str(ConstantFolding(variable_e.expression, "bool").result()) == "False"

variable_f = contract.get_state_variable_from_name("f")
assert str(variable_f.type) == "bool"
assert str(ConstantFolding(variable_f.expression, "bool").result()) == "True"

variable_g = contract.get_state_variable_from_name("g")
assert str(variable_g.type) == "bool"
assert str(ConstantFolding(variable_g.expression, "bool").result()) == "False"

variable_h = contract.get_state_variable_from_name("h")
assert str(variable_h.type) == "bool"
assert str(ConstantFolding(variable_h.expression, "bool").result()) == "False"

variable_i = contract.get_state_variable_from_name("i")
assert str(variable_i.type) == "bool"
assert str(ConstantFolding(variable_i.expression, "bool").result()) == "True"

variable_j = contract.get_state_variable_from_name("j")
assert str(variable_j.type) == "bool"
assert str(ConstantFolding(variable_j.expression, "bool").result()) == "False"

variable_k = contract.get_state_variable_from_name("k")
assert str(variable_k.type) == "bool"
assert str(ConstantFolding(variable_k.expression, "bool").result()) == "True"

variable_l = contract.get_state_variable_from_name("l")
assert str(variable_l.type) == "uint256"
assert (
str(ConstantFolding(variable_l.expression, "uint256").result())
== "115792089237316195423570985008687907853269984665640564039457584007913129639935"
)

0 comments on commit 2e41679

Please sign in to comment.