From 25d7470d40741ab728317555af31c38b508ffa61 Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Wed, 20 Oct 2021 16:12:34 +0200 Subject: [PATCH] Issue #18/EP-4049 initial implementation of refreshing MultiBackendConnection --- src/openeo_aggregator/connection.py | 65 ++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 11 deletions(-) diff --git a/src/openeo_aggregator/connection.py b/src/openeo_aggregator/connection.py index 36ae8736..928538d8 100644 --- a/src/openeo_aggregator/connection.py +++ b/src/openeo_aggregator/connection.py @@ -1,7 +1,9 @@ +import collections import contextlib import functools import logging import re +import time from typing import List, Dict, Any, Iterator, Callable, Tuple, Set, Union import flask @@ -32,6 +34,9 @@ class BackendConnection(Connection): - authentication is locked down: only short term authentication is allowed (during lifetime of a flask request) """ + # TODO: subclass from RestApiConnection to avoid inheriting feature set + # designed for single-user use case (e.g. caching, working local config files, ...) + def __init__(self, id: str, url: str, default_timeout: int = CONNECTION_TIMEOUT_DEFAULT): # Temporarily unlock `_auth` for `super().__init__()` self._auth_locked = False @@ -58,7 +63,10 @@ def _set_auth(self, auth: OpenEoApiAuthBase): def set_oidc_provider_map(self, pid_map: Dict[str, str]): if len(self._oidc_provider_map) > 0 and self._oidc_provider_map != pid_map: - _log.warning(f"Changing OIDC provider mapping in connection {self.id} from {self._oidc_provider_map} to {pid_map}") + _log.warning( + f"Changing OIDC provider mapping in connection {self.id}" + f" from {self._oidc_provider_map} to {pid_map}" + ) _log.info(f"Setting OIDC provider mapping for connection {self.id}: {pid_map}") self._oidc_provider_map = pid_map @@ -109,35 +117,70 @@ def override(self, default_timeout: int = _UNSET, default_headers: dict = _UNSET self.default_headers = orig_default_headers +_ConnectionsCache = collections.namedtuple("_ConnectionsCache", ["expiry", "connections"]) + + class MultiBackendConnection: """ Collection of multiple connections to different backends """ + # TODO: move this caching ttl to config? + _CONNECTION_CACHING_TTL = 5 * 60 + _TIMEOUT = 5 + # TODO: keep track of (recent) backend failures, e.g. to automatically blacklist a backend + # TODO: synchronized backend connection caching/flushing across gunicorn workers, for better consistency? + def __init__(self, backends: Dict[str, str]): if any(not re.match(r"^[a-z0-9]+$", bid) for bid in backends.keys()): raise ValueError( f"Backend ids should be alphanumeric only (no dots, dashes, ...) " f"to avoid collision issues when used as prefix. Got: {list(backends.keys())}" ) + # TODO: backend_urls as dict does not have explicit order, while this is important. self._backend_urls = backends - self._connections: List[BackendConnection] = [] - for (bid, url) in self._backend_urls.items(): - _log.info(f"Setting up backend {bid!r} Connection: {url!r}") - self._connections.append(BackendConnection(id=bid, url=url)) + self._connections_cache = _ConnectionsCache(expiry=0, connections=[]) # TODO: API version management: just do single-version aggregation, or also handle version discovery? self.api_version = self._get_api_version() self._cache = TtlCache(default_ttl=CACHE_TTL_DEFAULT) + def _get_connections(self, skip_failure=False) -> Iterator[BackendConnection]: + """Create new backend connections.""" + for (bid, url) in self._backend_urls.items(): + try: + _log.info(f"Create backend {bid!r} connection to {url!r}") + # TODO: also do a health check on the connection? + yield BackendConnection(id=bid, url=url) + except Exception as e: + if skip_failure: + _log.warning(f"Failed to create backend {bid!r} connection to {url!r}: {e!r}") + else: + raise + + def get_connections(self) -> List[BackendConnection]: + """Get backend connections (re-created automatically if cache ttl expired)""" + now = time.time() + if now > self._connections_cache.expiry: + _log.info(f"Connections cache miss: setting up new connections") + self._connections_cache = _ConnectionsCache( + expiry=now + self._CONNECTION_CACHING_TTL, + connections=list(self._get_connections(skip_failure=True)) + ) + _log.info( + f"Created {len(self._connections_cache.connections)} actual" + f" of {len(self._backend_urls)} configured connections" + ) + return self._connections_cache.connections + def __iter__(self) -> Iterator[BackendConnection]: - return iter(self._connections) + return iter(self.get_connections()) def first(self) -> BackendConnection: """Get first backend in the list""" # TODO: rename this to main_backend (if it makes sense to have a general main backend)? - return self._connections[0] + return self.get_connections()[0] def get_connection(self, backend_id: str) -> BackendConnection: for con in self: @@ -153,7 +196,7 @@ def get_status(self) -> dict: "root_url": c._root_url, "orig_url": c._orig_url, } - for c in self._connections + for c in self.get_connections() } def _get_api_version(self) -> ComparableVersion: @@ -169,7 +212,7 @@ def map(self, callback: Callable[[BackendConnection], Any]) -> Iterator[Tuple[st :param callback: function to apply to the connection """ - for con in self._connections: + for con in self.get_connections(): res = callback(con) # TODO: customizable exception handling: skip, warn, re-raise? yield con.id, res @@ -180,7 +223,7 @@ def get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]: def _get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]: # Collect provider info per backend providers_per_backend: Dict[str, List[OidcProvider]] = {} - for con in self._connections: + for con in self.get_connections(): providers_per_backend[con.id] = [] for provider_data in con.get("/credentials/oidc", expected_status=200).json()["providers"]: # Normalize issuer for sensible comparison operations. @@ -213,7 +256,7 @@ def build_oidc_handling(self, configured_providers: List[OidcProvider]) -> List[ _log.info(f"Actual aggregator providers: {agg_providers}") # Set up provider id mapping (aggregator pid to original backend pid) for the connections - for con in self._connections: + for con in self.get_connections(): backend_providers = providers_per_backend[con.id] pid_map = {} for agg_provider in agg_providers: