-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from shsms/streaming-helper
Add a grpc streaming helper
- Loading branch information
Showing
9 changed files
with
500 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,4 @@ | ||
# License: MIT | ||
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Utilities for writing Frequenz API clients. | ||
TODO(cookiecutter): Add a more descriptive module description. | ||
""" | ||
|
||
|
||
# TODO(cookiecutter): Remove this function | ||
def delete_me(*, blow_up: bool = False) -> bool: | ||
"""Do stuff for demonstration purposes. | ||
Args: | ||
blow_up: If True, raise an exception. | ||
Returns: | ||
True if no exception was raised. | ||
Raises: | ||
RuntimeError: if blow_up is True. | ||
""" | ||
if blow_up: | ||
raise RuntimeError("This function should be removed!") | ||
return True | ||
"""Utilities for writing Frequenz API clients.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# License: MIT | ||
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Implementation of the grpc streaming helper.""" | ||
|
||
import asyncio | ||
import logging | ||
import typing | ||
|
||
import grpc | ||
from grpc.aio import UnaryStreamCall # type: ignore[import] | ||
|
||
from frequenz import channels | ||
|
||
from . import retry_strategy | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
_InputT = typing.TypeVar("_InputT") | ||
_OutputT = typing.TypeVar("_OutputT") | ||
|
||
|
||
class GrpcStreamingHelper(typing.Generic[_InputT, _OutputT]): | ||
"""Helper class to handle grpc streaming methods.""" | ||
|
||
def __init__( | ||
self, | ||
stream_name: str, | ||
stream_method: typing.Callable[[], UnaryStreamCall[typing.Any, _InputT]], | ||
transform: typing.Callable[[_InputT], _OutputT], | ||
retry_spec: retry_strategy.RetryStrategy | None = None, | ||
): | ||
"""Initialize the streaming helper. | ||
Args: | ||
stream_name: A name to identify the stream in the logs. | ||
stream_method: A function that returns the grpc stream. This function is | ||
called everytime the connection is lost and we want to retry. | ||
transform: A function to transform the input type to the output type. | ||
retry_spec: The retry strategy to use, when the connection is lost. Defaults | ||
to retries every 3 seconds, with a jitter of 1 second, indefinitely. | ||
""" | ||
self._stream_name = stream_name | ||
self._stream_method = stream_method | ||
self._transform = transform | ||
self._retry_spec = ( | ||
retry_strategy.LinearBackoff() if retry_spec is None else retry_spec.copy() | ||
) | ||
|
||
self._channel: channels.Broadcast[_OutputT] = channels.Broadcast( | ||
f"GrpcStreamingHelper-{stream_name}" | ||
) | ||
self._task = asyncio.create_task(self._run()) | ||
|
||
def new_receiver(self, maxsize: int = 50) -> channels.Receiver[_OutputT]: | ||
"""Create a new receiver for the stream. | ||
Args: | ||
maxsize: The maximum number of messages to buffer. | ||
Returns: | ||
A new receiver. | ||
""" | ||
return self._channel.new_receiver(maxsize=maxsize) | ||
|
||
async def stop(self) -> None: | ||
"""Stop the streaming helper.""" | ||
if self._task.done(): | ||
return | ||
self._task.cancel() | ||
try: | ||
await self._task | ||
except asyncio.CancelledError: | ||
pass | ||
await self._channel.close() | ||
|
||
async def _run(self) -> None: | ||
"""Run the streaming helper.""" | ||
sender = self._channel.new_sender() | ||
|
||
while True: | ||
_logger.debug("Making call to grpc streaming method: %s", self._stream_name) | ||
|
||
try: | ||
call = self._stream_method() | ||
async for msg in call: | ||
await sender.send(self._transform(msg)) | ||
except grpc.aio.AioRpcError: | ||
_logger.exception( | ||
"Error in grpc streaming method: %s", self._stream_name | ||
) | ||
if interval := self._retry_spec.next_interval(): | ||
_logger.warning( | ||
"`%s`, connection ended, retrying %s in %0.3f seconds.", | ||
self._stream_name, | ||
self._retry_spec.get_progress(), | ||
interval, | ||
) | ||
await asyncio.sleep(interval) | ||
else: | ||
_logger.warning( | ||
"`%s`, connection ended, retry limit exceeded %s.", | ||
self._stream_name, | ||
self._retry_spec.get_progress(), | ||
) | ||
await self._channel.close() | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# License: MIT | ||
# Copyright © 2022 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Implementations for retry strategies.""" | ||
|
||
from __future__ import annotations | ||
|
||
import random | ||
from abc import ABC, abstractmethod | ||
from collections.abc import Iterator | ||
from copy import deepcopy | ||
|
||
_DEFAULT_RETRY_INTERVAL = 3.0 | ||
"""Default retry interval, in seconds.""" | ||
|
||
_DEFAULT_RETRY_JITTER = 1.0 | ||
"""Default retry jitter, in seconds.""" | ||
|
||
|
||
class RetryStrategy(ABC): | ||
"""Interface for implementing retry strategies.""" | ||
|
||
_limit: int | None | ||
_count: int | ||
|
||
@abstractmethod | ||
def next_interval(self) -> float | None: | ||
"""Return the time to wait before the next retry. | ||
Returns `None` if the retry limit has been reached, and no more retries | ||
are possible. | ||
Returns: | ||
Time until next retry when below retry limit, and None otherwise. | ||
""" | ||
|
||
def get_progress(self) -> str: | ||
"""Return a string denoting the retry progress. | ||
Returns: | ||
String denoting retry progress in the form "(count/limit)" | ||
""" | ||
if self._limit is None: | ||
return f"({self._count}/∞)" | ||
|
||
return f"({self._count}/{self._limit})" | ||
|
||
def reset(self) -> None: | ||
"""Reset the retry counter. | ||
To be called as soon as a connection is successful. | ||
""" | ||
self._count = 0 | ||
|
||
def copy(self) -> RetryStrategy: | ||
"""Create a new instance of `self`. | ||
Returns: | ||
A deepcopy of `self`. | ||
""" | ||
ret = deepcopy(self) | ||
ret.reset() | ||
return ret | ||
|
||
def __iter__(self) -> Iterator[float]: | ||
"""Return an iterator over the retry intervals. | ||
Yields: | ||
Next retry interval in seconds. | ||
""" | ||
while True: | ||
interval = self.next_interval() | ||
if interval is None: | ||
break | ||
yield interval | ||
|
||
|
||
class LinearBackoff(RetryStrategy): | ||
"""Provides methods for calculating the interval between retries.""" | ||
|
||
def __init__( | ||
self, | ||
interval: float = _DEFAULT_RETRY_INTERVAL, | ||
jitter: float = _DEFAULT_RETRY_JITTER, | ||
limit: int | None = None, | ||
) -> None: | ||
"""Create a `LinearBackoff` instance. | ||
Args: | ||
interval: time to wait for before the next retry, in seconds. | ||
jitter: a jitter to add to the retry interval. | ||
limit: max number of retries before giving up. `None` means no | ||
limit, and `0` means no retry. | ||
""" | ||
self._interval = interval | ||
self._jitter = jitter | ||
self._limit = limit | ||
|
||
self._count = 0 | ||
|
||
def next_interval(self) -> float | None: | ||
"""Return the time to wait before the next retry. | ||
Returns `None` if the retry limit has been reached, and no more retries | ||
are possible. | ||
Returns: | ||
Time until next retry when below retry limit, and None otherwise. | ||
""" | ||
if self._limit is not None and self._count >= self._limit: | ||
return None | ||
self._count += 1 | ||
return self._interval + random.uniform(0.0, self._jitter) | ||
|
||
|
||
class ExponentialBackoff(RetryStrategy): | ||
"""Provides methods for calculating the exponential interval between retries.""" | ||
|
||
DEFAULT_INTERVAL = _DEFAULT_RETRY_INTERVAL | ||
"""Default retry interval, in seconds.""" | ||
|
||
DEFAULT_MAX_INTERVAL = 60.0 | ||
"""Default maximum retry interval, in seconds.""" | ||
|
||
DEFAULT_MULTIPLIER = 2.0 | ||
"""Default multiplier for exponential increment.""" | ||
|
||
# pylint: disable=too-many-arguments | ||
def __init__( | ||
self, | ||
initial_interval: float = DEFAULT_INTERVAL, | ||
max_interval: float = DEFAULT_MAX_INTERVAL, | ||
multiplier: float = DEFAULT_MULTIPLIER, | ||
jitter: float = _DEFAULT_RETRY_JITTER, | ||
limit: int | None = None, | ||
) -> None: | ||
"""Create a `ExponentialBackoff` instance. | ||
Args: | ||
initial_interval: time to wait for before the first retry, in | ||
seconds. | ||
max_interval: maximum interval, in seconds. | ||
multiplier: exponential increment for interval. | ||
jitter: a jitter to add to the retry interval. | ||
limit: max number of retries before giving up. `None` means no | ||
limit, and `0` means no retry. | ||
""" | ||
self._initial = initial_interval | ||
self._max = max_interval | ||
self._multiplier = multiplier | ||
self._jitter = jitter | ||
self._limit = limit | ||
|
||
self._count = 0 | ||
|
||
def next_interval(self) -> float | None: | ||
"""Return the time to wait before the next retry. | ||
Returns `None` if the retry limit has been reached, and no more retries | ||
are possible. | ||
Returns: | ||
Time until next retry when below retry limit, and None otherwise. | ||
""" | ||
if self._limit is not None and self._count >= self._limit: | ||
return None | ||
self._count += 1 | ||
exp_backoff_interval = self._initial * self._multiplier ** (self._count - 1) | ||
return min(exp_backoff_interval + random.uniform(0.0, self._jitter), self._max) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# License: MIT | ||
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Tests for the frequenz.client.base package.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# License: MIT | ||
# Copyright © 2023 Frequenz Energy-as-a-Service GmbH | ||
|
||
"""Tests for the retry_strategy package.""" |
Oops, something went wrong.