Skip to content

Commit

Permalink
Improve transient storage support
Browse files Browse the repository at this point in the history
  • Loading branch information
smonicas committed Oct 10, 2024
1 parent 98af7fd commit ef03cb7
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 101 deletions.
60 changes: 40 additions & 20 deletions slither/core/compilation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit
# Memoize
self._all_state_variables: Optional[Set[StateVariable]] = None

self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._persistent_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._transient_storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}

self._contract_with_missing_inheritance: Set[Contract] = set()

Expand Down Expand Up @@ -297,33 +298,52 @@ def get_scope(self, filename_str: str) -> FileScope:

def compute_storage_layout(self) -> None:
assert self.is_solidity

for contract in self.contracts_derived:
self._storage_layouts[contract.name] = {}

slot = 0
offset = 0
for var in contract.stored_state_variables_ordered:
assert var.type
size, new_slot = var.type.storage_size

if new_slot:
if offset > 0:
slot += 1
offset = 0
elif size + offset > 32:
self._compute_storage_layout(contract.name, contract.storage_variables_ordered, False)
self._compute_storage_layout(contract.name, contract.transient_variables_ordered, True)

def _compute_storage_layout(
self, contract_name: str, state_variables_ordered: List[StateVariable], is_transient: bool
):
if is_transient:
self._transient_storage_layouts[contract_name] = {}
else:
self._persistent_storage_layouts[contract_name] = {}

slot = 0
offset = 0
for var in state_variables_ordered:
assert var.type
size, new_slot = var.type.storage_size

if new_slot:
if offset > 0:
slot += 1
offset = 0
elif size + offset > 32:
slot += 1
offset = 0

self._storage_layouts[contract.name][var.canonical_name] = (
if is_transient:
self._transient_storage_layouts[contract_name][var.canonical_name] = (
slot,
offset,
)
if new_slot:
slot += math.ceil(size / 32)
else:
offset += size
else:
self._persistent_storage_layouts[contract_name][var.canonical_name] = (
slot,
offset,
)

if new_slot:
slot += math.ceil(size / 32)
else:
offset += size

def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]:
return self._storage_layouts[contract.name][var.canonical_name]
if var.is_stored:
return self._persistent_storage_layouts[contract.name][var.canonical_name]
return self._transient_storage_layouts[contract.name][var.canonical_name]

# endregion
46 changes: 17 additions & 29 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,55 +440,43 @@ def variables_as_dict(self) -> Dict[str, "StateVariable"]:
def state_variables(self) -> List["StateVariable"]:
"""
Returns all the accessible variables (do not include private variable from inherited contract).
Use state_variables_ordered for all the variables following the storage order
Use stored_state_variables_ordered for all the storage variables following the storage order
Use transient_state_variables_ordered for all the transient variables following the storage order
list(StateVariable): List of the state variables.
"""
return list(self._variables.values())

@property
def stored_state_variables(self) -> List["StateVariable"]:
def state_variables_entry_points(self) -> List["StateVariable"]:
"""
Returns state variables with storage locations, excluding private variables from inherited contracts.
Use stored_state_variables_ordered to access variables with storage locations in their declaration order.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations.
list(StateVariable): List of the state variables that are public.
"""
return [variable for variable in self.state_variables if variable.is_stored]
return [var for var in self._variables.values() if var.visibility == "public"]

@property
def stored_state_variables_ordered(self) -> List["StateVariable"]:
def state_variables_ordered(self) -> List["StateVariable"]:
"""
list(StateVariable): List of the state variables with storage locations by order of declaration.
This implementation filters out state variables if they are constant or immutable. It will be
updated to accommodate any new non-storage keywords that might replace 'constant' and 'immutable' in the future.
Returns:
List[StateVariable]: A list of state variables with storage locations ordered by declaration.
list(StateVariable): List of the state variables by order of declaration.
"""
return [variable for variable in self.state_variables_ordered if variable.is_stored]
return self._variables_ordered

def add_state_variables_ordered(self, new_vars: List["StateVariable"]) -> None:
self._variables_ordered += new_vars

