Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement asynchronous AuthorizedSession class #1580

77 changes: 61 additions & 16 deletions google/auth/_exponential_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import random
import time

Expand All @@ -38,9 +39,8 @@
"""


class ExponentialBackoff:
"""An exponential backoff iterator. This can be used in a for loop to
perform requests with exponential backoff.
class _BaseExponentialBackoff:
"""An exponential backoff iterator base class.

Args:
total_attempts Optional[int]:
Expand Down Expand Up @@ -84,9 +84,40 @@ def __init__(
self._multiplier = multiplier
self._backoff_count = 0

def __iter__(self):
@property
def total_attempts(self):
"""The total amount of backoff attempts that will be made."""
return self._total_attempts

@property
def backoff_count(self):
"""The current amount of backoff attempts that have been made."""
return self._backoff_count

def _reset(self):
self._backoff_count = 0
self._current_wait_in_seconds = self._initial_wait_seconds

def _calculate_jitter(self):
jitter_variance = self._current_wait_in_seconds * self._randomization_factor
jitter = random.uniform(
self._current_wait_in_seconds - jitter_variance,
self._current_wait_in_seconds + jitter_variance,
)

return jitter


class ExponentialBackoff(_BaseExponentialBackoff):
"""An exponential backoff iterator. This can be used in a for loop to
perform requests with exponential backoff.
"""

def __init__(self, *args, **kwargs):
super(ExponentialBackoff, self).__init__(*args, **kwargs)

def __iter__(self):
self._reset()
return self

def __next__(self):
Expand All @@ -97,23 +128,37 @@ def __next__(self):
if self._backoff_count <= 1:
return self._backoff_count

jitter_variance = self._current_wait_in_seconds * self._randomization_factor
jitter = random.uniform(
self._current_wait_in_seconds - jitter_variance,
self._current_wait_in_seconds + jitter_variance,
)
jitter = self._calculate_jitter()

time.sleep(jitter)

self._current_wait_in_seconds *= self._multiplier
return self._backoff_count

@property
def total_attempts(self):
"""The total amount of backoff attempts that will be made."""
return self._total_attempts

@property
def backoff_count(self):
"""The current amount of backoff attempts that have been made."""
class AsyncExponentialBackoff(_BaseExponentialBackoff):
"""An async exponential backoff iterator. This can be used in a for loop to
perform async requests with exponential backoff.
"""

def __init__(self, *args, **kwargs):
super(AsyncExponentialBackoff, self).__init__(*args, **kwargs)

def __aiter__(self):
self._reset()
return self

async def __anext__(self):
if self._backoff_count >= self._total_attempts:
raise StopAsyncIteration
self._backoff_count += 1

if self._backoff_count <= 1:
return self._backoff_count

jitter = self._calculate_jitter()

await asyncio.sleep(jitter)

self._current_wait_in_seconds *= self._multiplier
return self._backoff_count
29 changes: 23 additions & 6 deletions google/auth/aio/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,24 @@
"""

import abc
from typing import AsyncGenerator, Dict, Mapping, Optional
from typing import AsyncGenerator, Mapping, Optional

import google.auth.transport


_DEFAULT_TIMEOUT_SECONDS = 180

DEFAULT_RETRYABLE_STATUS_CODES = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES
"""Sequence[int]: HTTP status codes indicating a request can be retried.
"""

DEFAULT_REFRESH_STATUS_CODES = google.auth.transport.DEFAULT_REFRESH_STATUS_CODES
"""Sequence[int]: Which HTTP status code indicate that credentials should be
refreshed.
"""

DEFAULT_MAX_REFRESH_ATTEMPTS = 3
"""int: How many times to refresh the credentials and retry a request."""


class Response(metaclass=abc.ABCMeta):
Expand All @@ -35,7 +52,7 @@ class Response(metaclass=abc.ABCMeta):
@abc.abstractmethod
def status_code(self) -> int:
"""
The HTTP response status code..
The HTTP response status code.

Returns:
int: The HTTP response status code.
Expand All @@ -45,11 +62,11 @@ def status_code(self) -> int:

@property
@abc.abstractmethod
def headers(self) -> Dict[str, str]:
def headers(self) -> Mapping[str, str]:
"""The HTTP response headers.

