Skip to content

Commit

Permalink
slither-read-storage native POA support (#1843)
Browse files Browse the repository at this point in the history
* Native support for POA networks in read_storage

* Set `srs.rpc` before `srs.block`

* Type hint

* New RpcInfo class w/ RpcInfo.web3 and RpcInfo.block
In SlitherReadStorage, self.rpc_info: Optional[RpcInfo]
replaces self.rpc, self._block, self._web3

* Black

* Update test_read_storage.py

* Add import in __init__.py

* Avoid instantiating SRS twice

* Add comment about `get_block` for POA networks

* Pylint

* Black

* Allow other valid block string arguments
["latest", "earliest", "pending", "safe", "finalized"]

* `args.block` can be in ["latest", "earliest", "pending", "safe", "finalized"]

* Use BlockTag enum class for valid `str` arguments

* Tweak `RpcInfo.__init__()` signature

* get rid of `or "latest"`

* Import BlockTag

* Use `web3.types.BlockIdentifier`

* Revert BlockTag enum

* Pylint and black

* Replace missing newline

* Update slither/tools/read_storage/__main__.py

Better, cleaner python

Co-authored-by: alpharush <[email protected]>

* Drop try/except around args.block parsing
allow ValueError if user provides invalid block arg

* Remove unused import

---------

Co-authored-by: alpharush <[email protected]>
  • Loading branch information
webthethird and 0xalpharush authored Jun 2, 2023
1 parent 48e3466 commit 00461aa
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 37 deletions.
2 changes: 1 addition & 1 deletion slither/tools/read_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .read_storage import SlitherReadStorage
from .read_storage import SlitherReadStorage, RpcInfo
29 changes: 13 additions & 16 deletions slither/tools/read_storage/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from crytic_compile import cryticparser

from slither import Slither
from slither.tools.read_storage.read_storage import SlitherReadStorage
from slither.tools.read_storage.read_storage import SlitherReadStorage, RpcInfo


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -126,22 +126,19 @@ def main() -> None:
else:
contracts = slither.contracts

srs = SlitherReadStorage(contracts, args.max_depth)

try:
srs.block = int(args.block)
except ValueError:
srs.block = str(args.block or "latest")

rpc_info = None
if args.rpc_url:
# Remove target prefix e.g. rinkeby:0x0 -> 0x0.
address = target[target.find(":") + 1 :]
# Default to implementation address unless a storage address is given.
if not args.storage_address:
args.storage_address = address
srs.storage_address = args.storage_address

srs.rpc = args.rpc_url
valid = ["latest", "earliest", "pending", "safe", "finalized"]
block = args.block if args.block in valid else int(args.block)
rpc_info = RpcInfo(args.rpc_url, block)

srs = SlitherReadStorage(contracts, args.max_depth, rpc_info)
# Remove target prefix e.g. rinkeby:0x0 -> 0x0.
address = target[target.find(":") + 1 :]
# Default to implementation address unless a storage address is given.
if not args.storage_address:
args.storage_address = address
srs.storage_address = args.storage_address

if args.variable_name:
# Use a lambda func to only return variables that have same name as target.
Expand Down
57 changes: 40 additions & 17 deletions slither/tools/read_storage/read_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

from eth_abi import decode, encode
from eth_typing.evm import ChecksumAddress
from eth_utils import keccak
from eth_utils import keccak, to_checksum_address
from web3 import Web3
from web3.types import BlockIdentifier
from web3.exceptions import ExtraDataLengthError
from web3.middleware import geth_poa_middleware

from slither.core.declarations import Contract, Structure
from slither.core.solidity_types import ArrayType, ElementaryType, MappingType, UserDefinedType
Expand Down Expand Up @@ -42,18 +45,43 @@ class SlitherReadStorageException(Exception):
pass


class RpcInfo:
def __init__(self, rpc_url: str, block: BlockIdentifier = "latest") -> None:
assert isinstance(block, int) or block in [
"latest",
"earliest",
"pending",
"safe",
"finalized",
]
self.rpc: str = rpc_url
self._web3: Web3 = Web3(Web3.HTTPProvider(self.rpc))
"""If the RPC is for a POA network, the first call to get_block fails, so we inject geth_poa_middleware"""
try:
self._block: int = self.web3.eth.get_block(block)["number"]
except ExtraDataLengthError:
self._web3.middleware_onion.inject(geth_poa_middleware, layer=0)
self._block: int = self.web3.eth.get_block(block)["number"]

@property
def web3(self) -> Web3:
return self._web3

@property
def block(self) -> int:
return self._block


# pylint: disable=too-many-instance-attributes
class SlitherReadStorage:
def __init__(self, contracts: List[Contract], max_depth: int) -> None:
def __init__(self, contracts: List[Contract], max_depth: int, rpc_info: RpcInfo = None) -> None:
self._checksum_address: Optional[ChecksumAddress] = None
self._contracts: List[Contract] = contracts
self._log: str = ""
self._max_depth: int = max_depth
self._slot_info: Dict[str, SlotInfo] = {}
self._target_variables: List[Tuple[Contract, StateVariable]] = []
self._web3: Optional[Web3] = None
self.block: Union[str, int] = "latest"
self.rpc: Optional[str] = None
self.rpc_info: Optional[RpcInfo] = rpc_info
self.storage_address: Optional[str] = None
self.table: Optional[MyPrettyTable] = None

Expand All @@ -73,18 +101,12 @@ def log(self) -> str:
def log(self, log: str) -> None:
self._log = log

@property
def web3(self) -> Web3:
if not self._web3:
self._web3 = Web3(Web3.HTTPProvider(self.rpc))
return self._web3

@property
def checksum_address(self) -> ChecksumAddress:
if not self.storage_address:
raise ValueError
if not self._checksum_address:
self._checksum_address = self.web3.to_checksum_address(self.storage_address)
self._checksum_address = to_checksum_address(self.storage_address)
return self._checksum_address

@property
Expand Down Expand Up @@ -223,11 +245,12 @@ def get_slot_values(self, slot_info: SlotInfo) -> None:
"""Fetches the slot value of `SlotInfo` object
:param slot_info:
"""
assert self.rpc_info is not None
hex_bytes = get_storage_data(
self.web3,
self.rpc_info.web3,
self.checksum_address,
int.to_bytes(slot_info.slot, 32, byteorder="big"),
self.block,
self.rpc_info.block,
)
slot_info.value = self.convert_value_to_type(
hex_bytes, slot_info.size, slot_info.offset, slot_info.type_string
Expand Down Expand Up @@ -600,15 +623,15 @@ def _get_array_length(self, type_: Type, slot: int) -> int:
(int): The length of the array.
"""
val = 0
if self.rpc:
if self.rpc_info:
# The length of dynamic arrays is stored at the starting slot.
# Convert from hexadecimal to decimal.
val = int(
get_storage_data(
self.web3,
self.rpc_info.web3,
self.checksum_address,
int.to_bytes(slot, 32, byteorder="big"),
self.block,
self.rpc_info.block,
).hex(),
16,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/tools/read-storage/test_read_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from web3.contract import Contract

from slither import Slither
from slither.tools.read_storage import SlitherReadStorage
from slither.tools.read_storage import SlitherReadStorage, RpcInfo

TEST_DATA_DIR = Path(__file__).resolve().parent / "test_data"

Expand Down Expand Up @@ -105,8 +105,8 @@ def test_read_storage(web3, ganache, solc_binary_path) -> None:
sl = Slither(Path(TEST_DATA_DIR, "storage_layout-0.8.10.sol").as_posix(), solc=solc_path)
contracts = sl.contracts

srs = SlitherReadStorage(contracts, 100)
srs.rpc = ganache.provider
rpc_info: RpcInfo = RpcInfo(ganache.provider)
srs = SlitherReadStorage(contracts, 100, rpc_info)
srs.storage_address = address
srs.get_all_storage_variables()
srs.get_storage_layout()
Expand Down

0 comments on commit 00461aa

Please sign in to comment.