Skip to content

Commit

Permalink
Retry GetExecBatchResults on ApiExceptions caused by GSB Errors (#588)
Browse files Browse the repository at this point in the history
* Implement re-tries for ApiExceptions caused by GSB Errors

* Endpoint -> endpoint

* Apply fixes after code review

* Fixes after code review: part II

* debug -> warning

* Improve logs when activity is prematurely terminated on the provider

* Formatting

* Raise BatchError when an activity is terminated by the provider

* Add unit tests for PollingBatch behavior when GSB errors occur

Co-authored-by: filipgolem <[email protected]>
Co-authored-by: Filip <[email protected]>
  • Loading branch information
3 people authored Aug 11, 2021
1 parent f9ac257 commit 4bd79db
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 21 deletions.
126 changes: 126 additions & 0 deletions tests/rest/test_activity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Optional, Tuple, Type
from unittest.mock import Mock

import pytest

from ya_activity.exceptions import ApiException
from yapapi.rest.activity import BatchError, PollingBatch


GetExecBatchResultsSpec = Tuple[Optional[Exception], List[str]]


def mock_activity(specs: List[GetExecBatchResultsSpec]):
"""Create a mock activity.
The argument `specs` is a list of pairs specifying the behavior of subsequent calls
to `get_exec_batch_results()`: i-th pair corresponds to the i-th call.
The first element of the pair is an optional error raised by the call, the second element
is the activity state (the `.state` component of the object returned by `Activity.state()`).
"""
i = -1

async def mock_results(*_args, **_kwargs):
nonlocal specs, i
i += 1
error = specs[i][0]
if error:
raise error
return [Mock(index=0)]

async def mock_state():
nonlocal specs, i
state = specs[i][1]
return Mock(state=state)

return Mock(state=mock_state, _api=Mock(get_exec_batch_results=mock_results))


GSB_ERROR = ":( GSB error: some endpoint address not found :("


@pytest.mark.parametrize(
"specs, expected_error",
[
# No errors
([(None, ["Running", "Running"])], None),
# Exception other than ApiException should stop iteration over batch results
(
[(ValueError("!?"), ["Running", "Running"])],
ValueError,
),
# ApiException not related to GSB should stop iteration over batch results
(
[(ApiException(status=400), ["Running", "Running"])],
ApiException,
),
# As above, but with status 500
(
[
(
ApiException(http_resp=Mock(status=500, data='{"message": "???"}')),
["Running", "Running"],
)
],
ApiException,
),
# ApiException not related to GSB should raise BatchError if activity is terminated
(
[
(
ApiException(http_resp=Mock(status=500, data='{"message": "???"}')),
["Running", "Terminated"],
)
],
BatchError,
),
# GSB-related ApiException should cause retrying if the activity is running
(
[
(
ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')),
["Running", "Running"],
),
(None, ["Running", "Running"]),
],
None,
),
# As above, but max number of tries is reached
(
[
(
ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')),
["Running", "Running"],
)
]
* PollingBatch.GET_EXEC_BATCH_RESULTS_MAX_TRIES,
ApiException,
),
# GSB-related ApiException should raise BatchError if activity is terminated
(
[
(
ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')),
["Running", "Terminated"],
)
],
BatchError,
),
],
)
@pytest.mark.asyncio
async def test_polling_batch_on_gsb_error(
specs: List[GetExecBatchResultsSpec], expected_error: Optional[Type[Exception]]
) -> None:
"""Test the behavior of PollingBatch when get_exec_batch_results() raises exceptions."""

PollingBatch.GET_EXEC_BATCH_RESULTS_INTERVAL = 0.1

activity = mock_activity(specs)
batch = PollingBatch(activity, "batch_id", 1)
try:
async for _ in batch:
pass
assert expected_error is None
except Exception as error:
assert expected_error is not None and isinstance(error, expected_error)
96 changes: 75 additions & 21 deletions yapapi/rest/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timedelta, timezone
import json
import logging
from typing import AsyncIterator, List, Optional, Type, Any, Dict
from typing import AsyncIterator, List, Optional, Tuple, Type, Any, Dict

from typing_extensions import AsyncContextManager, AsyncIterable

Expand All @@ -21,6 +21,8 @@
)

from yapapi import events
from yapapi.rest.common import is_intermittent_error, SuppressedExceptions


_log = logging.getLogger("yapapi.rest")

Expand Down Expand Up @@ -73,8 +75,8 @@ async def send(self, script: List[dict], deadline: Optional[datetime] = None) ->
batch_id = await self._api.call_exec(self._id, yaa.ExeScriptRequest(text=script_txt))

