Skip to content

Commit

Permalink
Merge pull request #2051 from crytic/dev-fix-enum-max-min
Browse files Browse the repository at this point in the history
Fix enum.max/min when enum in other contract
  • Loading branch information
montyly authored Sep 15, 2023
2 parents cc9e65f + d128b6d commit 46630b7
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 4 deletions.
25 changes: 21 additions & 4 deletions slither/visitors/slithir/expression_to_slithir.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def _post_member_access(self, expression: MemberAccess) -> None:
# Look for type(X).max / min
# Because we looked at the AST structure, we need to look into the nested expression
# Hopefully this is always on a direct sub field, and there is no weird construction
# pylint: disable=too-many-nested-blocks
if isinstance(expression.expression, CallExpression) and expression.member_name in [
"min",
"max",
Expand All @@ -474,10 +475,22 @@ def _post_member_access(self, expression: MemberAccess) -> None:
constant_type = type_found
else:
# type(enum).max/min
assert isinstance(type_expression_found, Identifier)
type_found_in_expression = type_expression_found.value
assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel))
type_found = UserDefinedType(type_found_in_expression)
# Case when enum is in another contract e.g. type(C.E).max
if isinstance(type_expression_found, MemberAccess):
contract = type_expression_found.expression.value
assert isinstance(contract, Contract)
for enum in contract.enums:
if enum.name == type_expression_found.member_name:
type_found_in_expression = enum
type_found = UserDefinedType(enum)
break
else:
assert isinstance(type_expression_found, Identifier)
type_found_in_expression = type_expression_found.value
assert isinstance(
type_found_in_expression, (EnumContract, EnumTopLevel)
)
type_found = UserDefinedType(type_found_in_expression)
constant_type = None
min_value = type_found_in_expression.min
max_value = type_found_in_expression.max
Expand Down Expand Up @@ -535,6 +548,10 @@ def _post_member_access(self, expression: MemberAccess) -> None:
if expression.member_name in expr.custom_errors_as_dict:
set_val(expression, expr.custom_errors_as_dict[expression.member_name])
return
# Lookup enums when in a different contract e.g. C.E
if str(expression) in expr.enums_as_dict:
set_val(expression, expr.enums_as_dict[str(expression)])
return

val_ref = ReferenceVariable(self._node)
member = Member(expr, Constant(expression.member_name), val_ref)
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/solc_parsing/test_ast_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def make_version(minor: int, patch_min: int, patch_max: int) -> List[str]:
),
Test("user_defined_operators-0.8.19.sol", ["0.8.19"]),
Test("type-aliases.sol", ["0.8.19"]),
Test("enum-max-min.sol", ["0.8.19"]),
]
# create the output folder if needed
try:
Expand Down
Binary file not shown.
37 changes: 37 additions & 0 deletions tests/e2e/solc_parsing/test_data/enum-max-min.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

library Q {
enum E {a}
}

contract Z {
enum E {a,b}
}

contract D {
enum E {a,b,c}

function a() public returns(uint){
return uint(type(E).max);
}

function b() public returns(uint){
return uint(type(Q.E).max);
}

function c() public returns(uint){
return uint(type(Z.E).max);
}

function d() public returns(uint){
return uint(type(E).min);
}

function e() public returns(uint){
return uint(type(Q.E).min);
}

function f() public returns(uint){
return uint(type(Z.E).min);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"Q": {},
"Z": {},
"D": {
"a()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"b()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"c()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"d()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"e()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n",
"f()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: RETURN 1\n\"];\n}\n"
}
}
37 changes: 37 additions & 0 deletions tests/unit/slithir/test_data/enum_max_min.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

library Q {
enum E {a}
}

contract Z {
enum E {a,b}
}

contract D {
enum E {a,b,c}

function a() public returns(uint){
return uint(type(E).max);
}

function b() public returns(uint){
return uint(type(Q.E).max);
}

function c() public returns(uint){
return uint(type(Z.E).max);
}

function d() public returns(uint){
return uint(type(E).min);
}

function e() public returns(uint){
return uint(type(Q.E).min);
}

function f() public returns(uint){
return uint(type(Z.E).min);
}

}
67 changes: 67 additions & 0 deletions tests/unit/slithir/test_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pathlib import Path
from slither import Slither
from slither.slithir.operations import Assignment
from slither.slithir.variables import Constant

TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"


def test_enum_max_min(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.19")
slither = Slither(Path(TEST_DATA_DIR, "enum_max_min.sol").as_posix(), solc=solc_path)

contract = slither.get_contract_from_name("D")[0]

f = contract.get_function_from_full_name("a()")
# TMP_1(uint256) := 2(uint256)
assignment = f.slithir_operations[1]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == 2
)

f = contract.get_function_from_full_name("b()")
# TMP_4(uint256) := 0(uint256)
assignment = f.slithir_operations[1]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == 0
)

f = contract.get_function_from_full_name("c()")
# TMP_7(uint256) := 1(uint256)
assignment = f.slithir_operations[1]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == 1
)

f = contract.get_function_from_full_name("d()")
# TMP_10(uint256) := 0(uint256)
assignment = f.slithir_operations[1]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == 0
)

f = contract.get_function_from_full_name("e()")
# TMP_13(uint256) := 0(uint256)
assignment = f.slithir_operations[1]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == 0
)

f = contract.get_function_from_full_name("f()")
# TMP_16(uint256) := 0(uint256)
assignment = f.slithir_operations[1]
assert (
isinstance(assignment, Assignment)
and isinstance(assignment.rvalue, Constant)
and assignment.rvalue.value == 0
)

0 comments on commit 46630b7

Please sign in to comment.