Skip to content

Commit

Permalink
feat: Allow add extra backends
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Nov 14, 2024
1 parent 4e91069 commit b9f092c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ class HuggingfaceOpenId(OpenIdConnectAuth):
DEFAULT_SCOPE = ["openid", "profile"]


def _load_backends():
_SUPPORTED_BACKENDS = {}


def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseOAuth2]]:
global _SUPPORTED_BACKENDS

backends = [
"argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId",
"social_core.backends.github.GithubOAuth2",
Expand All @@ -60,14 +65,20 @@ def _load_backends():
"social_core.backends.google.GoogleOAuth2",
"social_core.backends.google_openidconnect.GoogleOpenIdConnect",
]
return load_backends(backends, force_load=True)

if extra_backends:
backends.extend(extra_backends)

_SUPPORTED_BACKENDS = _load_backends()
_SUPPORTED_BACKENDS = load_backends(backends, force_load=True)
return _SUPPORTED_BACKENDS


def get_supported_backend_by_name(name: str) -> Type[BaseOAuth2]:
"""Get a registered oauth provider by name. Raise a ValueError if provided not found."""
global _SUPPORTED_BACKENDS

if not _SUPPORTED_BACKENDS:
_SUPPORTED_BACKENDS = load_supported_backends()

if provider := _SUPPORTED_BACKENDS.get(name):
return provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

import yaml

from argilla_server.security.authentication.oauth2._backends import get_supported_backend_by_name
from argilla_server.security.authentication.oauth2._backends import (
get_supported_backend_by_name,
load_supported_backends,
)
from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider

__all__ = ["OAuth2Settings"]
Expand Down Expand Up @@ -47,11 +50,13 @@ class OAuth2Settings:
def __init__(
self,
allow_http_redirect: bool = False,
extra_backends: List[str] = None,
**settings,
):
self.allow_http_redirect = allow_http_redirect
self.extra_backends = extra_backends or []
self.allowed_workspaces = self._build_workspaces(settings) or []
self._providers = self._build_providers(settings) or []
self._providers = self._build_providers(settings, extra_backends) or []

if self.allow_http_redirect:
# See https://stackoverflow.com/questions/27785375/testing-flask-oauthlib-locally-without-https
Expand All @@ -74,9 +79,11 @@ def _build_workspaces(cls, settings: dict) -> List[AllowedWorkspace]:
return [AllowedWorkspace(**workspace) for workspace in allowed_workspaces]

@classmethod
def _build_providers(cls, settings: dict) -> List["OAuth2ClientProvider"]:
def _build_providers(cls, settings: dict, extra_backends) -> List["OAuth2ClientProvider"]:
providers = []

load_supported_backends(extra_backends=extra_backends)

for provider in settings.pop("providers", []):
name = provider.pop("name")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,25 @@ def test_configure_huggingface_provider(self):
assert huggingface_provider.client_id == "huggingface_client_id"
assert huggingface_provider.client_secret == "huggingface_client_secret"
assert huggingface_provider.scope == ["openid", "profile", "email"]

def test_configure_extra_backends(self):
from social_core.backends.microsoft import MicrosoftOAuth2

provider_name = MicrosoftOAuth2.name
settings = OAuth2Settings(
extra_backends=["social_core.backends.microsoft.MicrosoftOAuth2"],
providers=[
{
"name": provider_name,
"client_id": "microsoft_client_id",
"client_secret": "microsoft_client_secret",
}
],
)

assert len(settings.providers) == 1
extra_provider = settings.providers[provider_name]

assert extra_provider.name == provider_name
assert extra_provider.client_id == "microsoft_client_id"
assert extra_provider.client_secret == "microsoft_client_secret"

0 comments on commit b9f092c

Please sign in to comment.