Skip to content

Commit

Permalink
Merge pull request #1559 from crytic/dev-fix-yul-parsing
Browse files Browse the repository at this point in the history
Improve yul parsing
  • Loading branch information
montyly authored Jan 9, 2023
2 parents 45c5ed9 + 17d3e8f commit aee2a78
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 45 deletions.
4 changes: 2 additions & 2 deletions slither/core/children/child_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@


class ChildFunction:
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._function = None

def set_function(self, function: "Function"):
def set_function(self, function: "Function") -> None:
self._function = function

@property
Expand Down
4 changes: 2 additions & 2 deletions slither/core/expressions/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@


class Identifier(ExpressionTyped):
def __init__(self, value):
def __init__(self, value) -> None:
super().__init__()
self._value: "Variable" = value

@property
def value(self) -> "Variable":
return self._value

def __str__(self):
def __str__(self) -> str:
return str(self._value)
4 changes: 2 additions & 2 deletions slither/core/variables/local_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@


class LocalVariable(ChildFunction, Variable):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._location: Optional[str] = None

def set_location(self, loc: str):
def set_location(self, loc: str) -> None:
self._location = loc

@property
Expand Down
1 change: 0 additions & 1 deletion slither/solc_parsing/declarations/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def _new_yul_block(
node,
[self._function.name, f"asm_{len(self._node_to_yulobject)}"],
scope,
parent_func=self._function,
)
self._node_to_yulobject[node] = yul_object
return yul_object
Expand Down
4 changes: 2 additions & 2 deletions slither/solc_parsing/yul/evm_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def format_function_descriptor(name):


class YulBuiltin: # pylint: disable=too-few-public-methods
def __init__(self, name):
def __init__(self, name: str) -> None:
self._name = name

@property
def name(self):
def name(self) -> str:
return self._name
86 changes: 50 additions & 36 deletions slither/solc_parsing/yul/parse_yul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
UnaryOperation,
)
from slither.core.expressions.expression import Expression
from slither.core.scope.scope import FileScope
from slither.core.solidity_types import ElementaryType
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.local_variable import LocalVariable
Expand Down Expand Up @@ -51,30 +52,35 @@ def __init__(self, node: Node, scope: "YulScope"):
def underlying_node(self) -> Node:
return self._node

def add_unparsed_expression(self, expression: Dict):
def add_unparsed_expression(self, expression: Dict) -> None:
assert self._unparsed_expression is None
self._unparsed_expression = expression

def analyze_expressions(self):
def analyze_expressions(self) -> None:
if self._node.type == NodeType.VARIABLE and not self._node.expression:
self._node.add_expression(self._node.variable_declaration.expression)
expression = self._node.variable_declaration.expression
if expression:
self._node.add_expression(expression)
if self._unparsed_expression:
expression = parse_yul(self._scope, self, self._unparsed_expression)
self._node.add_expression(expression)
if expression:
self._node.add_expression(expression)

if self._node.expression:
if self._node.type == NodeType.VARIABLE:
# Update the expression to be an assignement to the variable
_expression = AssignmentOperation(
Identifier(self._node.variable_declaration),
self._node.expression,
AssignmentOperationType.ASSIGN,
self._node.variable_declaration.type,
)
_expression.set_offset(
self._node.expression.source_mapping, self._node.compilation_unit
)
self._node.add_expression(_expression, bypass_verif_empty=True)
variable_declaration = self._node.variable_declaration
if variable_declaration:
_expression = AssignmentOperation(
Identifier(self._node.variable_declaration),
self._node.expression,
AssignmentOperationType.ASSIGN,
variable_declaration.type,
)
_expression.set_offset(
self._node.expression.source_mapping, self._node.compilation_unit
)
self._node.add_expression(_expression, bypass_verif_empty=True)

expression = self._node.expression
read_var = ReadVar(expression)
Expand Down Expand Up @@ -122,13 +128,13 @@ class YulScope(metaclass=abc.ABCMeta):
]

def __init__(
self, contract: Optional[Contract], yul_id: List[str], parent_func: Function = None
):
self, contract: Optional[Contract], yul_id: List[str], parent_func: Function
) -> None:
self._contract = contract
self._id: List[str] = yul_id
self._yul_local_variables: List[YulLocalVariable] = []
self._yul_local_functions: List[YulFunction] = []
self._parent_func = parent_func
self._parent_func: Function = parent_func

@property
def id(self) -> List[str]:
Expand All @@ -155,10 +161,14 @@ def function(self) -> Function:
def new_node(self, node_type: NodeType, src: Union[str, Dict]) -> YulNode:
pass

def add_yul_local_variable(self, var):
@property
def file_scope(self) -> FileScope:
return self._parent_func.file_scope

def add_yul_local_variable(self, var: "YulLocalVariable") -> None:
self._yul_local_variables.append(var)

