Skip to content

Commit

Permalink
Issue #7 EP-4046 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Oct 7, 2021
1 parent 74dd184 commit 68f7fbc
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 37 deletions.
6 changes: 4 additions & 2 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,12 @@ def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
)
self._cache = TtlCache(default_ttl=CACHE_TTL_DEFAULT)
self._auth_entitlement_check = config.auth_entitlement_check
self._oidc_providers: List[OidcProvider] = config.oidc_providers

def oidc_providers(self) -> List[OidcProvider]:
return self._backends.get_oidc_providers()
# Get predefined providers for intersection of issuers supported by back-end
# TODO caching
return self._backends.build_oidc_handling(configured_providers=self._oidc_providers)

def file_formats(self) -> dict:
return self._cache.get_or_call(key="file_formats", callback=self._file_formats)
Expand Down Expand Up @@ -531,4 +534,3 @@ def user_access_validation(self, user: User, request: flask.Request) -> User:
user.info["roles"] = ["EarlyAdopter"]

return user

40 changes: 39 additions & 1 deletion src/openeo_aggregator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import urllib.parse
from pathlib import Path
from typing import Any
from typing import Any, List
from typing import Union

from openeo_driver.users.oidc import OidcProvider
from openeo_driver.utils import dict_item

_log = logging.getLogger(__name__)
Expand All @@ -31,6 +32,8 @@ class AggregatorConfig(dict):
flask_error_handling = dict_item(default=True)
streaming_chunk_size = dict_item(default=STREAM_CHUNK_SIZE_DEFAULT)

# TODO: add validation/normalization to make sure we have a real list of OidcProvider objects?
oidc_providers: List[OidcProvider] = dict_item(default=[])
auth_entitlement_check = dict_item(default=True)

@classmethod
Expand All @@ -43,6 +46,15 @@ def from_json_file(cls, path: Union[str, Path]):
return cls(json.load(f))


_DEFAULT_OIDC_CLIENT_EGI = {
"id": "openeo-platform-default-client",
"grant_types": [
"authorization_code+pkce",
"urn:ietf:params:oauth:grant-type:device_code+pkce",
"refresh_token",
]
}

DEFAULT_CONFIG = AggregatorConfig(
aggregator_backends={
"vito": "https://openeo.vito.be/openeo/1.0/",
Expand All @@ -51,6 +63,32 @@ def from_json_file(cls, path: Union[str, Path]):
# "eodcdev": "https://openeo-dev.eodc.eu/v1.0/",
},
auth_entitlement_check=True,
oidc_default_clients=[
OidcProvider(
id="egi",
issuer="https://aai.egi.eu/oidc/",
scopes=[
"openid", "email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
],
title="EGI Check-in",
default_client=_DEFAULT_OIDC_CLIENT_EGI, # TODO: remove this legacy experimental field
default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
),
OidcProvider(
id="egi-dev",
issuer="https://aai-dev.egi.eu/oidc/",
scopes=[
"openid", "email",
"eduperson_entitlement",
"eduperson_scoped_affiliation",
],
title="EGI Check-in (dev)",
default_client=_DEFAULT_OIDC_CLIENT_EGI, # TODO: remove this legacy experimental field
default_clients=[_DEFAULT_OIDC_CLIENT_EGI],
),
],
)