Returns:
Dict[str, str]: The HTTP response headers.
Mapping[str, str]: The HTTP response headers.
"""
raise NotImplementedError("headers must be implemented.")

Expand Down Expand Up @@ -95,7 +112,7 @@ async def __call__(
self,
url: str,
method: str,
body: bytes,
body: Optional[bytes],
headers: Optional[Mapping[str, str]],
timeout: float,
**kwargs
Expand All @@ -106,7 +123,7 @@ async def __call__(
url (str): The URI to be requested.
method (str): The HTTP method to use for the request. Defaults
to 'GET'.
body (bytes): The payload / body in HTTP request.
body (Optional[bytes]): The payload / body in HTTP request.
headers (Mapping[str, str]): Request headers.
timeout (float): The number of seconds to wait for a
response from the server. If not specified or if None, the
Expand Down
68 changes: 9 additions & 59 deletions google/auth/aio/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
"""

import asyncio
from contextlib import asynccontextmanager
import time
from typing import AsyncGenerator, Dict, Mapping, Optional
from typing import AsyncGenerator, Mapping, Optional

try:
import aiohttp
import aiohttp # type: ignore
except ImportError as caught_exc: # pragma: NO COVER
raise ImportError(
"The aiohttp library is not installed from please install the aiohttp package to use the aiohttp transport."
Expand All @@ -30,54 +28,6 @@
from google.auth import _helpers
from google.auth import exceptions
from google.auth.aio import transport
from google.auth.exceptions import TimeoutError


_DEFAULT_TIMEOUT_SECONDS = 180


@asynccontextmanager
async def timeout_guard(timeout):
"""
timeout_guard is an asynchronous context manager to apply a timeout to an asynchronous block of code.

Args:
timeout (float): The time in seconds before the context manager times out.

Raises:
google.auth.exceptions.TimeoutError: If the code within the context exceeds the provided timeout.

Usage:
async with timeout_guard(10) as with_timeout:
await with_timeout(async_function())
"""
start = time.monotonic()
total_timeout = timeout

def _remaining_time():
elapsed = time.monotonic() - start
remaining = total_timeout - elapsed
if remaining <= 0:
raise TimeoutError(
f"Context manager exceeded the configured timeout of {total_timeout}s."
)
return remaining

async def with_timeout(coro):
try:
remaining = _remaining_time()
response = await asyncio.wait_for(coro, remaining)
return response
except (asyncio.TimeoutError, TimeoutError) as e:
raise TimeoutError(
f"The operation {coro} exceeded the configured timeout of {total_timeout}s."
) from e

try:
yield with_timeout

finally:
_remaining_time()


class Response(transport.Response):
Expand All @@ -89,7 +39,7 @@ class Response(transport.Response):

Attributes:
status_code (int): The HTTP status code of the response.
headers (Dict[str, str]): A case-insensitive multidict proxy wiht HTTP headers of response.
headers (Mapping[str, str]): The HTTP headers of the response.
"""

def __init__(self, response: aiohttp.ClientResponse):
Expand All @@ -102,7 +52,7 @@ def status_code(self) -> int:

@property
@_helpers.copy_docstring(transport.Response)
def headers(self) -> Dict[str, str]:
def headers(self) -> Mapping[str, str]:
return {key: value for key, value in self._response.headers.items()}

@_helpers.copy_docstring(transport.Response)
Expand Down Expand Up @@ -158,7 +108,7 @@ async def __call__(
method: str = "GET",
body: Optional[bytes] = None,
headers: Optional[Mapping[str, str]] = None,
timeout: float = _DEFAULT_TIMEOUT_SECONDS,
timeout: float = transport._DEFAULT_TIMEOUT_SECONDS,
**kwargs,
) -> transport.Response:
"""
Expand Down Expand Up @@ -199,14 +149,14 @@ async def __call__(
return Response(response)

except aiohttp.ClientError as caught_exc:
new_exc = exceptions.TransportError(f"Failed to send request to {url}.")
raise new_exc from caught_exc
client_exc = exceptions.TransportError(f"Failed to send request to {url}.")
raise client_exc from caught_exc

except asyncio.TimeoutError as caught_exc:
new_exc = exceptions.TimeoutError(
timeout_exc = exceptions.TimeoutError(
f"Request timed out after {timeout} seconds."
)
raise new_exc from caught_exc
raise timeout_exc from caught_exc

async def close(self) -> None:
"""
Expand Down
Loading