From ef03cb7060b67ae5c14251873ae89ada30b759fc Mon Sep 17 00:00:00 2001 From: Simone Date: Thu, 10 Oct 2024 18:32:26 +0200 Subject: [PATCH] Improve transient storage support --- slither/core/compilation_unit.py | 60 ++++++++---- slither/core/declarations/contract.py | 46 ++++----- slither/core/variables/state_variable.py | 16 ++++ slither/core/variables/variable.py | 7 -- .../variables/unchanged_state_variables.py | 2 +- slither/printers/summary/variable_order.py | 13 ++- slither/solc_parsing/declarations/contract.py | 4 +- .../checks/variable_initialization.py | 2 +- .../upgradeability/checks/variables_order.py | 93 ++++++++++++------- slither/utils/upgradeability.py | 8 +- .../vyper_parsing/declarations/contract.py | 2 +- 11 files changed, 152 insertions(+), 101 deletions(-) diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index df652dab0c..f4bd07e556 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -73,7 +73,8 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit # Memoize self._all_state_variables: Optional[Set[StateVariable]] = None - self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} + self._persistent_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} + self._transient_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} self._contract_with_missing_inheritance: Set[Contract] = set() @@ -297,33 +298,52 @@ def get_scope(self, filename_str: str) -> FileScope: def compute_storage_layout(self) -> None: assert self.is_solidity + for contract in self.contracts_derived: - self._storage_layouts[contract.name] = {} - - slot = 0 - offset = 0 - for var in contract.stored_state_variables_ordered: - assert var.type - size, new_slot = var.type.storage_size - - if new_slot: - if offset > 0: - slot += 1 - offset = 0 - elif size + offset > 32: + self._compute_storage_layout(contract.name, contract.storage_variables_ordered, False) + self._compute_storage_layout(contract.name, contract.transient_variables_ordered, True) + + def _compute_storage_layout( + self, contract_name: str, state_variables_ordered: List[StateVariable], is_transient: bool + ): + if is_transient: + self._transient_storage_layouts[contract_name] = {} + else: + self._persistent_storage_layouts[contract_name] = {} + + slot = 0 + offset = 0 + for var in state_variables_ordered: + assert var.type + size, new_slot = var.type.storage_size + + if new_slot: + if offset > 0: slot += 1 offset = 0 + elif size + offset > 32: + slot += 1 + offset = 0 - self._storage_layouts[contract.name][var.canonical_name] = ( + if is_transient: + self._transient_storage_layouts[contract_name][var.canonical_name] = ( slot, offset, ) - if new_slot: - slot += math.ceil(size / 32) - else: - offset += size + else: + self._persistent_storage_layouts[contract_name][var.canonical_name] = ( + slot, + offset, + ) + + if new_slot: + slot += math.ceil(size / 32) + else: + offset += size def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]: - return self._storage_layouts[contract.name][var.canonical_name] + if var.is_stored: + return self._persistent_storage_layouts[contract.name][var.canonical_name] + return self._transient_storage_layouts[contract.name][var.canonical_name] # endregion diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index a2f24fc6cd..8dccc007f9 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -440,55 +440,43 @@ def variables_as_dict(self) -> Dict[str, "StateVariable"]: def state_variables(self) -> List["StateVariable"]: """ Returns all the accessible variables (do not include private variable from inherited contract). - Use state_variables_ordered for all the variables following the storage order + Use stored_state_variables_ordered for all the storage variables following the storage order + Use transient_state_variables_ordered for all the transient variables following the storage order list(StateVariable): List of the state variables. """ return list(self._variables.values()) @property - def stored_state_variables(self) -> List["StateVariable"]: + def state_variables_entry_points(self) -> List["StateVariable"]: """ - Returns state variables with storage locations, excluding private variables from inherited contracts. - Use stored_state_variables_ordered to access variables with storage locations in their declaration order. - - This implementation filters out state variables if they are constant or immutable. It will be - updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future. - - Returns: - List[StateVariable]: A list of state variables with storage locations. + list(StateVariable): List of the state variables that are public. """ - return [variable for variable in self.state_variables if variable.is_stored] + return [var for var in self._variables.values() if var.visibility == "public"] @property - def stored_state_variables_ordered(self) -> List["StateVariable"]: + def state_variables_ordered(self) -> List["StateVariable"]: """ - list(StateVariable): List of the state variables with storage locations by order of declaration. - - This implementation filters out state variables if they are constant or immutable. It will be - updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future. - - Returns: - List[StateVariable]: A list of state variables with storage locations ordered by declaration. + list(StateVariable): List of the state variables by order of declaration. """ - return [variable for variable in self.state_variables_ordered if variable.is_stored] + return self._variables_ordered + + def add_state_variables_ordered(self, new_vars: List["StateVariable"]) -> None: + self._variables_ordered += new_vars @property - def state_variables_entry_points(self) -> List["StateVariable"]: + def storage_variables_ordered(self) -> List["StateVariable"]: """ - list(StateVariable): List of the state variables that are public. + list(StateVariable): List of the state variables in storage location by order of declaration. """ - return [var for var in self._variables.values() if var.visibility == "public"] + return [v for v in self._variables_ordered if v.is_stored] @property - def state_variables_ordered(self) -> List["StateVariable"]: + def transient_variables_ordered(self) -> List["StateVariable"]: """ - list(StateVariable): List of the state variables by order of declaration. + list(StateVariable): List of the state variables in transient location by order of declaration. """ - return list(self._variables_ordered) - - def add_variables_ordered(self, new_vars: List["StateVariable"]) -> None: - self._variables_ordered += new_vars + return [v for v in self._variables_ordered if v.is_transient] @property def state_variables_inherited(self) -> List["StateVariable"]: diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index 404cf74ba3..d3e3e60182 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -35,6 +35,22 @@ def location(self) -> Optional[str]: """ return self._location + @property + def is_stored(self) -> bool: + """ + Checks if the state variable is stored, based on it not being constant or immutable or transient. + """ + return ( + not self._is_constant and not self._is_immutable and not self._location == "transient" + ) + + @property + def is_transient(self) -> bool: + """ + Checks if the state variable is transient. A transient variable can not be constant or immutable. + """ + return self._location == "transient" + # endregion ################################################################################### ################################################################################### diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 63d1a7a838..f9ef190246 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -93,13 +93,6 @@ def is_constant(self) -> bool: def is_constant(self, is_cst: bool) -> None: self._is_constant = is_cst - @property - def is_stored(self) -> bool: - """ - Checks if a variable is stored, based on it not being constant or immutable. Future updates may adjust for new non-storage keywords. - """ - return not self._is_constant and not self._is_immutable - @property def is_reentrant(self) -> bool: return self._is_reentrant diff --git a/slither/detectors/variables/unchanged_state_variables.py b/slither/detectors/variables/unchanged_state_variables.py index 5771d96303..64c4c350f6 100644 --- a/slither/detectors/variables/unchanged_state_variables.py +++ b/slither/detectors/variables/unchanged_state_variables.py @@ -92,7 +92,7 @@ def detect(self) -> None: variables = [] functions = [] - variables.append(c.stored_state_variables) + variables.append(c.storage_variables_ordered) functions.append(c.all_functions_called) valid_candidates: Set[StateVariable] = { diff --git a/slither/printers/summary/variable_order.py b/slither/printers/summary/variable_order.py index 0d8ce2612c..fb19e39857 100644 --- a/slither/printers/summary/variable_order.py +++ b/slither/printers/summary/variable_order.py @@ -27,10 +27,17 @@ def output(self, _filename: str) -> Output: for contract in self.slither.contracts_derived: txt += f"\n{contract.name}:\n" - table = MyPrettyTable(["Name", "Type", "Slot", "Offset"]) - for variable in contract.stored_state_variables_ordered: + table = MyPrettyTable(["Name", "Type", "Slot", "Offset", "State"]) + for variable in contract.storage_variables_ordered: slot, offset = contract.compilation_unit.storage_layout_of(contract, variable) - table.add_row([variable.canonical_name, str(variable.type), slot, offset]) + table.add_row( + [variable.canonical_name, str(variable.type), slot, offset, "Storage"] + ) + for variable in contract.transient_variables_ordered: + slot, offset = contract.compilation_unit.storage_layout_of(contract, variable) + table.add_row( + [variable.canonical_name, str(variable.type), slot, offset, "Transient"] + ) all_tables.append((contract.name, table)) txt += str(table) + "\n" diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 1ccdc57602..06fc03b7a0 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -350,7 +350,7 @@ def parse_state_variables(self) -> None: if v.visibility != "private" } ) - self._contract.add_variables_ordered( + self._contract.add_state_variables_ordered( [ var for var in father.state_variables_ordered @@ -370,7 +370,7 @@ def parse_state_variables(self) -> None: if var_parser.reference_id is not None: self._contract.state_variables_by_ref_id[var_parser.reference_id] = var self._contract.variables_as_dict[var.name] = var - self._contract.add_variables_ordered([var]) + self._contract.add_state_variables_ordered([var]) def _parse_modifier(self, modifier_data: Dict) -> None: modif = Modifier(self._contract.compilation_unit) diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index b86036c87f..047c652dc4 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -43,7 +43,7 @@ class VariableWithInit(AbstractCheck): def _check(self) -> List[Output]: results = [] - for s in self.contract.stored_state_variables_ordered: + for s in self.contract.storage_variables_ordered: if s.initialized: info: CHECK_INFO = [s, " is a state variable with an initial value.\n"] json = self.generate_result(info) diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index 8d525a6dd3..8f5017d74c 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -115,29 +115,43 @@ def _contract2(self) -> Contract: def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() - order1 = contract1.stored_state_variables_ordered - order2 = contract2.stored_state_variables_ordered results: List[Output] = [] - for idx, _ in enumerate(order1): - if len(order2) <= idx: - # Handle by MissingVariable - return results - - variable1 = order1[idx] - variable2 = order2[idx] - if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info: CHECK_INFO = [ - "Different variables between ", - contract1, - " and ", - contract2, - "\n", - ] - info += ["\t ", variable1, "\n"] - info += ["\t ", variable2, "\n"] - json = self.generate_result(info) - results.append(json) + + def _check_internal( + contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool + ): + if is_transient: + order1 = contract1.transient_variables_ordered + order2 = contract2.transient_variables_ordered + else: + order1 = contract1.storage_variables_ordered + order2 = contract2.storage_variables_ordered + + for idx, _ in enumerate(order1): + if len(order2) <= idx: + # Handle by MissingVariable + return + + variable1 = order1[idx] + variable2 = order2[idx] + if (variable1.name != variable2.name) or (variable1.type != variable2.type): + info: CHECK_INFO = [ + "Different variables between ", + contract1, + " and ", + contract2, + "\n", + ] + info += ["\t ", variable1, "\n"] + info += ["\t ", variable2, "\n"] + json = self.generate_result(info) + results.append(json) + + # Checking state variables with storage location + _check_internal(contract1, contract2, results, False) + # Checking state variables with transient location + _check_internal(contract1, contract2, results, True) return results @@ -236,22 +250,35 @@ def _contract2(self) -> Contract: def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() - order1 = contract1.stored_state_variables_ordered - order2 = contract2.stored_state_variables_ordered - results = [] + results: List[Output] = [] - if len(order2) <= len(order1): - return [] + def _check_internal( + contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool + ): + if is_transient: + order1 = contract1.transient_variables_ordered + order2 = contract2.transient_variables_ordered + else: + order1 = contract1.storage_variables_ordered + order2 = contract2.storage_variables_ordered - idx = len(order1) + if len(order2) <= len(order1): + return - while idx < len(order2): - variable2 = order2[idx] - info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] - json = self.generate_result(info) - results.append(json) - idx = idx + 1 + idx = len(order1) + + while idx < len(order2): + variable2 = order2[idx] + info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] + json = self.generate_result(info) + results.append(json) + idx = idx + 1 + + # Checking state variables with storage location + _check_internal(contract1, contract2, results, False) + # Checking state variables with transient location + _check_internal(contract1, contract2, results, True) return results diff --git a/slither/utils/upgradeability.py b/slither/utils/upgradeability.py index 59979bca64..bedf08d4b2 100644 --- a/slither/utils/upgradeability.py +++ b/slither/utils/upgradeability.py @@ -80,8 +80,8 @@ def compare( tainted-contracts: list[TaintedExternalContract] """ - order_vars1 = v1.stored_state_variables_ordered - order_vars2 = v2.stored_state_variables_ordered + order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered + order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered func_sigs1 = [function.solidity_signature for function in v1.functions] func_sigs2 = [function.solidity_signature for function in v2.functions] @@ -306,8 +306,8 @@ def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]: List of StateVariables from v1 missing in v2 """ results = [] - order_vars1 = v1.stored_state_variables_ordered - order_vars2 = v2.stored_state_variables_ordered + order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered + order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered if len(order_vars2) < len(order_vars1): for variable in order_vars1: if variable.name not in [v.name for v in order_vars2]: diff --git a/slither/vyper_parsing/declarations/contract.py b/slither/vyper_parsing/declarations/contract.py index 64fab1c549..ddf5161509 100644 --- a/slither/vyper_parsing/declarations/contract.py +++ b/slither/vyper_parsing/declarations/contract.py @@ -470,7 +470,7 @@ def parse_state_variables(self) -> None: assert var.name self._contract.variables_as_dict[var.name] = var - self._contract.add_variables_ordered([var]) + self._contract.add_state_variables_ordered([var]) # Interfaces can refer to constants self._contract.file_scope.variables[var.name] = var