diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index 951bf4702c..bb8344d8fb 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -1,62 +1,149 @@ # Functions for generating Solidity code from typing import TYPE_CHECKING, Optional -from slither.utils.type import convert_type_for_solidity_signature_to_string +from slither.utils.type import ( + convert_type_for_solidity_signature_to_string, + export_nested_types_from_variable, + export_return_type_from_variable, +) +from slither.core.solidity_types import ( + Type, + UserDefinedType, + MappingType, + ArrayType, + ElementaryType, +) +from slither.core.declarations import Structure, Enum, Contract if TYPE_CHECKING: - from slither.core.declarations import FunctionContract, Structure, Contract + from slither.core.declarations import FunctionContract, CustomErrorContract + from slither.core.variables.state_variable import StateVariable + from slither.core.variables.local_variable import LocalVariable -def generate_interface(contract: "Contract") -> str: +# pylint: disable=too-many-arguments +def generate_interface( + contract: "Contract", + unroll_structs: bool = True, + include_events: bool = True, + include_errors: bool = True, + include_enums: bool = True, + include_structs: bool = True, +) -> str: """ Generates code for a Solidity interface to the contract. Args: - contract: A Contract object + contract: A Contract object. + unroll_structs: Whether to use structures' underlying types instead of the user-defined type (default: True). + include_events: Whether to include event signatures in the interface (default: True). + include_errors: Whether to include custom error signatures in the interface (default: True). + include_enums: Whether to include enum definitions in the interface (default: True). + include_structs: Whether to include struct definitions in the interface (default: True). Returns: A string with the code for an interface, with function stubs for all public or external functions and state variables, as well as any events, custom errors and/or structs declared in the contract. """ interface = f"interface I{contract.name} {{\n" - for event in contract.events: - name, args = event.signature - interface += f" event {name}({', '.join(args)});\n" - for error in contract.custom_errors: - args = [ - convert_type_for_solidity_signature_to_string(arg.type) - .replace("(", "") - .replace(")", "") - for arg in error.parameters - ] - interface += f" error {error.name}({', '.join(args)});\n" - for enum in contract.enums: - interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" - for struct in contract.structures: - interface += generate_struct_interface_str(struct) + if include_events: + for event in contract.events: + name, args = event.signature + interface += f" event {name}({', '.join(args)});\n" + if include_errors: + for error in contract.custom_errors: + interface += f" error {generate_custom_error_interface(error, unroll_structs)};\n" + if include_enums: + for enum in contract.enums: + interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" + if include_structs: + for struct in contract.structures: + interface += generate_struct_interface_str(struct, indent=4) for var in contract.state_variables_entry_points: - interface += f" function {var.signature_str.replace('returns', 'external returns ')};\n" + interface += f" function {generate_interface_variable_signature(var, unroll_structs)};\n" for func in contract.functions_entry_points: if func.is_constructor or func.is_fallback or func.is_receive: continue - interface += f" function {generate_interface_function_signature(func)};\n" + interface += ( + f" function {generate_interface_function_signature(func, unroll_structs)};\n" + ) interface += "}\n\n" return interface -def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]: +def generate_interface_variable_signature( + var: "StateVariable", unroll_structs: bool = True +) -> Optional[str]: + if var.visibility in ["private", "internal"]: + return None + if unroll_structs: + params = [ + convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "") + for x in export_nested_types_from_variable(var) + ] + returns = [ + convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "") + for x in export_return_type_from_variable(var) + ] + else: + _, params, _ = var.signature + params = [p + " memory" if p in ["bytes", "string"] else p for p in params] + returns = [] + _type = var.type + while isinstance(_type, MappingType): + _type = _type.type_to + while isinstance(_type, (ArrayType, UserDefinedType)): + _type = _type.type + ret = str(_type) + if isinstance(_type, Structure) or (isinstance(_type, Type) and _type.is_dynamic): + ret += " memory" + elif isinstance(_type, Contract): + ret = "address" + returns.append(ret) + return f"{var.name}({','.join(params)}) external returns ({', '.join(returns)})" + + +def generate_interface_function_signature( + func: "FunctionContract", unroll_structs: bool = True +) -> Optional[str]: """ Generates a string of the form: func_name(type1,type2) external {payable/view/pure} returns (type3) Args: func: A FunctionContract object + unroll_structs: Determines whether structs are unrolled into underlying types (default: True) Returns: The function interface as a str (contains the return values). Returns None if the function is private or internal, or is a constructor/fallback/receive. """ - name, parameters, return_vars = func.signature + def format_var(var: "LocalVariable", unroll: bool) -> str: + if unroll: + return ( + convert_type_for_solidity_signature_to_string(var.type) + .replace("(", "") + .replace(")", "") + ) + if isinstance(var.type, ArrayType) and isinstance( + var.type.type, (UserDefinedType, ElementaryType) + ): + return ( + convert_type_for_solidity_signature_to_string(var.type) + .replace("(", "") + .replace(")", "") + + f" {var.location}" + ) + if isinstance(var.type, UserDefinedType): + if isinstance(var.type.type, (Structure, Enum)): + return f"{str(var.type.type)} memory" + if isinstance(var.type.type, Contract): + return "address" + if var.type.is_dynamic: + return f"{var.type} {var.location}" + return str(var.type) + + name, _, _ = func.signature if ( func not in func.contract.functions_entry_points or func.is_constructor @@ -64,26 +151,20 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[ or func.is_receive ): return None - view = " view" if func.view else "" + view = " view" if func.view and not func.pure else "" pure = " pure" if func.pure else "" payable = " payable" if func.payable else "" - returns = [ - convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "") - for ret in func.returns - ] - parameters = [ - convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "") - for param in func.parameters - ] + returns = [format_var(ret, unroll_structs) for ret in func.returns] + parameters = [format_var(param, unroll_structs) for param in func.parameters] _interface_signature_str = ( name + "(" + ",".join(parameters) + ") external" + payable + pure + view ) - if len(return_vars) > 0: + if len(returns) > 0: _interface_signature_str += " returns (" + ",".join(returns) + ")" return _interface_signature_str -def generate_struct_interface_str(struct: "Structure") -> str: +def generate_struct_interface_str(struct: "Structure", indent: int = 0) -> str: """ Generates code for a structure declaration in an interface of the form: struct struct_name { @@ -92,13 +173,37 @@ def generate_struct_interface_str(struct: "Structure") -> str: ... ... } Args: - struct: A Structure object + struct: A Structure object. + indent: Number of spaces to indent the code block with. Returns: The structure declaration code as a string. """ - definition = f" struct {struct.name} {{\n" + spaces = "" + for _ in range(0, indent): + spaces += " " + definition = f"{spaces}struct {struct.name} {{\n" for elem in struct.elems_ordered: - definition += f" {elem.type} {elem.name};\n" - definition += " }\n" + if isinstance(elem.type, UserDefinedType): + if isinstance(elem.type.type, (Structure, Enum)): + definition += f"{spaces} {elem.type.type} {elem.name};\n" + elif isinstance(elem.type.type, Contract): + definition += f"{spaces} address {elem.name};\n" + else: + definition += f"{spaces} {elem.type} {elem.name};\n" + definition += f"{spaces}}}\n" return definition + + +def generate_custom_error_interface( + error: "CustomErrorContract", unroll_structs: bool = True +) -> str: + args = [ + convert_type_for_solidity_signature_to_string(arg.type).replace("(", "").replace(")", "") + if unroll_structs + else str(arg.type.type) + if isinstance(arg.type, UserDefinedType) and isinstance(arg.type.type, (Structure, Enum)) + else str(arg.type) + for arg in error.parameters + ] + return f"{error.name}({', '.join(args)})" diff --git a/tests/unit/utils/test_code_generation.py b/tests/unit/utils/test_code_generation.py index 6794896340..3684b10b46 100644 --- a/tests/unit/utils/test_code_generation.py +++ b/tests/unit/utils/test_code_generation.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from solc_select import solc_select @@ -21,3 +22,11 @@ def test_interface_generation() -> None: expected = file.read() assert actual == expected + + actual = generate_interface(sl.get_contract_from_name("TestContract")[0], unroll_structs=False) + expected_path = os.path.join(TEST_DATA_DIR, "TEST_generated_code_not_unrolled.sol") + + with open(expected_path, "r", encoding="utf-8") as file: + expected = file.read() + + assert actual == expected diff --git a/tests/unit/utils/test_data/code_generation/CodeGeneration.sol b/tests/unit/utils/test_data/code_generation/CodeGeneration.sol index c15017abd5..6f1f63c72f 100644 --- a/tests/unit/utils/test_data/code_generation/CodeGeneration.sol +++ b/tests/unit/utils/test_data/code_generation/CodeGeneration.sol @@ -8,7 +8,9 @@ contract TestContract is I { uint public stateA; uint private stateB; address public immutable owner = msg.sender; - mapping(address => mapping(uint => St)) public structs; + mapping(address => mapping(uint => St)) public structsMap; + St[] public structsArray; + I public otherI; event NoParams(); event Anonymous() anonymous; @@ -23,6 +25,10 @@ contract TestContract is I { uint v; } + struct Nested{ + St st; + } + function err0() public { revert ErrorSimple(); } @@ -44,13 +50,16 @@ contract TestContract is I { function newSt(uint x) public returns (St memory) { St memory st; st.v = x; - structs[msg.sender][x] = st; + structsMap[msg.sender][x] = st; return st; } function getSt(uint x) public view returns (St memory) { - return structs[msg.sender][x]; + return structsMap[msg.sender][x]; } function removeSt(St memory st) public { - delete structs[msg.sender][st.v]; + delete structsMap[msg.sender][st.v]; + } + function setOtherI(I _i) public { + otherI = _i; } } \ No newline at end of file diff --git a/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol b/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol index 62e08bd74c..373fba9ca2 100644 --- a/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol +++ b/tests/unit/utils/test_data/code_generation/TEST_generated_code.sol @@ -11,14 +11,20 @@ interface ITestContract { struct St { uint256 v; } + struct Nested { + St st; + } function stateA() external returns (uint256); function owner() external returns (address); - function structs(address,uint256) external returns (uint256); + function structsMap(address,uint256) external returns (uint256); + function structsArray(uint256) external returns (uint256); + function otherI() external returns (address); function err0() external; function err1() external; function err2(uint256,uint256) external; function newSt(uint256) external returns (uint256); function getSt(uint256) external view returns (uint256); function removeSt(uint256) external; + function setOtherI(address) external; } diff --git a/tests/unit/utils/test_data/code_generation/TEST_generated_code_not_unrolled.sol b/tests/unit/utils/test_data/code_generation/TEST_generated_code_not_unrolled.sol new file mode 100644 index 0000000000..0cc4dc0404 --- /dev/null +++ b/tests/unit/utils/test_data/code_generation/TEST_generated_code_not_unrolled.sol @@ -0,0 +1,30 @@ +interface ITestContract { + event NoParams(); + event Anonymous(); + event OneParam(address); + event OneParamIndexed(address); + error ErrorWithEnum(SomeEnum); + error ErrorSimple(); + error ErrorWithArgs(uint256, uint256); + error ErrorWithStruct(St); + enum SomeEnum { ONE, TWO, THREE } + struct St { + uint256 v; + } + struct Nested { + St st; + } + function stateA() external returns (uint256); + function owner() external returns (address); + function structsMap(address,uint256) external returns (St memory); + function structsArray(uint256) external returns (St memory); + function otherI() external returns (address); + function err0() external; + function err1() external; + function err2(uint256,uint256) external; + function newSt(uint256) external returns (St memory); + function getSt(uint256) external view returns (St memory); + function removeSt(St memory) external; + function setOtherI(address) external; +} +