if self._stream_events:
return StreamingBatch(self._api, self._id, batch_id, len(script), deadline)
return PollingBatch(self._api, self._id, batch_id, len(script), deadline)
return StreamingBatch(self, batch_id, len(script), deadline)
return PollingBatch(self, batch_id, len(script), deadline)

async def __aenter__(self) -> "Activity":
return self
Expand Down Expand Up @@ -146,22 +148,19 @@ class BatchTimeoutError(BatchError):
class Batch(abc.ABC, AsyncIterable[events.CommandEventContext]):
"""Abstract base class for iterating over events related to a batch running on provider."""

_api: RequestorControlApi
_activity_id: str
_activity: Activity
_batch_id: str
_size: int
_deadline: datetime

def __init__(
self,
api: RequestorControlApi,
activity_id: str,
activity: Activity,
batch_id: str,
batch_size: int,
deadline: Optional[datetime] = None,
) -> None:
self._api = api
self._activity_id = activity_id
self._activity = activity
self._batch_id = batch_id
self._size = batch_size
self._deadline = (
Expand All @@ -179,25 +178,80 @@ def id(self):
return self._batch_id


def _is_gsb_endpoint_not_found_error(err: ApiException) -> bool:
"""Check if `err` is caused by "Endpoint address not found" GSB error."""

if err.status != 500:
return False
try:
msg = json.loads(err.body)["message"]
return "GSB error" in msg and "endpoint address not found" in msg
except Exception:
_log.debug("Cannot read error message from ApiException", exc_info=True)
return False


class PollingBatch(Batch):
"""A `Batch` implementation that polls the server repeatedly for command status."""

GET_EXEC_BATCH_RESULTS_MAX_TRIES = 3
"""Max number of attempts to call GetExecBatchResults if a GSB error occurs."""

GET_EXEC_BATCH_RESULTS_INTERVAL = 3.0
"""Time in seconds before retrying GetExecBatchResults after a GSB error occurs."""

async def _activity_terminated(self) -> Tuple[bool, Optional[str], Optional[str]]:
"""Check if the activity we're using is in "Terminated" state."""
try:
state = await self._activity.state()
return "Terminated" in state.state, state.reason, state.error_message
except Exception:
_log.debug("Cannot query activity state", exc_info=True)
return False, None, None

async def _get_results(self, timeout: float) -> List[yaa.ExeScriptCommandResult]:
"""Call GetExecBatchResults with re-trying on "Endpoint address not found" GSB error."""

for n in range(self.GET_EXEC_BATCH_RESULTS_MAX_TRIES, 0, -1):
try:
results = await self._activity._api.get_exec_batch_results(
self._activity._id, self._batch_id, _request_timeout=min(timeout, 5)
)
return results
except ApiException as err:
terminated, reason, error_msg = await self._activity_terminated()
if terminated:
raise BatchError("Activity terminated by provider", reason, error_msg)
# TODO: add and use a new Exception class (subclass of BatchError)
# to indicate closing the activity by the provider
if not _is_gsb_endpoint_not_found_error(err):
raise err
msg = "GetExecBatchResults failed due to GSB error"
if n > 1:
_log.debug("%s, retrying in %s s", msg, self.GET_EXEC_BATCH_RESULTS_INTERVAL)
await asyncio.sleep(self.GET_EXEC_BATCH_RESULTS_INTERVAL)
else:
_log.debug(
"%s, giving up after %d attempts",
msg,
self.GET_EXEC_BATCH_RESULTS_MAX_TRIES,
)
raise err

return []

async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]:
last_idx = 0

while last_idx < self._size:
timeout = self.seconds_left()
if timeout <= 0:
raise BatchTimeoutError()
try:
results: List[yaa.ExeScriptCommandResult] = await self._api.get_exec_batch_results(
self._activity_id, self._batch_id, _request_timeout=min(timeout, 5)
)
except asyncio.TimeoutError:
continue
except ApiException as err:
if err.status == 408:
continue
raise

results: List[yaa.ExeScriptCommandResult] = []
async with SuppressedExceptions(is_intermittent_error):
results = await self._get_results(timeout=min(timeout, 5))

any_new: bool = False
results = results[last_idx:]
for result in results:
Expand Down Expand Up @@ -227,13 +281,13 @@ class StreamingBatch(Batch):
async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]:
from aiohttp_sse_client import client as sse_client # type: ignore

api_client = self._api.api_client
api_client = self._activity._api.api_client
host = api_client.configuration.host
headers = api_client.default_headers

api_client.update_params_for_auth(headers, None, ["app_key"])

activity_id = self._activity_id
activity_id = self._activity._id
batch_id = self._batch_id
last_idx = self._size - 1

Expand Down

0 comments on commit 4bd79db

Please sign in to comment.