Skip to content

Commit

Permalink
Issue #18/EP-4049 initial implementation of refreshing MultiBackendCo…
Browse files Browse the repository at this point in the history
…nnection
  • Loading branch information
soxofaan committed Nov 4, 2021
1 parent abec086 commit e97f5f9
Showing 1 changed file with 54 additions and 11 deletions.
65 changes: 54 additions & 11 deletions src/openeo_aggregator/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import collections
import contextlib
import functools
import logging
import re
import time
from typing import List, Dict, Any, Iterator, Callable, Tuple, Set, Union

import flask
Expand Down Expand Up @@ -32,6 +34,9 @@ class BackendConnection(Connection):
- authentication is locked down: only short term authentication is allowed (during lifetime of a flask request)
"""

# TODO: subclass from RestApiConnection to avoid inheriting feature set
# designed for single-user use case (e.g. caching, working local config files, ...)

def __init__(self, id: str, url: str, default_timeout: int = CONNECTION_TIMEOUT_DEFAULT):
# Temporarily unlock `_auth` for `super().__init__()`
self._auth_locked = False
Expand All @@ -58,7 +63,10 @@ def _set_auth(self, auth: OpenEoApiAuthBase):

def set_oidc_provider_map(self, pid_map: Dict[str, str]):
if len(self._oidc_provider_map) > 0 and self._oidc_provider_map != pid_map:
_log.warning(f"Changing OIDC provider mapping in connection {self.id} from {self._oidc_provider_map} to {pid_map}")
_log.warning(
f"Changing OIDC provider mapping in connection {self.id}"
f" from {self._oidc_provider_map} to {pid_map}"
)
_log.info(f"Setting OIDC provider mapping for connection {self.id}: {pid_map}")
self._oidc_provider_map = pid_map

Expand Down Expand Up @@ -109,35 +117,70 @@ def override(self, default_timeout: int = _UNSET, default_headers: dict = _UNSET
self.default_headers = orig_default_headers


_ConnectionsCache = collections.namedtuple("_ConnectionsCache", ["expiry", "connections"])


class MultiBackendConnection:
"""
Collection of multiple connections to different backends
"""

# TODO: move this caching ttl to config?
_CONNECTION_CACHING_TTL = 5 * 60

_TIMEOUT = 5

# TODO: keep track of (recent) backend failures, e.g. to automatically blacklist a backend
# TODO: synchronized backend connection caching/flushing across gunicorn workers, for better consistency?

def __init__(self, backends: Dict[str, str]):
if any(not re.match(r"^[a-z0-9]+$", bid) for bid in backends.keys()):
raise ValueError(
f"Backend ids should be alphanumeric only (no dots, dashes, ...) "
f"to avoid collision issues when used as prefix. Got: {list(backends.keys())}"
)
# TODO: backend_urls as dict does not have explicit order, while this is important.
self._backend_urls = backends
self._connections: List[BackendConnection] = []
for (bid, url) in self._backend_urls.items():
_log.info(f"Setting up backend {bid!r} Connection: {url!r}")
self._connections.append(BackendConnection(id=bid, url=url))
self._connections_cache = _ConnectionsCache(expiry=0, connections=[])
# TODO: API version management: just do single-version aggregation, or also handle version discovery?
self.api_version = self._get_api_version()
self._cache = TtlCache(default_ttl=CACHE_TTL_DEFAULT)

def _get_connections(self, skip_failure=False) -> Iterator[BackendConnection]:
"""Create new backend connections."""
for (bid, url) in self._backend_urls.items():
try:
_log.info(f"Create backend {bid!r} connection to {url!r}")
# TODO: also do a health check on the connection?
yield BackendConnection(id=bid, url=url)
except Exception as e:
if skip_failure:
_log.warning(f"Failed to create backend {bid!r} connection to {url!r}: {e!r}")
else:
raise

def get_connections(self) -> List[BackendConnection]:
"""Get backend connections (re-created automatically if cache ttl expired)"""
now = time.time()
if now > self._connections_cache.expiry:
_log.info(f"Connections cache miss: setting up new connections")
self._connections_cache = _ConnectionsCache(
expiry=now + self._CONNECTION_CACHING_TTL,
connections=list(self._get_connections(skip_failure=True))
)
_log.info(
f"Created {len(self._connections_cache.connections)} actual"
f" of {len(self._backend_urls)} configured connections"
)
return self._connections_cache.connections

def __iter__(self) -> Iterator[BackendConnection]:
return iter(self._connections)
return iter(self.get_connections())

def first(self) -> BackendConnection:
"""Get first backend in the list"""
# TODO: rename this to main_backend (if it makes sense to have a general main backend)?
return self._connections[0]
return self.get_connections()[0]

def get_connection(self, backend_id: str) -> BackendConnection:
for con in self:
Expand All @@ -153,7 +196,7 @@ def get_status(self) -> dict:
"root_url": c._root_url,
"orig_url": c._orig_url,
}
for c in self._connections
for c in self.get_connections()
}

def _get_api_version(self) -> ComparableVersion:
Expand All @@ -169,7 +212,7 @@ def map(self, callback: Callable[[BackendConnection], Any]) -> Iterator[Tuple[st
:param callback: function to apply to the connection
"""
for con in self._connections:
for con in self.get_connections():
res = callback(con)
# TODO: customizable exception handling: skip, warn, re-raise?
yield con.id, res
Expand All @@ -180,7 +223,7 @@ def get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]:
def _get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]:
# Collect provider info per backend
providers_per_backend: Dict[str, List[OidcProvider]] = {}
for con in self._connections:
for con in self.get_connections():
providers_per_backend[con.id] = []
for provider_data in con.get("/credentials/oidc", expected_status=200).json()["providers"]:
# Normalize issuer for sensible comparison operations.
Expand Down Expand Up @@ -213,7 +256,7 @@ def build_oidc_handling(self, configured_providers: List[OidcProvider]) -> List[
_log.info(f"Actual aggregator providers: {agg_providers}")

# Set up provider id mapping (aggregator pid to original backend pid) for the connections
for con in self._connections:
for con in self.get_connections():
backend_providers = providers_per_backend[con.id]
pid_map = {}
for agg_provider in agg_providers:
Expand Down

0 comments on commit e97f5f9

Please sign in to comment.