Skip to content

Commit

Permalink
Issue #18/EP-4049 rework OIDC metadata building, responsibility and c…
Browse files Browse the repository at this point in the history
…aching
  • Loading branch information
soxofaan committed Oct 27, 2021
1 parent 26d5a41 commit 6e512c3
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 146 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
install_requires=[
"requests",
"openeo>=0.9.1a2.*",
"openeo_driver>=0.14.7a1.*",
"openeo_driver>=0.14.8a1.*",
"flask~=2.0",
"gunicorn~=20.0",
],
Expand Down
2 changes: 1 addition & 1 deletion src/openeo_aggregator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_app(config: Any = None, auto_logging_setup=True) -> flask.Flask:
_log.info(f"Using config: {config}")

_log.info(f"Creating MultiBackendConnection with {config.aggregator_backends}")
backends = MultiBackendConnection(backends=config.aggregator_backends)
backends = MultiBackendConnection.from_config(config)

_log.info("Creating AggregatorBackendImplementation")
backend_implementation = AggregatorBackendImplementation(backends=backends, config=config)
Expand Down
11 changes: 4 additions & 7 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class AggregatorBackendImplementation(OpenEoBackendImplementation):
enable_basic_auth = False

# Simplify mocking time for unit tests.
_clock = time.time
_clock = time.time # TODO: centralized helper for this test pattern

def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
self._backends = backends
Expand All @@ -564,15 +564,12 @@ def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
self._cache = TtlCache(default_ttl=CACHE_TTL_DEFAULT)
self._backends.on_connections_change.add(self._cache.flush_all)
self._auth_entitlement_check: Union[bool, dict] = config.auth_entitlement_check
self._configured_oidc_providers: List[OidcProvider] = config.configured_oidc_providers

def oidc_providers(self) -> List[OidcProvider]:
# TODO: openeo-python-driver is not ready for changing oidc_providers in HttpAuthHandler
key = "oidc_providers"
if key not in self._cache:
providers = self._backends.build_oidc_handling(configured_providers=self._configured_oidc_providers)
self._cache.set(key, value=providers)
return self._cache[key]
return self._cache.get_or_call(
key="oidc_providers", callback=self._backends.get_oidc_providers
)

def file_formats(self) -> dict:
return self._cache.get_or_call(key="file_formats", callback=self._file_formats)
Expand Down
117 changes: 65 additions & 52 deletions src/openeo_aggregator/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from openeo import Connection
from openeo.capabilities import ComparableVersion
from openeo.rest.auth.auth import BearerAuth, OpenEoApiAuthBase
from openeo_aggregator.config import CACHE_TTL_DEFAULT, CONNECTION_TIMEOUT_DEFAULT, STREAM_CHUNK_SIZE_DEFAULT
from openeo_aggregator.config import CACHE_TTL_DEFAULT, CONNECTION_TIMEOUT_DEFAULT, STREAM_CHUNK_SIZE_DEFAULT, \
AggregatorConfig
from openeo_aggregator.utils import TtlCache, _UNSET, EventHandler
from openeo_driver.backend import OidcProvider
from openeo_driver.errors import OpenEOApiException, AuthenticationRequiredException, \
Expand All @@ -23,7 +24,13 @@


class LockedAuthException(InternalException):
"""Implementation tries to do permanent authentication on connection"""
def __init__(self):
super().__init__(message="Setting auth while locked.")


class InvalidatedConnection(InternalException):
def __init__(self):
super().__init__(message="Usage of invalidated connection")


class BackendConnection(Connection):
Expand All @@ -37,7 +44,13 @@ class BackendConnection(Connection):
# TODO: subclass from RestApiConnection to avoid inheriting feature set
# designed for single-user use case (e.g. caching, working local config files, ...)

