Skip to content

Commit

Permalink
perf: lazy source traceback in transaction errors (#2211)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Aug 16, 2024
1 parent 36ff7cc commit 3f85a75
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 67 deletions.
2 changes: 1 addition & 1 deletion src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,7 @@ def _cache_wrap(self, function: Callable) -> ReceiptAPI:
except ContractLogicError as err:
if address := err.address:
self.chain_manager.contracts[address] = self.contract_type
err._set_tb() # Re-try setting source traceback
err = err.with_ape_traceback() # Re-try setting source traceback
new_err = None
try:
# Try enrichment again now that the contract type is cached.
Expand Down
89 changes: 72 additions & 17 deletions src/ape/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from inspect import getframeinfo, stack
from pathlib import Path
from types import CodeType, TracebackType
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast

import click
from eth_typing import Hash32
Expand Down Expand Up @@ -163,6 +163,12 @@ class MethodNonPayableError(ContractDataError):
"""


_TRACE_ARG = Optional[Union["TraceAPI", Callable[[], Optional["TraceAPI"]]]]
_SOURCE_TRACEBACK_ARG = Optional[
Union["SourceTraceback", Callable[[], Optional["SourceTraceback"]]]
]


class TransactionError(ApeException):
"""
Raised when issues occur related to transactions.
Expand All @@ -176,25 +182,28 @@ def __init__(
base_err: Optional[Exception] = None,
code: Optional[int] = None,
txn: Optional[FailedTxn] = None,
trace: Optional["TraceAPI"] = None,
trace: _TRACE_ARG = None,
contract_address: Optional["AddressType"] = None,
source_traceback: Optional["SourceTraceback"] = None,
source_traceback: _SOURCE_TRACEBACK_ARG = None,
project: Optional["ProjectManager"] = None,
set_ape_traceback: bool = False, # Overriden in ContractLogicError
):
message = message or (str(base_err) if base_err else self.DEFAULT_MESSAGE)
self.message = message
self.base_err = base_err
self.code = code
self.txn = txn
self.trace = trace
self._trace = trace
self.contract_address = contract_address
self.source_traceback: Optional["SourceTraceback"] = source_traceback
self._source_traceback = source_traceback
self._project = project
ex_message = f"({code}) {message}" if code else message

# Finalizes expected revert message.
super().__init__(ex_message)
self._set_tb()

if set_ape_traceback:
self.with_ape_traceback()

@property
def address(self) -> Optional["AddressType"]:
Expand Down Expand Up @@ -223,15 +232,51 @@ def contract_type(self) -> Optional[ContractType]:
except (RecursionError, ProviderNotConnectedError):
return None

def _set_tb(self):
if not self.source_traceback and self.txn:
self.source_traceback = _get_ape_traceback_from_tx(self.txn)
@property
def trace(self) -> Optional["TraceAPI"]:
tr = self._trace
if callable(tr):
result = tr()
self._trace = result
return result

return tr

@trace.setter
def trace(self, value):
self._trace = value

@property
def source_traceback(self) -> Optional["SourceTraceback"]:
tb = self._source_traceback
result: Optional["SourceTraceback"]
if callable(tb):
result = tb()
self._source_traceback = result
else:
result = tb

return result

if src_tb := self.source_traceback:
@source_traceback.setter
def source_traceback(self, value):
self._source_traceback = value

def _get_ape_traceback(self) -> Optional[TracebackType]:
source_tb = self.source_traceback
if not source_tb and self.txn:
source_tb = _get_ape_traceback_from_tx(self.txn)

if src_tb := source_tb:
# Create a custom Pythonic traceback using lines from the sources
# found from analyzing the trace of the transaction.
if py_tb := _get_custom_python_traceback(self, src_tb, project=self._project):
self.__traceback__ = py_tb
return py_tb

return None

def with_ape_traceback(self):
return self.with_traceback(self._get_ape_traceback())


class VirtualMachineError(TransactionError):
Expand All @@ -250,19 +295,22 @@ def __init__(
self,
revert_message: Optional[str] = None,
txn: Optional[FailedTxn] = None,
trace: Optional["TraceAPI"] = None,
trace: _TRACE_ARG = None,
contract_address: Optional["AddressType"] = None,
source_traceback: Optional["SourceTraceback"] = None,
source_traceback: _SOURCE_TRACEBACK_ARG = None,
base_err: Optional[Exception] = None,
project: Optional["ProjectManager"] = None,
set_ape_traceback: bool = True, # Overriden default.
):
self.txn = txn
self.trace = trace
self.contract_address = contract_address

super().__init__(
base_err=base_err,
contract_address=contract_address,
message=revert_message,
project=project,
set_ape_traceback=set_ape_traceback,
source_traceback=source_traceback,
trace=trace,
txn=txn,
Expand Down Expand Up @@ -313,8 +361,15 @@ def __init__(
code: Optional[int] = None,
txn: Optional[FailedTxn] = None,
base_err: Optional[Exception] = None,
set_ape_traceback: bool = False,
):
super().__init__("The transaction ran out of gas.", code=code, txn=txn, base_err=base_err)
super().__init__(
"The transaction ran out of gas.",
code=code,
txn=txn,
base_err=base_err,
set_ape_traceback=set_ape_traceback,
)


class NetworkError(ApeException):
Expand Down Expand Up @@ -786,10 +841,10 @@ def __init__(
abi: ErrorABI,
inputs: dict[str, Any],
txn: Optional[FailedTxn] = None,
trace: Optional["TraceAPI"] = None,
trace: _TRACE_ARG = None,
contract_address: Optional["AddressType"] = None,
base_err: Optional[Exception] = None,
source_traceback: Optional["SourceTraceback"] = None,
source_traceback: _SOURCE_TRACEBACK_ARG = None,
):
self.abi = abi
self.inputs = inputs
Expand Down
2 changes: 1 addition & 1 deletion src/ape/managers/compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def get_custom_error(self, err: ContractLogicError) -> Optional[CustomError]:
HexBytes(message),
address,
base_err=err.base_err,
source_traceback=err.source_traceback,
source_traceback=lambda: err.source_traceback,
trace=err.trace,
txn=err.txn,
)
Expand Down
45 changes: 35 additions & 10 deletions src/ape/pytest/coverage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from typing import Callable, Optional, Union

import click
from ethpm_types.abi import MethodABI
Expand All @@ -23,16 +23,33 @@


class CoverageData(ManagerAccessMixin):
def __init__(self, project: ProjectManager, sources: Iterable[ContractSource]):
def __init__(
self,
project: ProjectManager,
sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]],
):
self.project = project
self.sources = list(sources)
self._sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]] = (
sources
)
self._report: Optional[CoverageReport] = None
self._init_coverage_profile() # Inits self._report.

