Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch encoder/decoder to convert ABI addresses #133

Merged
merged 6 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
_BaseEVMContract,
_handle_child_trace,
)
from boa.contracts.utils import decode_addresses, encode_addresses
from boa.environment import Address
from boa.util.abi import ABIError, abi_decode, abi_encode, is_abi_encodable
from boa.util.abi import ABIError, Address, abi_decode, abi_encode, is_abi_encodable


class ABIFunction:
Expand Down Expand Up @@ -94,14 +92,11 @@ def _merge_kwargs(self, *args, **kwargs) -> list:
)
try:
kwarg_inputs = self._abi["inputs"][len(args) :]
merged = list(args) + [kwargs.pop(i["name"]) for i in kwarg_inputs]
return list(args) + [kwargs.pop(i["name"]) for i in kwarg_inputs]
except KeyError as e:
error = f"Missing keyword argument {e} for `{self.signature}`. Passed {args} {kwargs}"
raise TypeError(error)

# allow address objects to be passed in place of addresses
return encode_addresses(merged)

def __call__(self, *args, value=0, gas=None, sender=None, **kwargs):
"""Calls the function with the given arguments based on the ABI contract."""
if not self.contract or not self.contract.env:
Expand Down Expand Up @@ -235,12 +230,10 @@ def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]

schema = f"({_format_abi_type(abi_type)})"
try:
decoded = abi_decode(schema, computation.output)
return abi_decode(schema, computation.output)
except ABIError as e:
raise BoaError(self.stack_trace(computation)) from e

return tuple(decode_addresses(typ, val) for typ, val in zip(abi_type, decoded))