def __init__(self, id: str, url: str, default_timeout: int = CONNECTION_TIMEOUT_DEFAULT):
def __init__(
self,
id: str,
url: str,
configured_oidc_providers: List[OidcProvider],
default_timeout: int = CONNECTION_TIMEOUT_DEFAULT
):
# Temporarily unlock `_auth` for `super().__init__()`
self._auth_locked = False
super(BackendConnection, self).__init__(url, default_timeout=default_timeout)
Expand All @@ -49,26 +62,35 @@ def __init__(self, id: str, url: str, default_timeout: int = CONNECTION_TIMEOUT_
v=openeo_aggregator.about.__version__,
)
# Mapping of aggregator provider id to backend's provider id
self._oidc_provider_map: Dict[str, str] = {}
self._oidc_provider_map: Dict[str, str] = self._build_oidc_provider_map(configured_oidc_providers)

def _get_auth(self) -> Union[None, OpenEoApiAuthBase]:
return None if self._auth_locked else self._auth

def _set_auth(self, auth: OpenEoApiAuthBase):
if self._auth_locked:
raise LockedAuthException("Setting auth while locked.")
raise LockedAuthException
self._auth = auth

auth = property(_get_auth, _set_auth)

def set_oidc_provider_map(self, pid_map: Dict[str, str]):
if len(self._oidc_provider_map) > 0 and self._oidc_provider_map != pid_map:
_log.warning(
f"Changing OIDC provider mapping in connection {self.id}"
f" from {self._oidc_provider_map} to {pid_map}"
)
_log.info(f"Setting OIDC provider mapping for connection {self.id}: {pid_map}")
self._oidc_provider_map = pid_map
def _build_oidc_provider_map(self, configured_providers: List[OidcProvider]) -> Dict[str, str]:
"""Construct mapping from aggregator OIDC provider id to backend OIDC provider id"""
pid_map = {}
if configured_providers:
backend_providers = [
OidcProvider.from_dict(p)
for p in self.get("/credentials/oidc", expected_status=200).json()["providers"]
]
for agg_provider in configured_providers:
targets = [bp.id for bp in backend_providers if bp.get_issuer() == agg_provider.get_issuer()]
if targets:
pid_map[agg_provider.id] = targets[0]
return pid_map

@property
def oidc_provider_map(self) -> Dict[str, str]:
return self._oidc_provider_map