@property
def sources(self) -> list[ContractSource]:
if isinstance(self._sources, list):
return self._sources

elif callable(self._sources):
# Lazily evaluated.
self._sources = self._sources()

self._sources = [src for src in self._sources]
return self._sources

@property
def report(self) -> CoverageReport:
if self._report is None:
return self._init_coverage_profile()
self._report = self._init_coverage_profile()

return self._report

Expand Down Expand Up @@ -69,7 +86,6 @@ def _init_coverage_profile(
for project in report.projects:
project.sources = [x for x in project.sources if len(x.statements) > 0]

self._report = report
return report

def cover(
Expand Down Expand Up @@ -142,11 +158,20 @@ def __init__(
else:
self._output_path = Path.cwd()

sources = self._project._contract_sources
# Data gets initialized lazily (if coverage is needed).
self._data: Optional[CoverageData] = None

self.data: Optional[CoverageData] = (
CoverageData(self._project, sources) if self.config_wrapper.track_coverage else None
)
@property
def data(self) -> Optional[CoverageData]:
if not self.config_wrapper.track_coverage:
return None

elif self._data is None:
# First time being initialized.
self._data = CoverageData(self._project, lambda: self._project._contract_sources)
return self._data

return self._data

@property
def enabled(self) -> bool:
Expand Down
Loading

0 comments on commit 3f85a75

Please sign in to comment.