@property
def state_variables_entry_points(self) -> List["StateVariable"]:
def storage_variables_ordered(self) -> List["StateVariable"]:
"""
list(StateVariable): List of the state variables that are public.
list(StateVariable): List of the state variables in storage location by order of declaration.
"""
return [var for var in self._variables.values() if var.visibility == "public"]
return [v for v in self._variables_ordered if v.is_stored]

@property
def state_variables_ordered(self) -> List["StateVariable"]:
def transient_variables_ordered(self) -> List["StateVariable"]:
"""
list(StateVariable): List of the state variables by order of declaration.
list(StateVariable): List of the state variables in transient location by order of declaration.
"""
return list(self._variables_ordered)

def add_variables_ordered(self, new_vars: List["StateVariable"]) -> None:
self._variables_ordered += new_vars
return [v for v in self._variables_ordered if v.is_transient]

@property
def state_variables_inherited(self) -> List["StateVariable"]:
Expand Down
16 changes: 16 additions & 0 deletions slither/core/variables/state_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def location(self) -> Optional[str]:
"""
return self._location

@property
def is_stored(self) -> bool:
"""
Checks if the state variable is stored, based on it not being constant or immutable or transient.
"""
return (
not self._is_constant and not self._is_immutable and not self._location == "transient"
)

@property
def is_transient(self) -> bool:
"""
Checks if the state variable is transient. A transient variable can not be constant or immutable.
"""
return self._location == "transient"

# endregion
###################################################################################
###################################################################################
Expand Down
7 changes: 0 additions & 7 deletions slither/core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,6 @@ def is_constant(self) -> bool:
def is_constant(self, is_cst: bool) -> None:
self._is_constant = is_cst

@property
def is_stored(self) -> bool:
"""
Checks if a variable is stored, based on it not being constant or immutable. Future updates may adjust for new non-storage keywords.
"""
return not self._is_constant and not self._is_immutable

@property
def is_reentrant(self) -> bool:
return self._is_reentrant
Expand Down
2 changes: 1 addition & 1 deletion slither/detectors/variables/unchanged_state_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def detect(self) -> None:
variables = []
functions = []

variables.append(c.stored_state_variables)
variables.append(c.storage_variables_ordered)
functions.append(c.all_functions_called)

valid_candidates: Set[StateVariable] = {
Expand Down
13 changes: 10 additions & 3 deletions slither/printers/summary/variable_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,17 @@ def output(self, _filename: str) -> Output:

for contract in self.slither.contracts_derived:
txt += f"\n{contract.name}:\n"
table = MyPrettyTable(["Name", "Type", "Slot", "Offset"])
for variable in contract.stored_state_variables_ordered:
table = MyPrettyTable(["Name", "Type", "Slot", "Offset", "State"])
for variable in contract.storage_variables_ordered:
slot, offset = contract.compilation_unit.storage_layout_of(contract, variable)
table.add_row([variable.canonical_name, str(variable.type), slot, offset])
table.add_row(
[variable.canonical_name, str(variable.type), slot, offset, "Storage"]
)
for variable in contract.transient_variables_ordered:
slot, offset = contract.compilation_unit.storage_layout_of(contract, variable)
table.add_row(
[variable.canonical_name, str(variable.type), slot, offset, "Transient"]
)

all_tables.append((contract.name, table))
txt += str(table) + "\n"
Expand Down
4 changes: 2 additions & 2 deletions slither/solc_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def parse_state_variables(self) -> None:
if v.visibility != "private"
}
)
self._contract.add_variables_ordered(
self._contract.add_state_variables_ordered(
[
var
for var in father.state_variables_ordered
Expand All @@ -370,7 +370,7 @@ def parse_state_variables(self) -> None:
if var_parser.reference_id is not None:
self._contract.state_variables_by_ref_id[var_parser.reference_id] = var
self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var])
self._contract.add_state_variables_ordered([var])

def _parse_modifier(self, modifier_data: Dict) -> None:
modif = Modifier(self._contract.compilation_unit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class VariableWithInit(AbstractCheck):

def _check(self) -> List[Output]:
results = []
for s in self.contract.stored_state_variables_ordered:
for s in self.contract.storage_variables_ordered:
if s.initialized:
info: CHECK_INFO = [s, " is a state variable with an initial value.\n"]
json = self.generate_result(info)
Expand Down
93 changes: 60 additions & 33 deletions slither/tools/upgradeability/checks/variables_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,43 @@ def _contract2(self) -> Contract:
def _check(self) -> List[Output]:
contract1 = self._contract1()
contract2 = self._contract2()
order1 = contract1.stored_state_variables_ordered
order2 = contract2.stored_state_variables_ordered

results: List[Output] = []
for idx, _ in enumerate(order1):
if len(order2) <= idx:
# Handle by MissingVariable
return results

variable1 = order1[idx]
variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info: CHECK_INFO = [
"Different variables between ",
contract1,
" and ",
contract2,
"\n",
]
info += ["\t ", variable1, "\n"]
info += ["\t ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)

def _check_internal(
contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool
):
if is_transient:
order1 = contract1.transient_variables_ordered
order2 = contract2.transient_variables_ordered
else:
order1 = contract1.storage_variables_ordered
order2 = contract2.storage_variables_ordered

for idx, _ in enumerate(order1):
if len(order2) <= idx:
# Handle by MissingVariable
return

variable1 = order1[idx]
variable2 = order2[idx]
if (variable1.name != variable2.name) or (variable1.type != variable2.type):
info: CHECK_INFO = [
"Different variables between ",
contract1,
" and ",
contract2,
"\n",
]
info += ["\t ", variable1, "\n"]
info += ["\t ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)

# Checking state variables with storage location
_check_internal(contract1, contract2, results, False)
# Checking state variables with transient location
_check_internal(contract1, contract2, results, True)

return results

Expand Down Expand Up @@ -236,22 +250,35 @@ def _contract2(self) -> Contract:
def _check(self) -> List[Output]:
contract1 = self._contract1()
contract2 = self._contract2()
order1 = contract1.stored_state_variables_ordered
order2 = contract2.stored_state_variables_ordered

results = []
results: List[Output] = []

if len(order2) <= len(order1):
return []
def _check_internal(
contract1: Contract, contract2: Contract, results: List[Output], is_transient: bool
):
if is_transient:
order1 = contract1.transient_variables_ordered
order2 = contract2.transient_variables_ordered
else:
order1 = contract1.storage_variables_ordered
order2 = contract2.storage_variables_ordered

idx = len(order1)
if len(order2) <= len(order1):
return

while idx < len(order2):
variable2 = order2[idx]
info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
idx = idx + 1
idx = len(order1)

while idx < len(order2):
variable2 = order2[idx]
info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"]
json = self.generate_result(info)
results.append(json)
idx = idx + 1

# Checking state variables with storage location
_check_internal(contract1, contract2, results, False)
# Checking state variables with transient location
_check_internal(contract1, contract2, results, True)

return results

Expand Down
8 changes: 4 additions & 4 deletions slither/utils/upgradeability.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def compare(
tainted-contracts: list[TaintedExternalContract]
"""

