diff --git a/slither/core/expressions/new_array.py b/slither/core/expressions/new_array.py index 162b48f1ec..b1d407793f 100644 --- a/slither/core/expressions/new_array.py +++ b/slither/core/expressions/new_array.py @@ -1,32 +1,23 @@ -from typing import Union, TYPE_CHECKING - +from typing import TYPE_CHECKING from slither.core.expressions.expression import Expression -from slither.core.solidity_types.type import Type if TYPE_CHECKING: - from slither.core.solidity_types.elementary_type import ElementaryType - from slither.core.solidity_types.type_alias import TypeAliasTopLevel + from slither.core.solidity_types.array_type import ArrayType class NewArray(Expression): - - # note: dont conserve the size of the array if provided - def __init__( - self, depth: int, array_type: Union["TypeAliasTopLevel", "ElementaryType"] - ) -> None: + def __init__(self, array_type: "ArrayType") -> None: super().__init__() - assert isinstance(array_type, Type) - self._depth: int = depth - self._array_type: Type = array_type + # pylint: disable=import-outside-toplevel + from slither.core.solidity_types.array_type import ArrayType - @property - def array_type(self) -> Type: - return self._array_type + assert isinstance(array_type, ArrayType) + self._array_type = array_type @property - def depth(self) -> int: - return self._depth + def array_type(self) -> "ArrayType": + return self._array_type def __str__(self): - return "new " + str(self._array_type) + "[]" * self._depth + return "new " + str(self._array_type) diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 8bc9b3247b..665a7c8f9f 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1077,7 +1077,7 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]) -> Union[Call, return op if isinstance(ins.ori, TmpNewArray): - n = NewArray(ins.ori.depth, ins.ori.array_type, ins.lvalue) + n = NewArray(ins.ori.array_type, ins.lvalue) n.set_expression(ins.expression) return n diff --git a/slither/slithir/operations/init_array.py b/slither/slithir/operations/init_array.py index 4f6b2f9faf..e0754c7705 100644 --- a/slither/slithir/operations/init_array.py +++ b/slither/slithir/operations/init_array.py @@ -41,7 +41,7 @@ def __str__(self): def convert(elem): if isinstance(elem, (list,)): return str([convert(x) for x in elem]) - return str(elem) + return f"{elem}({elem.type})" init_values = convert(self.init_values) return f"{self.lvalue}({self.lvalue.type}) = {init_values}" diff --git a/slither/slithir/operations/new_array.py b/slither/slithir/operations/new_array.py index 34e8f99f7b..39b989459f 100644 --- a/slither/slithir/operations/new_array.py +++ b/slither/slithir/operations/new_array.py @@ -1,10 +1,10 @@ from typing import List, Union, TYPE_CHECKING -from slither.slithir.operations.lvalue import OperationWithLValue + +from slither.core.solidity_types.array_type import ArrayType from slither.slithir.operations.call import Call -from slither.core.solidity_types.type import Type +from slither.slithir.operations.lvalue import OperationWithLValue if TYPE_CHECKING: - from slither.core.solidity_types.type_alias import TypeAliasTopLevel from slither.slithir.variables.constant import Constant from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA @@ -13,32 +13,24 @@ class NewArray(Call, OperationWithLValue): def __init__( self, - depth: int, - array_type: "TypeAliasTopLevel", + array_type: "ArrayType", lvalue: Union["TemporaryVariableSSA", "TemporaryVariable"], ) -> None: super().__init__() - assert isinstance(array_type, Type) - self._depth = depth + assert isinstance(array_type, ArrayType) self._array_type = array_type self._lvalue = lvalue @property - def array_type(self) -> "TypeAliasTopLevel": + def array_type(self) -> "ArrayType": return self._array_type @property def read(self) -> List["Constant"]: return self._unroll(self.arguments) - @property - def depth(self) -> int: - return self._depth - def __str__(self): args = [str(a) for a in self.arguments] lvalue = self.lvalue - return ( - f"{lvalue}({lvalue.type}) = new {self.array_type}{'[]' * self.depth}({','.join(args)})" - ) + return f"{lvalue}{lvalue.type}) = new {self.array_type}({','.join(args)})" diff --git a/slither/slithir/tmp_operations/tmp_new_array.py b/slither/slithir/tmp_operations/tmp_new_array.py index efbdb6242d..04acb4b9ed 100644 --- a/slither/slithir/tmp_operations/tmp_new_array.py +++ b/slither/slithir/tmp_operations/tmp_new_array.py @@ -6,13 +6,11 @@ class TmpNewArray(OperationWithLValue): def __init__( self, - depth: int, array_type: Type, lvalue: TemporaryVariable, ) -> None: super().__init__() assert isinstance(array_type, Type) - self._depth = depth self._array_type = array_type self._lvalue = lvalue @@ -24,9 +22,5 @@ def array_type(self) -> Type: def read(self): return [] - @property - def depth(self) -> int: - return self._depth - def __str__(self): - return f"{self.lvalue} = new {self.array_type}{'[]' * self._depth}" + return f"{self.lvalue} = new {self.array_type}" diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 9a180d14f6..4c958798b4 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -789,10 +789,9 @@ def copy_ir(ir: Operation, *instances) -> Operation: variable_right = get_variable(ir, lambda x: x.variable_right, *instances) return Member(variable_left, variable_right, lvalue) if isinstance(ir, NewArray): - depth = ir.depth array_type = ir.array_type lvalue = get_variable(ir, lambda x: x.lvalue, *instances) - new_ir = NewArray(depth, array_type, lvalue) + new_ir = NewArray(array_type, lvalue) new_ir.arguments = get_rec_values(ir, lambda x: x.arguments, *instances) return new_ir if isinstance(ir, NewElementaryType): diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index d0dc4c7e02..945a60b8fe 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -559,37 +559,9 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) type_name = children[0] if type_name[caller_context.get_key()] == "ArrayTypeName": - depth = 0 - while type_name[caller_context.get_key()] == "ArrayTypeName": - # Note: dont conserve the size of the array if provided - # We compute it directly - if is_compact_ast: - type_name = type_name["baseType"] - else: - type_name = type_name["children"][0] - depth += 1 - if type_name[caller_context.get_key()] == "ElementaryTypeName": - if is_compact_ast: - array_type = ElementaryType(type_name["name"]) - else: - array_type = ElementaryType(type_name["attributes"]["name"]) - elif type_name[caller_context.get_key()] == "UserDefinedTypeName": - if is_compact_ast: - if "name" not in type_name: - name_type = type_name["pathNode"]["name"] - else: - name_type = type_name["name"] - - array_type = parse_type(UnknownType(name_type), caller_context) - else: - array_type = parse_type( - UnknownType(type_name["attributes"]["name"]), caller_context - ) - elif type_name[caller_context.get_key()] == "FunctionTypeName": - array_type = parse_type(type_name, caller_context) - else: - raise ParsingError(f"Incorrect type array {type_name}") - array = NewArray(depth, array_type) + array_type = parse_type(type_name, caller_context) + assert isinstance(array_type, ArrayType) + array = NewArray(array_type) array.set_offset(src, caller_context.compilation_unit) return array diff --git a/slither/visitors/expression/expression_printer.py b/slither/visitors/expression/expression_printer.py index 601627c028..61af4f1c9f 100644 --- a/slither/visitors/expression/expression_printer.py +++ b/slither/visitors/expression/expression_printer.py @@ -76,8 +76,7 @@ def _post_member_access(self, expression: expressions.MemberAccess) -> None: def _post_new_array(self, expression: expressions.NewArray) -> None: array = str(expression.array_type) - depth = expression.depth - val = f"new {array}{'[]' * depth}" + val = f"new {array}" set_val(expression, val) def _post_new_contract(self, expression: expressions.NewContract) -> None: diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index e4fa5a1b18..005ad81a44 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -532,7 +532,7 @@ def _post_member_access(self, expression: MemberAccess) -> None: def _post_new_array(self, expression: NewArray) -> None: val = TemporaryVariable(self._node) - operation = TmpNewArray(expression.depth, expression.array_type, val) + operation = TmpNewArray(expression.array_type, val) operation.set_expression(expression) self._result.append(operation) set_val(expression, val) diff --git a/tests/unit/slithir/test_ssa_generation.py b/tests/unit/slithir/test_ssa_generation.py index 62218e41aa..3c7e84973f 100644 --- a/tests/unit/slithir/test_ssa_generation.py +++ b/tests/unit/slithir/test_ssa_generation.py @@ -1,15 +1,17 @@ # # pylint: disable=too-many-lines import pathlib -from collections import defaultdict from argparse import ArgumentTypeError +from collections import defaultdict from inspect import getsourcefile from typing import Union, List, Dict, Callable import pytest from solc_select.solc_select import valid_version as solc_valid_version + from slither import Slither from slither.core.cfg.node import Node, NodeType from slither.core.declarations import Function, Contract +from slither.core.solidity_types import ArrayType from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.slithir.operations import ( @@ -1050,6 +1052,34 @@ def test_issue_1748(slither_from_source): assert isinstance(assign_op, InitArray) +def test_issue_1776(slither_from_source): + source = """ + contract Contract { + function foo() public returns (uint) { + uint[5][10][] memory arr = new uint[5][10][](2); + return 0; + } + } + """ + with slither_from_source(source) as slither: + c = slither.get_contract_from_name("Contract")[0] + f = c.functions[0] + operations = f.slithir_operations + new_op = operations[0] + lvalue = new_op.lvalue + lvalue_type = lvalue.type + assert isinstance(lvalue_type, ArrayType) + assert lvalue_type.is_dynamic + lvalue_type1 = lvalue_type.type + assert isinstance(lvalue_type1, ArrayType) + assert not lvalue_type1.is_dynamic + assert lvalue_type1.length_value.value == "10" + lvalue_type2 = lvalue_type1.type + assert isinstance(lvalue_type2, ArrayType) + assert not lvalue_type2.is_dynamic + assert lvalue_type2.length_value.value == "5" + + def test_issue_1846_ternary_in_if(slither_from_source): source = """ contract Contract {