def stack_trace(self, computation: ComputationAPI) -> StackTrace:
"""
Create a stack trace for a failed contract call.
Expand Down Expand Up @@ -300,15 +293,16 @@ def at(self, address: Address | str) -> ABIContract:
return contract


def _abi_from_json(abi: dict) -> list | str:
def _abi_from_json(abi: dict) -> str:
"""
Parses an ABI type into a list of types.
Parses an ABI type into its schema string.
:param abi: The ABI type to parse.
:return: A list of types or a single type.
:return: The schema string for the given abi type.
"""
if "components" in abi:
assert abi["type"] == "tuple" # sanity check
return [_abi_from_json(item) for item in abi["components"]]
assert abi["type"] in ("tuple", "tuple[]") # sanity check
components = [_abi_from_json(item) for item in abi["components"]]
return abi["type"].replace("tuple", f"({','.join(components)})")
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
return abi["type"]


Expand Down
3 changes: 2 additions & 1 deletion boa/contracts/base_evm_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from eth.abc import ComputationAPI

from boa.environment import Address, Env
from boa.environment import Env
from boa.util.abi import Address
from boa.util.exceptions import strip_internal_frames


Expand Down
27 changes: 0 additions & 27 deletions boa/contracts/utils.py

This file was deleted.

6 changes: 3 additions & 3 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
)
from boa.contracts.vyper.event import Event, RawEvent
from boa.contracts.vyper.ir_executor import executor_from_ir
from boa.environment import Address, Env
from boa.environment import Env
from boa.profiling import LineProfile, cache_gas_used_for_computation
from boa.util.abi import abi_decode, abi_encode
from boa.util.abi import Address, abi_decode, abi_encode
from boa.util.lrudict import lrudict
from boa.vm.gas_meters import ProfilingGasMeter
from boa.vm.utils import to_bytes, to_int
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def vyper_object(val, vyper_type):
# and tag it with _vyper_type metadata

vt = type(val)
if vt is bool:
if vt is bool or vt is Address:
# https://stackoverflow.com/q/2172189
# bool is not ambiguous wrt vyper type anyways.
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
return val
Expand Down
43 changes: 3 additions & 40 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import random
import sys
import warnings
from typing import Annotated, Any, Iterator, Optional, Tuple
from typing import Any, Iterator, Optional, Tuple

import eth.constants as constants
import eth.tools.builder.chain as chain
Expand All @@ -20,11 +20,10 @@
from eth.vm.opcode_values import STOP
from eth.vm.transaction_context import BaseTransactionContext
from eth_typing import Address as PYEVM_Address # it's just bytes.
from eth_utils import setup_DEBUG2_logging, to_canonical_address, to_checksum_address
from eth_utils import setup_DEBUG2_logging

from boa.util.abi import abi_decode
from boa.util.abi import Address, abi_decode
from boa.util.eip1167 import extract_eip1167_address, is_eip1167_contract
from boa.util.lrudict import lrudict
from boa.vm.fast_accountdb import patch_pyevm_state_object, unpatch_pyevm_state_object
from boa.vm.fork import AccountDBFork
from boa.vm.gas_meters import GasMeter, NoGasMeter, ProfilingGasMeter
Expand Down Expand Up @@ -94,42 +93,6 @@ def anchor(self):
setattr(self, attr, snap[attr])


# XXX: inherit from bytes directly so that we can pass it to py-evm?
# inherit from `str` so that ABI encoder / decoder can work without failing
class Address(str): # (PYEVM_Address):
# converting between checksum and canonical addresses is a hotspot;
# this class contains both and caches recently seen conversions
__slots__ = ("canonical_address",)
_cache = lrudict(1024)

canonical_address: Annotated[PYEVM_Address, "canonical address"]

def __new__(cls, address):
if isinstance(address, Address):
return address

try:
return cls._cache[address]
except KeyError:
pass

checksum_address = to_checksum_address(address)
self = super().__new__(cls, checksum_address)
self.canonical_address = to_canonical_address(address)
cls._cache[address] = self
return self

# def __hash__(self):
# return hash(self.checksum_address)

# def __eq__(self, other):
# return super().__eq__(self, other)

def __repr__(self):
checksum_addr = super().__repr__()
return f"_Address({checksum_addr})"


# make mypy happy
_AddressType = Address | str | bytes | PYEVM_Address

Expand Down
2 changes: 1 addition & 1 deletion boa/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
VyperContract,
VyperDeployer,
)
from boa.environment import Address
from boa.explorer import fetch_abi_from_etherscan
from boa.util.abi import Address
from boa.util.disk_cache import DiskCache

_Contract = Union[VyperContract, VyperBlueprint]
Expand Down
3 changes: 2 additions & 1 deletion boa/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from eth_account import Account
from requests.exceptions import HTTPError

from boa.environment import Address, Env
from boa.environment import Env
from boa.rpc import (
EthereumRPC,
RPCError,
Expand All @@ -19,6 +19,7 @@
to_int,
trim_dict,
)
from boa.util.abi import Address


class TraceObject:
Expand Down
57 changes: 54 additions & 3 deletions boa/util/abi.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,66 @@
# wrapper module around whatever encoder we are using
from typing import Any
from typing import Annotated, Any

from eth.codecs.abi import nodes
from eth.codecs.abi.decoder import Decoder
from eth.codecs.abi.encoder import Encoder
from eth.codecs.abi.exceptions import ABIError
from eth.codecs.abi.nodes import ABITypeNode
from eth.codecs.abi.parser import Parser
from eth_typing import Address as PYEVM_Address
from eth_utils import to_canonical_address, to_checksum_address

from boa.util.lrudict import lrudict

_parsers: dict[str, ABITypeNode] = {}


# XXX: inherit from bytes directly so that we can pass it to py-evm?
# inherit from `str` so that ABI encoder / decoder can work without failing
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
class Address(str): # (PYEVM_Address):
# converting between checksum and canonical addresses is a hotspot;
# this class contains both and caches recently seen conversions
__slots__ = ("canonical_address",)
_cache = lrudict(1024)

canonical_address: Annotated[PYEVM_Address, "canonical address"]

def __new__(cls, address):
if isinstance(address, Address):
return address

try:
return cls._cache[address]
except KeyError:
pass

checksum_address = to_checksum_address(address)
self = super().__new__(cls, checksum_address)
self.canonical_address = to_canonical_address(address)
cls._cache[address] = self
return self

def __repr__(self):
checksum_addr = super().__repr__()
return f"_Address({checksum_addr})"


class ABIEncoder(Encoder):
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def visit_AddressNode(cls, node: nodes.AddressNode, value) -> bytes:
value = getattr(value, "address", value)
return super().visit_AddressNode(node, value)


class ABIDecoder(Decoder):
@classmethod
def visit_AddressNode(
cls, node: nodes.AddressNode, value: bytes, checksum: bool = True, **kwargs: Any
) -> "Address":
ret = super().visit_AddressNode(node, value)
return Address(ret)


def _get_parser(schema: str):
try:
return _parsers[schema]
Expand All @@ -19,11 +70,11 @@ def _get_parser(schema: str):


def abi_encode(schema: str, data: Any) -> bytes:
return Encoder.encode(_get_parser(schema), data)
return ABIEncoder.encode(_get_parser(schema), data)


def abi_decode(schema: str, data: bytes) -> Any:
return Decoder.decode(_get_parser(schema), data)
return ABIDecoder.decode(_get_parser(schema), data)


def is_abi_encodable(abi_type: str, data: Any) -> bool:
Expand Down
25 changes: 18 additions & 7 deletions tests/unitary/test_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import boa
from boa import BoaError
from boa.contracts.abi.abi_contract import ABIContractFactory, ABIFunction
from boa.environment import Address
from boa.util.abi import Address


def load_via_abi(code):
Expand Down Expand Up @@ -89,21 +89,32 @@ def test(_a: address) -> address:
assert isinstance(result, Address)


def test_dynarray():
def test_address_nested():
code = """
struct Test:
address: address
number: uint256

@external
@view
def test(_a: DynArray[uint256, 100]) -> (DynArray[address, 2], uint256):
return [msg.sender, msg.sender], _a[0]
def test(_a: DynArray[uint256, 100]) -> ((DynArray[Test, 2], uint256), uint256):
first: DynArray[Test, 2] = [
Test({address: msg.sender, number: _a[0]}),
Test({address: msg.sender, number: _a[1]}),
]
return (first, _a[2]), _a[3]
"""
abi_contract, vyper_contract = load_via_abi(code)
deployer_contract = abi_contract.deployer.at(abi_contract.address)
given = [1, 2, 3, 4, 5]
sender = Address(boa.env.eoa)
expected = ([sender, sender], 1)
expected = (([(sender, 1), (sender, 2)], 3), 4)
abi_result = abi_contract.test(given)
assert abi_result == expected
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(abi_result[0][0][0][0], Address)

assert vyper_contract.test(given) == expected
assert abi_contract.test(given) == expected
assert deployer_contract.test(given) == expected
assert deployer_contract.test(given) == abi_result


def test_overloading():
Expand Down
Loading