order_vars1 = v1.stored_state_variables_ordered
order_vars2 = v2.stored_state_variables_ordered
order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered
order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered
func_sigs1 = [function.solidity_signature for function in v1.functions]
func_sigs2 = [function.solidity_signature for function in v2.functions]

Expand Down Expand Up @@ -306,8 +306,8 @@ def get_missing_vars(v1: Contract, v2: Contract) -> List[StateVariable]:
List of StateVariables from v1 missing in v2
"""
results = []
order_vars1 = v1.stored_state_variables_ordered
order_vars2 = v2.stored_state_variables_ordered
order_vars1 = v1.storage_variables_ordered + v1.transient_variables_ordered
order_vars2 = v2.storage_variables_ordered + v2.transient_variables_ordered
if len(order_vars2) < len(order_vars1):
for variable in order_vars1:
if variable.name not in [v.name for v in order_vars2]:
Expand Down
2 changes: 1 addition & 1 deletion slither/vyper_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def parse_state_variables(self) -> None:

assert var.name
self._contract.variables_as_dict[var.name] = var
self._contract.add_variables_ordered([var])
self._contract.add_state_variables_ordered([var])
# Interfaces can refer to constants
self._contract.file_scope.variables[var.name] = var

Expand Down

0 comments on commit ef03cb7

Please sign in to comment.