def _get_bearer(self, request: flask.Request) -> str:
"""Extract authorization header from request and (optionally) transform for given backend """
Expand Down Expand Up @@ -116,6 +138,14 @@ def override(self, default_timeout: int = _UNSET, default_headers: dict = _UNSET
self.default_timeout = orig_default_timeout
self.default_headers = orig_default_headers

def invalidate(self):
"""Destroy connection to avoid accidental usage."""

def request(*args, **kwargs):
raise InvalidatedConnection

self.request = request


_ConnectionsCache = collections.namedtuple("_ConnectionsCache", ["expiry", "connections"])

Expand All @@ -134,31 +164,41 @@ class MultiBackendConnection:
_TIMEOUT = 5

# Simplify mocking time for unit tests.
_clock = time.time
_clock = time.time # TODO: centralized helper for this test pattern

def __init__(self, backends: Dict[str, str]):
def __init__(self, backends: Dict[str, str], configured_oidc_providers: List[OidcProvider]):
if any(not re.match(r"^[a-z0-9]+$", bid) for bid in backends.keys()):
raise ValueError(
f"Backend ids should be alphanumeric only (no dots, dashes, ...) "
f"to avoid collision issues when used as prefix. Got: {list(backends.keys())}"
)
# TODO: backend_urls as dict does not have explicit order, while this is important.
self._backend_urls = backends
self._configured_oidc_providers = configured_oidc_providers

self._connections_cache = _ConnectionsCache(expiry=0, connections=[])
# General (metadata/status) caching
self._cache = TtlCache(default_ttl=CACHE_TTL_DEFAULT)

# Caching of connection objects
self._connections_cache = _ConnectionsCache(expiry=0, connections=[])
self.on_connections_change = EventHandler("connections_change")
self.on_connections_change.add(self._cache.flush_all)

@staticmethod
def from_config(config: AggregatorConfig) -> 'MultiBackendConnection':
return MultiBackendConnection(
backends=config.aggregator_backends,
configured_oidc_providers=config.configured_oidc_providers
)

def _get_connections(self, skip_failures=False) -> Iterator[BackendConnection]:
"""Create new backend connections."""
for (bid, url) in self._backend_urls.items():
try:
_log.info(f"Create backend {bid!r} connection to {url!r}")
# TODO: Creating connection usually involves version discovery and request of capability doc.
# Additional health check necessary?
yield BackendConnection(id=bid, url=url)
yield BackendConnection(id=bid, url=url, configured_oidc_providers=self._configured_oidc_providers)
except Exception as e:
_log.warning(f"Failed to create backend {bid!r} connection to {url!r}: {e!r}")
if not skip_failures:
Expand All @@ -170,6 +210,8 @@ def get_connections(self) -> List[BackendConnection]:
if now > self._connections_cache.expiry:
_log.info(f"Connections cache expired ({now:.2f}>{self._connections_cache.expiry:.2f})")
orig_bids = [c.id for c in self._connections_cache.connections]
for con in self._connections_cache.connections:
con.invalidate()
self._connections_cache = _ConnectionsCache(
expiry=now + self._CONNECTIONS_CACHING_TTL,
connections=list(self._get_connections(skip_failures=True))
Expand Down Expand Up @@ -233,54 +275,25 @@ def map(self, callback: Callable[[BackendConnection], Any]) -> Iterator[Tuple[st
# TODO: customizable exception handling: skip, warn, re-raise?
yield con.id, res

def get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]:
return self._cache.get_or_call(key="oidc_providers_per_backend", callback=self._get_oidc_providers_per_backend)

def _get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]:
# Collect provider info per backend
providers_per_backend: Dict[str, List[OidcProvider]] = {}
for con in self.get_connections():
providers_per_backend[con.id] = []
for provider_data in con.get("/credentials/oidc", expected_status=200).json()["providers"]:
# Normalize issuer for sensible comparison operations.
provider_data["issuer"] = provider_data["issuer"].rstrip("/").lower()
providers_per_backend[con.id].append(OidcProvider.from_dict(provider_data))
return providers_per_backend

def build_oidc_handling(self, configured_providers: List[OidcProvider]) -> List[OidcProvider]:
def get_oidc_providers(self) -> List[OidcProvider]:
"""
Determine OIDC providers to use in aggregator (based on OIDC issuers supported by all backends)
and set up provider id mapping in the backend connections
:param configured_providers: OIDC providers dedicated/configured for the aggregator
:return: list of actual OIDC providers to use (configured for aggregator and supported by all backends)
"""
providers_per_backend = self.get_oidc_providers_per_backend()

# Find OIDC issuers supported by each backend (intersection of issuer sets).
issuers_per_backend = [
set(p.issuer for p in providers)
for providers in providers_per_backend.values()
]
intersection: Set[str] = functools.reduce((lambda x, y: x.intersection(y)), issuers_per_backend)
# Get intersection of aggregator OIDC provider ids
agg_pids_per_backend = [set(c.oidc_provider_map.keys()) for c in self.get_connections()]
intersection: Set[str] = functools.reduce((lambda x, y: x.intersection(y)), agg_pids_per_backend)
_log.info(f"OIDC provider intersection: {intersection}")
if len(intersection) == 0:
_log.warning(f"Emtpy OIDC provider intersection. Issuers per backend: {issuers_per_backend}")
_log.warning(f"Emtpy OIDC provider intersection. Issuers per backend: {agg_pids_per_backend}")

# Take configured providers for common issuers.
agg_providers = [p for p in configured_providers if p.issuer.rstrip("/").lower() in intersection]
agg_providers = [p for p in self._configured_oidc_providers if p.id in intersection]
_log.info(f"Actual aggregator providers: {agg_providers}")

# Set up provider id mapping (aggregator pid to original backend pid) for the connections
for con in self.get_connections():
backend_providers = providers_per_backend[con.id]
pid_map = {}
for agg_provider in agg_providers:
agg_issuer = agg_provider.issuer.rstrip("/").lower()
orig_pid = next(bp.id for bp in backend_providers if bp.issuer == agg_issuer)
pid_map[agg_provider.id] = orig_pid
con.set_oidc_provider_map(pid_map)

return agg_providers


Expand Down
2 changes: 1 addition & 1 deletion src/openeo_aggregator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TtlCache:
def __init__(self, default_ttl: int = 60, clock: Callable[[], float] = time.time):
self._cache = {}
self.default_ttl = default_ttl
self._clock = clock
self._clock = clock # TODO: centralized helper for this test pattern

def set(self, key, value, ttl=None):
"""Add item to cache"""
Expand Down
29 changes: 16 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import flask
import pytest

Expand Down Expand Up @@ -30,26 +32,22 @@ def backend2(requests_mock):


@pytest.fixture
def multi_backend_connection(backend1, backend2) -> MultiBackendConnection:
return MultiBackendConnection({
"b1": backend1,
"b2": backend2,
})
def configured_oidc_providers() -> List[OidcProvider]:
return [
OidcProvider(id="egi", issuer="https://egi.test", title="EGI"),
OidcProvider(id="x-agg", issuer="https://x.test", title="X (agg)"),
OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)"),
OidcProvider(id="z-agg", issuer="https://z.test", title="Z (agg)"),
]


@pytest.fixture
def base_config() -> AggregatorConfig:
def base_config(configured_oidc_providers) -> AggregatorConfig:
"""Base config for tests (without any configured backends)."""
conf = AggregatorConfig()
# conf.flask_error_handling = False # Temporary disable flask error handlers to simplify debugging (better stack traces).

conf.configured_oidc_providers = [
OidcProvider(id="egi", issuer="https://egi.test", title="EGI"),
OidcProvider(id="x-agg", issuer="https://x.test", title="X (agg)"),
OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)"),
OidcProvider(id="z-agg", issuer="https://z.test", title="Z (agg)"),
]