def get_yul_local_variable_from_name(self, variable_name):
def get_yul_local_variable_from_name(self, variable_name: str) -> Optional["YulLocalVariable"]:
return next(
(
v
Expand All @@ -168,10 +178,10 @@ def get_yul_local_variable_from_name(self, variable_name):
None,
)

def add_yul_local_function(self, func):
def add_yul_local_function(self, func: "YulFunction") -> None:
self._yul_local_functions.append(func)

def get_yul_local_function_from_name(self, func_name):
def get_yul_local_function_from_name(self, func_name: str) -> Optional["YulLocalVariable"]:
return next(
(v for v in self._yul_local_functions if v.underlying.name == func_name),
None,
Expand Down Expand Up @@ -242,7 +252,7 @@ def underlying(self) -> Function:
def function(self) -> Function:
return self._function

def convert_body(self):
def convert_body(self) -> None:
node = self.new_node(NodeType.ENTRYPOINT, self._ast["src"])
link_underlying_nodes(self._entrypoint, node)

Expand All @@ -258,7 +268,7 @@ def convert_body(self):

convert_yul(self, node, self._ast["body"], self.node_scope)

def parse_body(self):
def parse_body(self) -> None:
for node in self._nodes:
node.analyze_expressions()

Expand Down Expand Up @@ -289,9 +299,8 @@ def __init__(
entrypoint: Node,
yul_id: List[str],
node_scope: Union[Scope, Function],
**kwargs,
):
super().__init__(contract, yul_id, **kwargs)
super().__init__(contract, yul_id, entrypoint.function)

self._entrypoint: YulNode = YulNode(entrypoint, self)
self._nodes: List[YulNode] = []
Expand All @@ -318,7 +327,7 @@ def new_node(self, node_type: NodeType, src: Union[str, Dict]) -> YulNode:
def convert(self, ast: Dict) -> YulNode:
return convert_yul(self, self._entrypoint, ast, self.node_scope)

def analyze_expressions(self):
def analyze_expressions(self) -> None:
for node in self._nodes:
node.analyze_expressions()

Expand Down Expand Up @@ -361,18 +370,22 @@ def convert_yul_function_definition(
while not isinstance(top_node_scope, Function):
top_node_scope = top_node_scope.father

func: Union[FunctionTopLevel, FunctionContract]
if isinstance(top_node_scope, FunctionTopLevel):
scope = root.contract.file_scope
scope = root.file_scope
func = FunctionTopLevel(root.compilation_unit, scope)
# Note: we do not add the function in the scope
# While its a top level function, it is not accessible outside of the function definition
# In practice we should probably have a specific function type for function defined within a function
else:
func = FunctionContract(root.compilation_unit)

func.function_language = FunctionLanguage.Yul
yul_function = YulFunction(func, root, ast, node_scope)

root.contract.add_function(func)
if root.contract:
root.contract.add_function(func)

root.compilation_unit.add_function(func)
root.add_yul_local_function(yul_function)

Expand Down Expand Up @@ -774,14 +787,15 @@ def parse_yul_identifier(root: YulScope, _node: YulNode, ast: Dict) -> Optional[
# check function-scoped variables
parent_func = root.parent_func
if parent_func:
variable = parent_func.get_local_variable_from_name(name)
if variable:
return Identifier(variable)
local_variable = parent_func.get_local_variable_from_name(name)
if local_variable:
return Identifier(local_variable)

if isinstance(parent_func, FunctionContract):
variable = parent_func.contract.get_state_variable_from_name(name)
if variable:
return Identifier(variable)
assert parent_func.contract
state_variable = parent_func.contract.get_state_variable_from_name(name)
if state_variable:
return Identifier(state_variable)

# check yul-scoped variable
variable = root.get_yul_local_variable_from_name(name)
Expand All @@ -798,7 +812,7 @@ def parse_yul_identifier(root: YulScope, _node: YulNode, ast: Dict) -> Optional[
if magic_suffix:
return magic_suffix

ret, _ = find_top_level(name, root.contract.file_scope)
ret, _ = find_top_level(name, root.file_scope)
if ret:
return Identifier(ret)

Expand Down Expand Up @@ -840,7 +854,7 @@ def parse_yul_unsupported(_root: YulScope, _node: YulNode, ast: Dict) -> Optiona


def parse_yul(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]:
op = parsers.get(ast["nodeType"], parse_yul_unsupported)(root, node, ast)
op: Expression = parsers.get(ast["nodeType"], parse_yul_unsupported)(root, node, ast)
if op:
op.set_offset(ast["src"], root.compilation_unit)
return op
Expand Down
Binary file not shown.
16 changes: 16 additions & 0 deletions tests/ast-parsing/yul-top-level-0.8.0.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function top_level_yul(int256 c) pure returns (uint result) {
assembly {
function internal_yul(a) -> b {
b := a
}

result := internal_yul(c)
}
}


contract Test {
function test() public{
top_level_yul(10);
}
}
1 change: 1 addition & 0 deletions tests/test_ast_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def make_version(minor: int, patch_min: int, patch_max: int) -> List[str]:
Test("using-for-global-0.8.0.sol", ["0.8.15"]),
Test("library_event-0.8.16.sol", ["0.8.16"]),
Test("top-level-struct-0.8.0.sol", ["0.8.0"]),
Test("yul-top-level-0.8.0.sol", ["0.8.0"]),
Test("complex_imports/import_aliases_issue_1319/test.sol", ["0.5.12"]),
]
# create the output folder if needed
Expand Down

0 comments on commit aee2a78

Please sign in to comment.