Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved contract receipt related logic #1548

Merged
merged 8 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/ape/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def txn_hash(self) -> HexBytes:
@property
def receipt(self) -> Optional["ReceiptAPI"]:
"""
This transaction's associated published receipt,
if it exists.
This transaction's associated published receipt, if it exists.
"""

try:
Expand Down
46 changes: 31 additions & 15 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,22 +584,34 @@ def range(
Iterator[:class:`~ape.contracts.base.ContractLog`]
"""

if not hasattr(self.contract, "address"):
return

start_block = None
stop_block = None

if stop is None:
start_block = 0
contract = None
try:
contract = self.chain_manager.contracts.instance_at(self.contract.address)
except Exception:
pass

if contract:
start_block = contract.receipt.block_number
else:
start_block = self.chain_manager.contracts.get_creation_receipt(
self.contract.address
).block_number

stop_block = start_or_stop
elif start_or_stop is not None and stop is not None:
start_block = start_or_stop
stop_block = stop - 1

stop_block = min(stop_block, self.chain_manager.blocks.height)

addresses = set(
([self.contract.address] if hasattr(self.contract, "address") else [])
+ (extra_addresses or [])
)
addresses = set([self.contract.address] + (extra_addresses or []))
contract_event_query = ContractEventQuery(
columns=list(ContractLog.__fields__.keys()),
contract=addresses,
Expand Down Expand Up @@ -822,25 +834,29 @@ def from_receipt(cls, receipt: ReceiptAPI, contract_type: ContractType) -> "Cont
return instance

@property
def receipt(self) -> Optional[ReceiptAPI]:
def receipt(self) -> ReceiptAPI:
"""
The receipt associated with deploying the contract instance,
if it is known and exists.
"""

if not self._cached_receipt and self.txn_hash:
if self._cached_receipt:
return self._cached_receipt

if self.txn_hash:
# Hash is known. Use that to get the receipt.
try:
receipt = self.chain_manager.get_receipt(self.txn_hash)
except (TransactionNotFoundError, ValueError, ChainError):
return None

self._cached_receipt = receipt
return receipt

elif self._cached_receipt:
return self._cached_receipt
pass
else:
self._cached_receipt = receipt
return receipt

return None
# Brute force find the receipt.
receipt = self.chain_manager.contracts.get_creation_receipt(self.address)
self._cached_receipt = receipt
return receipt

def __repr__(self) -> str:
contract_name = self.contract_type.name or "Unnamed contract"
Expand Down
54 changes: 54 additions & 0 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,51 @@ def _write_deployments_mapping(self, deployments_map: Dict):
with self._deployments_mapping_cache.open("w") as fp:
json.dump(deployments_map, fp, sort_keys=True, indent=2, default=sorted)

def get_creation_receipt(
self, address: AddressType, start_block: int = 0, stop_block: Optional[int] = None
) -> ReceiptAPI:
"""
Get the receipt responsible for the initial creation of the contract.

Args:
address (``AddressType``): The address of the contract.
start_block (int): The block to start looking from.
stop_block (Optional[int]): The block to stop looking at.

Returns:
:class:`~ape.apt.transactions.ReceiptAPI`
"""
if stop_block is None and (stop := self.chain_manager.blocks.head.number):
stop_block = stop
elif stop_block is None:
raise ChainError("Chain missing blocks.")

mid_block = (stop_block - start_block) // 2 + start_block
# NOTE: biased towards mid_block == start_block

if start_block == mid_block:
for tx in self.chain_manager.blocks[mid_block].transactions:
if (receipt := tx.receipt) and receipt.contract_address == address:
return receipt

if mid_block + 1 <= stop_block:
return self.get_creation_receipt(
address, start_block=mid_block + 1, stop_block=stop_block
)
else:
raise ChainError(f"Failed to find a contract-creation receipt for '{address}'.")

elif self.provider.get_code(address, block_id=mid_block):
return self.get_creation_receipt(address, start_block=start_block, stop_block=mid_block)

elif start_block + 1 <= mid_block:
return self.get_creation_receipt(
address, start_block=start_block + 1, stop_block=stop_block
)

else:
raise ChainError(f"Failed to find a contract-creation receipt for '{address}'.")


class ReportManager(BaseManager):
"""
Expand Down Expand Up @@ -1645,4 +1690,13 @@ def set_balance(self, account: Union[BaseAddress, AddressType], amount: Union[in
return self.provider.set_balance(account, amount)

def get_receipt(self, transaction_hash: str) -> ReceiptAPI:
"""
Get a transaction receipt from the chain.

Args:
transaction_hash (str): The hash of the transaction.

Returns:
:class:`~ape.apt.transactions.ReceiptAPI`
"""
return self.chain_manager.history[transaction_hash]
11 changes: 7 additions & 4 deletions src/ape/managers/project/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ape.api import DependencyAPI, ProjectAPI
from ape.api.networks import LOCAL_NETWORK_NAME
from ape.contracts import ContractContainer, ContractInstance, ContractNamespace
from ape.exceptions import ApeAttributeError, APINotImplementedError, ProjectError
from ape.exceptions import ApeAttributeError, APINotImplementedError, ChainError, ProjectError
from ape.logging import logger
from ape.managers.base import BaseManager
from ape.managers.project.types import ApeProject, BrownieProject
Expand Down Expand Up @@ -730,9 +730,12 @@ def track_deployment(self, contract: ContractInstance):
raise ProjectError("Can only publish deployments on a live network.")

contract_name = contract.contract_type.name
receipt = contract.receipt
if not receipt:
raise ProjectError(f"Contract '{contract_name}' transaction receipt is unknown.")
try:
receipt = contract.receipt
except ChainError as err:
raise ProjectError(
f"Contract '{contract_name}' transaction receipt is unknown."
) from err

block_number = receipt.block_number
block_hash_bytes = self.provider.get_block(block_number).hash
Expand Down
10 changes: 10 additions & 0 deletions tests/functional/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,13 @@ def test_cache_non_checksum_address(chain, vyper_contract_instance):
lowered_address = vyper_contract_instance.address.lower()
chain.contracts[lowered_address] = vyper_contract_instance.contract_type
assert chain.contracts[vyper_contract_instance.address] == vyper_contract_instance.contract_type


def test_get_contract_receipt(chain, vyper_contract_instance):
address = vyper_contract_instance.address
receipt = chain.contracts.get_creation_receipt(address)
assert receipt.contract_address == address

chain.mine()
receipt = chain.contracts.get_creation_receipt(address)
assert receipt.contract_address == address
10 changes: 10 additions & 0 deletions tests/functional/test_contract_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,16 @@ def test_receipt(contract_instance, owner):
assert receipt.sender == owner


def test_receipt_when_needs_brute_force(vyper_contract_instance, owner):
# Force it to use the brute-force approach.
vyper_contract_instance._cached_receipt = None
vyper_contract_instance.txn_hash = None

actual = vyper_contract_instance.receipt.contract_address
expected = vyper_contract_instance.address
assert actual == expected


def test_from_receipt_when_receipt_not_deploy(contract_instance, owner):
receipt = contract_instance.setNumber(555, sender=owner)
expected_err = (
Expand Down
10 changes: 10 additions & 0 deletions tests/functional/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def test_transaction_contract_event_query(contract_instance, owner, eth_tester_p
assert df_events.event_name[0] == "FooHappened"


def test_transaction_contract_event_query_starts_query_at_deploy_tx(
contract_instance, owner, eth_tester_provider
):
contract_instance.fooAndBar(sender=owner)
time.sleep(0.1)
df_events = contract_instance.FooHappened.query("*")
assert isinstance(df_events, pd.DataFrame)
assert df_events.event_name[0] == "FooHappened"


class Model(BaseInterfaceModel):
number: int
timestamp: int
Expand Down