Expand Down
34 changes: 17 additions & 17 deletions src/openeo_aggregator/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,24 @@ def map(self, callback: Callable[[BackendConnection], Any]) -> Iterator[Tuple[st
# TODO: customizable exception handling: skip, warn, re-raise?
yield con.id, res

def get_oidc_providers(self) -> List[OidcProvider]:
return self._cache.get_or_call(key="oidc_data", callback=self._build_oidc_data)
def get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]:
return self._cache.get_or_call(key="oidc_providers_per_backend", callback=self._get_oidc_providers_per_backend)

def _build_oidc_data(self) -> List[OidcProvider]:
"""
Build list of common OIDC providers to advertise as aggregator OIDC provider
and set up the provider mapping in the connections
"""
def _get_oidc_providers_per_backend(self) -> Dict[str, List[OidcProvider]]:
# Collect provider info per backend
providers_per_backend: Dict[str, List[OidcProvider]] = {}
for con in self._connections:
providers_per_backend[con.id] = []
for provider_data in con.get("/credentials/oidc", expected_status=200).json()["providers"]:
# Normalize issuer a bit to have useful intersection later.
provider_data["issuer"] = provider_data["issuer"].rstrip("/")
# Normalize issuer for sensible comparison operations.
provider_data["issuer"] = provider_data["issuer"].rstrip("/").lower()
providers_per_backend[con.id].append(OidcProvider.from_dict(provider_data))
return providers_per_backend

def build_oidc_handling(self, configured_providers: List[OidcProvider]) -> List[OidcProvider]:
providers_per_backend = self.get_oidc_providers_per_backend()

# Calculate intersection (based on issuer URL)
# Find issuers supported by each backend.
issuers_per_backend = [
set(p.issuer for p in providers)
for providers in providers_per_backend.values()
Expand All @@ -183,18 +183,18 @@ def _build_oidc_data(self) -> List[OidcProvider]:
_log.info(f"OIDC provider intersection: {intersection}")
if len(intersection) == 0:
_log.warning(f"Emtpy OIDC provider intersection. Issuers per backend: {issuers_per_backend}")
# Use provider order as used in first backend
agg_providers = [
p for p in providers_per_backend[self.first().id]
if p.issuer in intersection
]

# Build and register mapping of aggregator provider id to backend provider id.
# Take configured providers for common issuers.
agg_providers = [p for p in configured_providers if p.issuer.rstrip("/").lower() in intersection]

# Set up provider id mapping (aggregator pid to original backend pid) for the connections
for con in self._connections:
backend_providers = providers_per_backend[con.id]
pid_map = {}
for agg_provider in agg_providers:
pid_map[agg_provider.id] = next(bp.id for bp in backend_providers if bp.issuer == agg_provider.issuer)
agg_issuer = agg_provider.issuer.rstrip("/").lower()
orig_pid = next(bp.id for bp in backend_providers if bp.issuer == agg_issuer)
pid_map[agg_provider.id] = orig_pid
con.set_oidc_provider_map(pid_map)

return agg_providers
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from openeo_aggregator.backend import MultiBackendConnection
from openeo_aggregator.config import AggregatorConfig
from openeo_driver.testing import ApiTester
from openeo_driver.users.oidc import OidcProvider


@pytest.fixture
Expand Down Expand Up @@ -45,6 +46,12 @@ def config(backend1, backend2) -> AggregatorConfig:
}
# conf.flask_error_handling = False # Temporary disable flask error handlers to simplify debugging (better stack traces).

conf.oidc_providers = [
OidcProvider(id="x-agg", issuer="https://x.test", title="X (agg)"),
OidcProvider(id="y-agg", issuer="https://y.test", title="Y (agg)"),
OidcProvider(id="z-agg", issuer="https://z.test", title="Z (agg)"),
]

# Disable OIDC/EGI entitlement check by default.
conf.auth_entitlement_check = False
return conf
Expand Down
2 changes: 1 addition & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_oidc_providers(self, multi_backend_connection, config, backend1, backen
providers = implementation.oidc_providers()
assert len(providers) == 1
provider = providers[0]
expected = {"id": "y", "issuer": "https://y.test", "title": "YY", "scopes": ["openid"]}
expected = {"id": "y-agg", "issuer": "https://y.test", "title": "Y (agg)", "scopes": ["openid"]}
assert provider.prepare_for_json() == expected

def test_file_formats_simple(self, multi_backend_connection, config, backend1, backend2, requests_mock):
Expand Down
64 changes: 48 additions & 16 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,28 @@ def test_map(self, multi_backend_connection, backend1, backend2, requests_mock):
def test_api_version(self, multi_backend_connection):
assert multi_backend_connection.api_version == ComparableVersion("1.0.0")

def test_get_oidc_providers(self, multi_backend_connection, backend1, backend2):
providers = multi_backend_connection.get_oidc_providers()
@pytest.mark.parametrize(["pid", "issuer", "title"], [
("egi", "https://egi.test", "EGI"),
("agg-egi", "https://EGI.test/", "Agg EGI"),
])
def test_build_oidc_handling_basic(self, multi_backend_connection, backend1, backend2, pid, issuer, title):
providers = multi_backend_connection.build_oidc_handling(configured_providers=[
OidcProvider(id=pid, issuer=issuer, title=title),
OidcProvider(id="egi-dev", issuer="https://egi-dev.test", title="EGI dev"),
])
assert providers == [
OidcProvider(id="egi", issuer="https://egi.test", title="EGI", scopes=["openid"]),
OidcProvider(id=pid, issuer=issuer, title=title, scopes=["openid"]),
]

for con in multi_backend_connection:
assert con._oidc_provider_map == {pid: "egi"}

@pytest.mark.parametrize(["issuer_y1", "issuer_y2"], [
("https://y.test", "https://y.test"),
("https://y.test", "https://y.test/"),
("https://y.test/", "https://y.test/"),
])
def test_oidc_providers_issuer_intersection(
def test_build_oidc_handling_intersection(
self, multi_backend_connection, requests_mock, backend1, backend2, issuer_y1, issuer_y2
):
requests_mock.get(backend1 + "/credentials/oidc", json={"providers": [
Expand All @@ -183,32 +193,51 @@ def test_oidc_providers_issuer_intersection(
{"id": "z2", "issuer": "https://z.test", "title": "ZZZ2"},
]})

providers = multi_backend_connection.get_oidc_providers()
providers = multi_backend_connection.build_oidc_handling(configured_providers=[
OidcProvider("xa", "https://x.test", "A-X"),
OidcProvider("ya", "https://y.test", "A-Y"),
OidcProvider("za", "https://z.test", "A-Z"),
])
assert providers == [
OidcProvider(id="y1", issuer="https://y.test", title="YY1", scopes=["openid"]),
OidcProvider(id="ya", issuer="https://y.test", title="A-Y", scopes=["openid"]),
]

assert [con._oidc_provider_map for con in multi_backend_connection] == [
{"ya": "y1"},
{"ya": "y2"},
]

def test_oidc_providers_issuer_intersection_order(
def test_build_oidc_handling_order(
self, multi_backend_connection, requests_mock, backend1, backend2
):
requests_mock.get(backend1 + "/credentials/oidc", json={"providers": [
{"id": "d1", "issuer": "https://d.test", "title": "D1"},
{"id": "b1", "issuer": "https://b.test", "title": "B1"},
{"id": "c1", "issuer": "https://c.test", "title": "C1"},
{"id": "c1", "issuer": "https://c.test/", "title": "C1"},
{"id": "a1", "issuer": "https://a.test", "title": "A1"},
{"id": "e1", "issuer": "https://e.test", "title": "E1"},
{"id": "e1", "issuer": "https://e.test/", "title": "E1"},
]})
requests_mock.get(backend2 + "/credentials/oidc", json={"providers": [
{"id": "e2", "issuer": "https://e.test", "title": "E2"},
{"id": "b2", "issuer": "https://b.test", "title": "B2"},
{"id": "b2", "issuer": "https://b.test/", "title": "B2"},
{"id": "c2", "issuer": "https://c.test", "title": "C2"},
{"id": "a2", "issuer": "https://a.test", "title": "A2"},
{"id": "d2", "issuer": "https://d.test", "title": "D2"},
]})

providers = multi_backend_connection.get_oidc_providers()
providers = multi_backend_connection.build_oidc_handling(configured_providers=[
OidcProvider("a-b", "https://b.test", "A-B"),
OidcProvider("a-e", "https://e.test/", "A-E"),
OidcProvider("a-a", "https://a.test", "A-A"),
OidcProvider("a-d", "https://d.test", "A-D"),
OidcProvider("a-c", "https://c.test/", "A-C"),
])
assert [p.issuer for p in providers] == [
"https://d.test", "https://b.test", "https://c.test", "https://a.test", "https://e.test"
"https://b.test", "https://e.test/", "https://a.test", "https://d.test", "https://c.test/"
]
assert [con._oidc_provider_map for con in multi_backend_connection] == [
{'a-a': 'a1', 'a-b': 'b1', 'a-c': 'c1', 'a-d': 'd1', 'a-e': 'e1'},
{'a-a': 'a2', 'a-b': 'b2', 'a-c': 'c2', 'a-d': 'd2', 'a-e': 'e2'},
]

def test_oidc_provider_mapping(self, requests_mock):
Expand Down Expand Up @@ -236,9 +265,12 @@ def test_oidc_provider_mapping(self, requests_mock):

multi_backend_connection = MultiBackendConnection({"b1": domain1, "b2": domain2, "b3": domain3})

assert multi_backend_connection.get_oidc_providers() == [
OidcProvider(id="x1", issuer="https://x.test", title="X1", scopes=["openid"]),
OidcProvider(id="y1", issuer="https://y.test", title="Y1", scopes=["openid"]),
assert multi_backend_connection.build_oidc_handling(configured_providers=[
OidcProvider("ax", "https://x.test", "A-X"),
OidcProvider("ay", "https://y.test", "A-Y"),
]) == [
OidcProvider(id="ax", issuer="https://x.test", title="A-X", scopes=["openid"]),
OidcProvider(id="ay", issuer="https://y.test", title="A-Y", scopes=["openid"]),
]

def get_me(request: requests.Request, context):
Expand All @@ -250,7 +282,7 @@ def get_me(request: requests.Request, context):
requests_mock.get("https://b3.test/v1/me", json=get_me)

# Fake aggregator request containing bearer token for aggregator providers
request = flask.Request(environ={"HTTP_AUTHORIZATION": "Bearer oidc/x1/yadayadayada"})
request = flask.Request(environ={"HTTP_AUTHORIZATION": "Bearer oidc/ax/yadayadayada"})

con1 = multi_backend_connection.get_connection("b1")
with con1.authenticated_from_request(request=request):
Expand Down

0 comments on commit 68f7fbc

Please sign in to comment.