conf.configured_oidc_providers = configured_oidc_providers
# Disable OIDC/EGI entitlement check by default.
conf.auth_entitlement_check = False
return conf
Expand All @@ -66,6 +64,11 @@ def config(base_config, backend1, backend2) -> AggregatorConfig:
return conf


@pytest.fixture
def multi_backend_connection(config) -> MultiBackendConnection:
return MultiBackendConnection.from_config(config)


def get_flask_app(config: AggregatorConfig) -> flask.Flask:
app = create_app(config=config, auto_logging_setup=False)
app.config['TESTING'] = True
Expand Down
8 changes: 7 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import itertools
import time

import pytest

from openeo_aggregator.backend import AggregatorCollectionCatalog, AggregatorProcessing, \
AggregatorBackendImplementation, _InternalCollectionMetadata, JobIdMapping
from openeo_aggregator.connection import MultiBackendConnection
from openeo_driver.errors import OpenEOApiException, CollectionNotFoundException, JobNotFoundException
from openeo_driver.users.oidc import OidcProvider

Expand Down Expand Up @@ -40,8 +44,10 @@ def test_oidc_providers_caching(self, multi_backend_connection, config, backend1
providers = implementation.oidc_providers()
assert providers == [OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)")]
assert (m1.call_count, m2.call_count) == (1, 1)

MultiBackendConnection._clock = itertools.count(time.time() + 1000).__next__
implementation._cache.flush_all()
implementation._backends._cache.flush_all()

providers = implementation.oidc_providers()
assert providers == [OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)")]
assert (m1.call_count, m2.call_count) == (2, 2)
Expand Down
Loading

0 comments on commit 6e512c3

Please sign in to comment.