Skip to content

Commit

Permalink
refactor: move transaction convert methods to conversion manager [APE…
Browse files Browse the repository at this point in the history
…-1416] (#1676)
  • Loading branch information
antazoey authored Sep 26, 2023
1 parent d3ba065 commit 803d952
Show file tree
Hide file tree
Showing 35 changed files with 220 additions and 225 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ exclude =
docs
build
.eggs
tests/integration/cli/projects
per-file-ignores =
# Need signal handler before imports
src/ape/__init__.py: E402
Expand Down
3 changes: 1 addition & 2 deletions src/ape/api/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ape.exceptions import ConversionError
from ape.types import AddressType, ContractCode
from ape.utils import BaseInterface, abstractmethod, cached_property
from ape.utils.abi import _convert_kwargs

if TYPE_CHECKING:
from ape.api.transactions import ReceiptAPI, TransactionAPI
Expand Down Expand Up @@ -167,7 +166,7 @@ def history(self) -> "AccountHistory":
return self.chain_manager.history[self.address]

def as_transaction(self, **kwargs) -> "TransactionAPI":
converted_kwargs = _convert_kwargs(kwargs, self.conversion_manager.convert)
converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs)
return self.provider.network.ecosystem.create_transaction(
receiver=self.address, **converted_kwargs
)
Expand Down
4 changes: 2 additions & 2 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def encode_contract_blueprint( # type: ignore[empty-body]
or Starknet's ``Declare`` transaction type.
Args:
contract (``ContractType``): The type of contract to create a blueprint for.
contract_type (``ContractType``): The type of contract to create a blueprint for.
This is the type of contract that will get created by factory contracts.
*args: Calldata, if applicable.
**kwargs: Transaction specifications, such as ``value``.
Expand Down Expand Up @@ -686,7 +686,7 @@ def create_adhoc_network(cls) -> "NetworkAPI":
return cls(
name="adhoc",
ecosystem=ethereum,
data_folder=data_folder,
data_folder=Path(data_folder),
request_header=request_header,
_default_provider="geth",
)
Expand Down
4 changes: 2 additions & 2 deletions src/ape/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ethpm_types import Checksum, ContractType, PackageManifest, Source
from ethpm_types.manifest import PackageName
from ethpm_types.utils import AnyUrl, compute_checksum
from ethpm_types.utils import Algorithm, AnyUrl, compute_checksum
from packaging.version import InvalidVersion, Version
from pydantic import ValidationError

Expand Down Expand Up @@ -210,7 +210,7 @@ def _create_source_dict(

source_dict[key] = Source(
checksum=Checksum(
algorithm="md5",
algorithm=Algorithm.MD5,
hash=compute_checksum(source_path.read_bytes()),
),
urls=[],
Expand Down
2 changes: 1 addition & 1 deletion src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def convert_parent_hash(cls, data):
@validator("hash", "parent_hash", pre=True)
def validate_hexbytes(cls, value):
# NOTE: pydantic treats these values as bytes and throws an error
if value and not isinstance(value, HexBytes):
if value and not isinstance(value, bytes):
return HexBytes(value)

return value
Expand Down
2 changes: 1 addition & 1 deletion src/ape/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def extract_fields(item, columns):


class _BaseQuery(BaseModel):
columns: List[str]
columns: Sequence[str]

# TODO: Support "*" from getting the EcosystemAPI fields

Expand Down
2 changes: 1 addition & 1 deletion src/ape/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def total_transfer_value(self) -> int:
to submit the transaction.
"""
if self.max_fee is None:
raise TransactionError("Max fee must not be null.")
raise TransactionError("`self.max_fee` must not be None.")

return self.value + self.max_fee

Expand Down
52 changes: 24 additions & 28 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ape.logging import logger
from ape.types import AddressType, ContractLog, LogFilter, MockContractLog
from ape.utils import ManagerAccessMixin, cached_property, singledispatchmethod
from ape.utils.abi import StructParser, _convert_args, _convert_kwargs
from ape.utils.abi import StructParser


class ContractConstructor(ManagerAccessMixin):
Expand Down Expand Up @@ -54,10 +54,10 @@ def decode_input(self, calldata: bytes) -> Tuple[str, Dict[str, Any]]:
return self.abi.selector, decoded_inputs

def serialize_transaction(self, *args, **kwargs) -> TransactionAPI:
arguments = _convert_args(args, self.conversion_manager.convert, self.abi)
kwargs = _convert_kwargs(kwargs, self.conversion_manager.convert)
arguments = self.conversion_manager.convert_method_args(self.abi, args)
converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs)
return self.provider.network.ecosystem.encode_deployment(
self.deployment_bytecode, self.abi, *arguments, **kwargs
self.deployment_bytecode, self.abi, *arguments, **converted_kwargs
)

def __call__(self, private: bool = False, *args, **kwargs) -> ReceiptAPI:
Expand Down Expand Up @@ -86,9 +86,9 @@ def __repr__(self) -> str:
return self.abi.signature

def serialize_transaction(self, *args, **kwargs) -> TransactionAPI:
kwargs = _convert_kwargs(kwargs, self.conversion_manager.convert)
converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs)
return self.provider.network.ecosystem.encode_transaction(
self.address, self.abi, *args, **kwargs
self.address, self.abi, *args, **converted_kwargs
)

def __call__(self, *args, **kwargs) -> Any:
Expand Down Expand Up @@ -130,9 +130,9 @@ def __str__(self) -> str:

def encode_input(self, *args) -> HexBytes:
selected_abi = _select_method_abi(self.abis, args)
args = self._convert_tuple(args, selected_abi)
arguments = self.conversion_manager.convert_method_args(selected_abi, args)
ecosystem = self.provider.network.ecosystem
encoded_calldata = ecosystem.encode_calldata(selected_abi, *args)
encoded_calldata = ecosystem.encode_calldata(selected_abi, *arguments)
method_id = ecosystem.get_method_selector(selected_abi)
return HexBytes(method_id + encoded_calldata)

Expand Down Expand Up @@ -179,9 +179,6 @@ def decode_input(self, calldata: bytes) -> Tuple[str, Dict[str, Any]]:

raise err

def _convert_tuple(self, v: tuple, abi) -> tuple:
return _convert_args(v, self.conversion_manager.convert, abi)


class ContractCallHandler(ContractMethodHandler):
def __call__(self, *args, **kwargs) -> Any:
Expand All @@ -190,12 +187,12 @@ def __call__(self, *args, **kwargs) -> Any:
raise _get_non_contract_error(self.contract.address, network)

selected_abi = _select_method_abi(self.abis, args)
args = self._convert_tuple(args, selected_abi)
arguments = self.conversion_manager.convert_method_args(selected_abi, args)

return ContractCall(
abi=selected_abi,
address=self.contract.address,
)(*args, **kwargs)
)(*arguments, **kwargs)

def as_transaction(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -236,7 +233,7 @@ def estimate_gas_cost(self, *args, **kwargs) -> int:
"""

selected_abi = _select_method_abi(self.abis, args)
arguments = _convert_args(args, self.conversion_manager.convert, selected_abi)
arguments = self.conversion_manager.convert_method_args(selected_abi, args)
return self.transact.estimate_gas_cost(*arguments, **kwargs)


Expand Down Expand Up @@ -271,10 +268,10 @@ def serialize_transaction(self, *args, **kwargs) -> TransactionAPI:
# Automatically impersonate contracts (if API available) when sender
kwargs["sender"] = self.account_manager.test_accounts[kwargs["sender"].address]

arguments = _convert_args(args, self.conversion_manager.convert, self.abi)
kwargs = _convert_kwargs(kwargs, self.conversion_manager.convert)
arguments = self.conversion_manager.convert_method_args(self.abi, args)
converted_kwargs = self.conversion_manager.convert_method_kwargs(kwargs)
return self.provider.network.ecosystem.encode_transaction(
self.address, self.abi, *arguments, **kwargs
self.address, self.abi, *arguments, **converted_kwargs
)

def __call__(self, *args, **kwargs) -> ReceiptAPI:
Expand Down Expand Up @@ -328,7 +325,7 @@ def estimate_gas_cost(self, *args, **kwargs) -> int:
reported in the fee-currency's smallest unit, e.g. Wei.
"""
selected_abi = _select_method_abi(self.abis, args)
arguments = _convert_args(args, self.conversion_manager.convert, selected_abi)
arguments = self.conversion_manager.convert_method_args(selected_abi, args)
txn = self.as_transaction(*arguments, **kwargs)
return self.provider.estimate_gas_cost(txn)

Expand Down Expand Up @@ -356,7 +353,6 @@ def _as_transaction(self, *args) -> ContractTransaction:
raise _get_non_contract_error(self.contract.address, network)

selected_abi = _select_method_abi(self.abis, args)
args = self._convert_tuple(args, selected_abi)

return ContractTransaction(
abi=selected_abi,
Expand Down Expand Up @@ -500,7 +496,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> MockContractLog:
else:
converted_args[key] = self.conversion_manager.convert(value, py_type)

properties = {"event_arguments": converted_args, "event_name": self.abi.name}
properties: Dict = {"event_arguments": converted_args, "event_name": self.abi.name}
if hasattr(self.contract, "address"):
# Only address if this is off an instance.
properties["contract_address"] = self.contract.address
Expand Down Expand Up @@ -551,7 +547,7 @@ def query(
if columns[0] == "*":
columns = list(ContractLog.__fields__) # type: ignore

query = {
query: Dict = {
"columns": columns,
"event": self.abi,
"start_block": start_block,
Expand Down Expand Up @@ -595,7 +591,7 @@ def range(
Iterator[:class:`~ape.contracts.base.ContractLog`]
"""

if not hasattr(self.contract, "address"):
if not (contract_address := getattr(self.contract, "address", None)):
return

start_block = None
Expand All @@ -604,16 +600,16 @@ def range(
if stop is None:
contract = None
try:
contract = self.chain_manager.contracts.instance_at(self.contract.address)
contract = self.chain_manager.contracts.instance_at(contract_address)
except Exception:
pass

if contract:
start_block = contract.receipt.block_number
else:
start_block = self.chain_manager.contracts.get_creation_receipt(
self.contract.address
).block_number
cache = self.chain_manager.contracts
receipt = cache.get_creation_receipt(contract_address)
start_block = receipt.block_number

stop_block = start_or_stop
elif start_or_stop is not None and stop is not None:
Expand All @@ -622,7 +618,7 @@ def range(

stop_block = min(stop_block, self.chain_manager.blocks.height)

addresses = set([self.contract.address] + (extra_addresses or []))
addresses = list(set([contract_address] + (extra_addresses or [])))
contract_event_query = ContractEventQuery(
columns=list(ContractLog.__fields__.keys()),
contract=addresses,
Expand Down Expand Up @@ -1144,7 +1140,7 @@ def __getattr__(self, attr_name: str) -> Any:
if attr_name in set(super(BaseAddress, self).__dir__()):
return super(BaseAddress, self).__getattribute__(attr_name)

if attr_name not in {
elif attr_name not in {
*self._view_methods_,
*self._mutable_methods_,
*self._events_,
Expand Down
10 changes: 5 additions & 5 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __iter__(self) -> Iterator[BlockAPI]:

def query(
self,
*columns: List[str],
*columns: str,
start_block: int = 0,
stop_block: Optional[int] = None,
step: int = 1,
Expand All @@ -130,7 +130,7 @@ def query(
than the chain length.
Args:
columns (List[str]): columns in the DataFrame to return
*columns (str): columns in the DataFrame to return
start_block (int): The first block, by number, to include in the
query. Defaults to 0.
stop_block (Optional[int]): The last block, by number, to include
Expand Down Expand Up @@ -438,12 +438,12 @@ def outgoing(self) -> Iterator[ReceiptAPI]:
start_nonce = receipt.nonce + 1 # start next loop on the next item

if start_nonce != stop_nonce:
# NOTE: there is no more sessional history, so just return query engine iterator
# NOTE: there is no more session history, so just return query engine iterator
yield from iter(self[start_nonce : stop_nonce + 1]) # noqa: E203

def query(
self,
*columns: List[str],
*columns: str,
start_nonce: int = 0,
stop_nonce: Optional[int] = None,
engine_to_use: Optional[str] = None,
Expand All @@ -459,7 +459,7 @@ def query(
than the account's current nonce.
Args:
columns (List[str]): columns in the DataFrame to return
*columns (str): columns in the DataFrame to return
start_nonce (int): The first transaction, by nonce, to include in the
query. Defaults to 0.
stop_nonce (Optional[int]): The last transaction, by nonce, to include
Expand Down
37 changes: 28 additions & 9 deletions src/ape/managers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if TYPE_CHECKING:
from .project import ProjectManager

from ethpm_types import PackageMeta
from ethpm_types import BaseModel, PackageMeta

CONFIG_FILE_NAME = "ape-config.yaml"

Expand All @@ -28,9 +28,16 @@ class CompilerConfig(PluginConfig):
"""List of globular files to ignore"""


class DeploymentConfigCollection(dict):
def __init__(self, data: Dict, valid_ecosystems: Dict, valid_networks: List[str]):
for ecosystem_name, networks in data.items():
class DeploymentConfigCollection(BaseModel):
__root__: Dict

@root_validator(pre=True)
def validate_deployments(cls, data: Dict):
root_data = data.get("__root__", data)
valid_ecosystems = root_data.pop("valid_ecosystems", {})
valid_networks = root_data.pop("valid_networks", {})
valid_data: Dict = {}
for ecosystem_name, networks in root_data.items():
if ecosystem_name not in valid_ecosystems:
logger.warning(f"Invalid ecosystem '{ecosystem_name}' in deployments config.")
continue
Expand All @@ -41,21 +48,29 @@ def __init__(self, data: Dict, valid_ecosystems: Dict, valid_networks: List[str]
logger.warning(f"Invalid network '{network_name}' in deployments config.")
continue

valid_deployments = []
for deployment in [d for d in contract_deployments]:
address = deployment.get("address", None)
if "address" not in deployment:
if not (address := deployment.get("address")):
logger.warning(
f"Missing 'address' field in deployment "
f"(ecosystem={ecosystem_name}, network={network_name})"
)
continue

valid_deployment = {**deployment}
try:
deployment["address"] = ecosystem.decode_address(address)
valid_deployment["address"] = ecosystem.decode_address(address)
except ValueError as err:
logger.warning(str(err))

super().__init__(data)
valid_deployments.append(valid_deployment)

valid_data[ecosystem_name] = {
**valid_data.get(ecosystem_name, {}),
network_name: valid_deployments,
}

return {"__root__": valid_data}


class ConfigManager(BaseInterfaceModel):
Expand Down Expand Up @@ -194,7 +209,11 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]:
valid_ecosystems = dict(self.plugin_manager.ecosystems)
valid_network_names = [n[1] for n in [e[1] for e in self.plugin_manager.networks]]
self.deployments = configs["deployments"] = DeploymentConfigCollection(
deployments, valid_ecosystems, valid_network_names
__root__={
**deployments,
"valid_ecosystems": valid_ecosystems,
"valid_networks": valid_network_names,
}
)

for plugin_name, config_class in self.plugin_manager.config_class:
Expand Down
Loading

0 comments on commit 803d952

Please sign in to comment.