diff --git a/tests/rest/test_repeat_on_error.py b/tests/rest/test_repeat_on_error.py new file mode 100644 index 000000000..f7fddeedf --- /dev/null +++ b/tests/rest/test_repeat_on_error.py @@ -0,0 +1,106 @@ +import asyncio + +import aiohttp +import pytest + +import ya_activity +import ya_market +import ya_payment + +from yapapi.rest.common import repeat_on_error, SuppressedExceptions, is_intermittent_error + + +@pytest.mark.parametrize( + "max_tries, exceptions, calls_expected, expected_error", + [ + (1, [], 1, None), + (1, [asyncio.TimeoutError()], 1, asyncio.TimeoutError), + (1, [ya_activity.ApiException(408)], 1, ya_activity.ApiException), + (1, [ya_activity.ApiException(500)], 1, ya_activity.ApiException), + (1, [ValueError()], 1, ValueError), + # + (2, [], 1, None), + (2, [asyncio.TimeoutError()], 2, None), + (2, [ya_activity.ApiException(408)], 2, None), + (2, [ya_market.ApiException(408)], 2, None), + (2, [ya_payment.ApiException(408)], 2, None), + (2, [ya_activity.ApiException(500)], 1, ya_activity.ApiException), + (2, [aiohttp.ServerDisconnectedError()], 2, None), + (2, [aiohttp.ClientOSError(32, "Broken pipe")], 2, None), + (2, [aiohttp.ClientOSError(1132, "UnBroken pipe")], 1, aiohttp.ClientOSError), + (2, [ValueError()], 1, ValueError), + (2, [asyncio.TimeoutError()] * 2, 2, asyncio.TimeoutError), + # + (3, [], 1, None), + (3, [asyncio.TimeoutError()], 2, None), + (3, [ya_activity.ApiException(408)], 2, None), + (3, [asyncio.TimeoutError()] * 2, 3, None), + (3, [asyncio.TimeoutError()] * 3, 3, asyncio.TimeoutError), + (3, [ya_activity.ApiException(500)], 1, ya_activity.ApiException), + (3, [asyncio.TimeoutError(), ValueError()], 2, ValueError), + ], +) +@pytest.mark.asyncio +async def test_repeat_on_error(max_tries, exceptions, calls_expected, expected_error): + + calls_made = 0 + + @repeat_on_error(max_tries=max_tries) + async def request(): + nonlocal calls_made, exceptions + calls_made += 1 + if exceptions: + e = exceptions[0] + exceptions = exceptions[1:] + raise e + return True + + try: + await request() + except Exception as e: + assert expected_error is not None, f"Unexpected exception: {e}" + assert isinstance(e, expected_error), f"Expected an {expected_error}, got {e}" + assert ( + calls_made == calls_expected + ), f"{calls_made} attempts were made, expected {calls_expected}" + + +@pytest.mark.asyncio +async def test_suppressed_exceptions(): + + async with SuppressedExceptions(is_intermittent_error) as se: + pass + assert se.exception is None + + async with SuppressedExceptions(is_intermittent_error) as se: + raise asyncio.TimeoutError() + assert isinstance(se.exception, asyncio.TimeoutError) + + async with SuppressedExceptions(is_intermittent_error) as se: + raise aiohttp.ClientOSError(32, "Broken pipe") + assert isinstance(se.exception, aiohttp.ClientOSError) + + async with SuppressedExceptions(is_intermittent_error) as se: + raise aiohttp.ServerDisconnectedError() + assert isinstance(se.exception, aiohttp.ServerDisconnectedError) + + with pytest.raises(AssertionError): + async with SuppressedExceptions(is_intermittent_error): + raise AssertionError() + + +@pytest.mark.asyncio +async def test_suppressed_exceptions_with_return(): + async def success(): + return "success" + + async def failure(): + raise asyncio.TimeoutError() + + async def func(request): + async with SuppressedExceptions(is_intermittent_error): + return await request + return "failure" # noqa + + assert await func(success()) == "success" + assert await func(failure()) == "failure" diff --git a/yapapi/rest/common.py b/yapapi/rest/common.py new file mode 100644 index 000000000..a957a602a --- /dev/null +++ b/yapapi/rest/common.py @@ -0,0 +1,91 @@ +import asyncio +import functools +import logging +from typing import Callable, Optional + +import aiohttp + +import ya_market +import ya_activity +import ya_payment + + +logger = logging.getLogger(__name__) + + +def is_intermittent_error(e: Exception) -> bool: + """Check if `e` indicates an intermittent communication failure such as network timeout.""" + + is_timeout_exception = isinstance(e, asyncio.TimeoutError) or ( + isinstance(e, (ya_activity.ApiException, ya_market.ApiException, ya_payment.ApiException)) + and e.status in (408, 504) + ) + + return ( + is_timeout_exception + or isinstance(e, aiohttp.ServerDisconnectedError) + # OS error with errno 32 is "Broken pipe" + or (isinstance(e, aiohttp.ClientOSError) and e.errno == 32) + ) + + +class SuppressedExceptions: + """An async context manager for suppressing exceptions satisfying given condition.""" + + exception: Optional[Exception] + + def __init__(self, condition: Callable[[Exception], bool], report_exceptions: bool = True): + self._condition = condition + self._report_exceptions = report_exceptions + self.exception = None + + async def __aenter__(self) -> "SuppressedExceptions": + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_value and self._condition(exc_value): + self.exception = exc_value + if self._report_exceptions: + logger.debug( + "Exception suppressed: %r", exc_value, exc_info=(exc_type, exc_value, traceback) + ) + return True + return False + + +def repeat_on_error( + max_tries: int, + condition: Callable[[Exception], bool] = is_intermittent_error, + interval: float = 0.0, +): + """Decorate a function to repeat calls up to `max_tries` times when errors occur. + + Only exceptions satisfying the given `condition` will cause the decorated function + to be retried. All remaining exceptions will fall through. + """ + + def decorator(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + """Make at most `max_tries` attempts to call `func`.""" + + for try_num in range(1, max_tries + 1): + + if try_num > 1: + await asyncio.sleep(interval) + + async with SuppressedExceptions(condition, False) as se: + return await func(*args, **kwargs) + + assert se.exception # noqa (unreachable) + repeat = try_num < max_tries + msg = f"API call timed out (attempt {try_num}/{max_tries}), " + msg += f"retrying in {interval} s" if repeat else "giving up" + # Don't print traceback if this was the last attempt, let the caller do it. + logger.debug(msg, exc_info=repeat) + if not repeat: + raise se.exception + + return wrapper + + return decorator diff --git a/yapapi/rest/market.py b/yapapi/rest/market.py index f75d27288..8fdfcfd0b 100644 --- a/yapapi/rest/market.py +++ b/yapapi/rest/market.py @@ -1,14 +1,17 @@ import asyncio +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone import logging from types import TracebackType from typing import AsyncIterator, Optional, TypeVar, Type, Generator, Any, Generic from typing_extensions import Awaitable, AsyncContextManager -from ..props import Model from ya_market import ApiClient, ApiException, RequestorApi, models # type: ignore -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone + +from .common import is_intermittent_error, SuppressedExceptions +from ..props import Model + _ModelType = TypeVar("_ModelType", bound=Model) @@ -191,8 +194,10 @@ async def events(self) -> AsyncIterator[OfferProposal]: """Yield counter-proposals based on the incoming, matching Offers.""" while self._open: + proposals = [] try: - proposals = await self._api.collect_offers(self._id, timeout=10, max_events=10) + async with SuppressedExceptions(is_intermittent_error): + proposals = await self._api.collect_offers(self._id, timeout=5, max_events=10) except ApiException as ex: if ex.status == 404: logger.debug( diff --git a/yapapi/rest/payment.py b/yapapi/rest/payment.py index 644d61b06..d6149238f 100644 --- a/yapapi/rest/payment.py +++ b/yapapi/rest/payment.py @@ -7,6 +7,7 @@ from decimal import Decimal from datetime import datetime, timezone, timedelta from dataclasses import dataclass +from .common import repeat_on_error, is_intermittent_error, SuppressedExceptions from .resource import ResourceCtx @@ -18,6 +19,7 @@ def __init__(self, _api: RequestorApi, _base: yap.Invoice): self.__dict__.update(**_base.__dict__) self._api: RequestorApi = _api + @repeat_on_error(max_tries=5) async def accept(self, *, amount: Union[Decimal, str], allocation: "Allocation"): acceptance = yap.Acceptance(total_amount_accepted=str(amount), allocation_id=allocation.id) await self._api.accept_invoice(self.invoice_id, acceptance) @@ -28,6 +30,7 @@ def __init__(self, _api: RequestorApi, _base: yap.DebitNote): self.__dict__.update(**_base.__dict__) self._api: RequestorApi = _api + @repeat_on_error(max_tries=5) async def accept(self, *, amount: Union[Decimal, str], allocation: "Allocation"): acceptance = yap.Acceptance(total_amount_accepted=str(amount), allocation_id=allocation.id) await self._api.accept_debit_note(self.debit_note_id, acceptance) @@ -67,6 +70,7 @@ class Allocation(_Link): expires: Optional[datetime] "Allocation expiration timestamp" + @repeat_on_error(max_tries=5) async def details(self) -> AllocationDetails: details: yap.Allocation = await self._api.get_allocation(self.id) return AllocationDetails( @@ -74,6 +78,7 @@ async def details(self) -> AllocationDetails: remaining_amount=Decimal(details.remaining_amount), ) + @repeat_on_error(max_tries=5) async def delete(self): await self._api.release_allocation(self.id) @@ -191,6 +196,7 @@ async def accounts(self) -> AsyncIterator[Account]: async def decorate_demand(self, ids: List[str]) -> yap.MarketDecoration: return await self._api.get_demand_decorations(ids) + @repeat_on_error(max_tries=5) async def debit_note(self, debit_note_id: str) -> DebitNote: debit_note = await self._api.get_debit_note(debit_note_id) return DebitNote(_api=self._api, _base=debit_note) @@ -200,13 +206,13 @@ async def invoices(self) -> AsyncIterator[Invoice]: for invoice_obj in cast(Iterable[yap.Invoice], await self._api.get_invoices()): yield Invoice(_api=self._api, _base=invoice_obj) + @repeat_on_error(max_tries=5) async def invoice(self, invoice_id: str) -> Invoice: invoice_obj = await self._api.get_invoice(invoice_id) return Invoice(_api=self._api, _base=invoice_obj) def incoming_invoices(self) -> AsyncIterator[Invoice]: ts = datetime.now(timezone.utc) - api = self._api async def fetch(init_ts: datetime): ts = init_ts @@ -214,7 +220,9 @@ async def fetch(init_ts: datetime): # In the current version of `ya-aioclient` the method `get_invoice_events` # incorrectly accepts `timeout` parameter, while the server uses `pollTimeout` # events = await api.get_invoice_events(poll_timeout=5, after_timestamp=ts) - events = await api.get_invoice_events(after_timestamp=ts) + events = [] + async with SuppressedExceptions(is_intermittent_error): + events = await self._api.get_invoice_events(after_timestamp=ts) for ev in events: logger.debug("Received invoice event: %r, type: %s", ev, ev.__class__) if isinstance(ev, yap.InvoiceReceivedEvent): @@ -235,7 +243,9 @@ def incoming_debit_notes(self) -> AsyncIterator[DebitNote]: async def fetch(init_ts: datetime): ts = init_ts while True: - events = await self._api.get_debit_note_events(after_timestamp=ts) + events = [] + async with SuppressedExceptions(is_intermittent_error): + events = await self._api.get_debit_note_events(after_timestamp=ts) for ev in events: logger.debug("Received debit note event: %r, type: %s", ev, ev.__class__) if isinstance(ev, yap.DebitNoteReceivedEvent):