From 793beaeca0d24712a05597aec3e5814716bbbe45 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 09:58:53 +0100 Subject: [PATCH 01/22] chore: Remove oauth.enabled attribute --- argilla-server/src/argilla_server/_app.py | 3 - .../argilla_server/api/handlers/v1/oauth2.py | 18 ++--- .../authentication/oauth2/settings.py | 5 +- .../src/argilla_server/security/settings.py | 2 +- .../tests/unit/api/handlers/v1/test_oauth2.py | 66 ------------------- argilla-server/tests/unit/test_app.py | 22 +------ 6 files changed, 11 insertions(+), 105 deletions(-) diff --git a/argilla-server/src/argilla_server/_app.py b/argilla-server/src/argilla_server/_app.py index 05ad3fae04..39187aeb07 100644 --- a/argilla-server/src/argilla_server/_app.py +++ b/argilla-server/src/argilla_server/_app.py @@ -216,9 +216,6 @@ def _show_telemetry_warning(): async def _create_oauth_allowed_workspaces(db: AsyncSession): from argilla_server.security.settings import settings as security_settings - if not security_settings.oauth.enabled: - return - for allowed_workspace in security_settings.oauth.allowed_workspaces: if await Workspace.get_by(db, name=allowed_workspace.name) is None: _LOGGER.info(f"Creating workspace with name {allowed_workspace.name!r}") diff --git a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py index 5f34c57072..b0c255f5e1 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py @@ -42,21 +42,16 @@ class UserOAuthCreate(UserCreate): def get_provider_by_name_or_raise(provider: str = Path()) -> OAuth2ClientProvider: - if not settings.oauth.enabled: - raise NotFoundError(message="OAuth2 is not enabled") - - if provider in settings.oauth.providers: + try: return settings.oauth.providers[provider] - - raise NotFoundError(message=f"OAuth Provider '{provider}' not found") + except KeyError: + raise NotFoundError(message=f"OAuth Provider '{provider}' not found") @router.get("/providers", response_model=Providers) def list_providers() -> Providers: - if not settings.oauth.enabled: - return Providers(items=[]) - - return Providers(items=[Provider(name=provider_name) for provider_name in settings.oauth.providers]) + providers = [Provider(name=provider_name) for provider_name in settings.oauth.providers] + return Providers(items=providers) @router.get("/providers/{provider}/authentication") @@ -73,7 +68,8 @@ async def get_access_token( provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise), db: AsyncSession = Depends(get_async_db), ) -> Token: - userinfo = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims) + user_data = await provider.get_user_data(request) + userinfo = UserInfo(user_data).use_claims(provider.claims) if not userinfo.username: raise RuntimeError("OAuth error: Missing username") diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py index e4771bad07..904849e065 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py @@ -31,8 +31,6 @@ class OAuth2Settings: OAuth2 settings model. Args: - enabled: - Whether OAuth2 authentication is enabled or not. allow_http_redirect: Whether to allow HTTP scheme on redirect urls (for tests purposes). providers: @@ -46,12 +44,11 @@ class OAuth2Settings: def __init__( self, - enabled: bool = True, allow_http_redirect: bool = False, providers: List[OAuth2ClientProvider] = None, allowed_workspaces: List[AllowedWorkspace] = None, + **kwargs, # Ignore any other key ): - self.enabled = enabled self.allow_http_redirect = allow_http_redirect self.allowed_workspaces = allowed_workspaces or [] self._providers = providers or [] diff --git a/argilla-server/src/argilla_server/security/settings.py b/argilla-server/src/argilla_server/security/settings.py index 2715b9b3b5..b610116c81 100644 --- a/argilla-server/src/argilla_server/security/settings.py +++ b/argilla-server/src/argilla_server/security/settings.py @@ -63,7 +63,7 @@ def oauth(self) -> "OAuth2Settings": if not self._oauth_settings and os.path.exists(self.oauth_cfg): self._oauth_settings = OAuth2Settings.from_yaml(self.oauth_cfg) else: - self._oauth_settings = OAuth2Settings(enabled=False) + self._oauth_settings = OAuth2Settings() return self._oauth_settings diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index b5a3d87477..3177c896dd 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -28,16 +28,10 @@ from tests.factories import AdminFactory, AnnotatorFactory -@pytest.fixture -def disabled_oauth_settings() -> OAuth2Settings: - return OAuth2Settings(enabled=False) - - @pytest.fixture def default_oauth_settings() -> OAuth2Settings: return OAuth2Settings.from_dict( { - "enabled": True, "providers": [ { "name": "huggingface", @@ -57,25 +51,6 @@ async def tests_list_providers_with_default_config(self, async_client: AsyncClie assert response.status_code == 200 assert response.json() == {"items": []} - async def test_list_providers_with_oauth_disabled( - self, async_client: AsyncClient, owner_auth_header: dict, disabled_oauth_settings: OAuth2Settings - ): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: disabled_oauth_settings - ): - response = await async_client.get("/api/v1/oauth2/providers", headers=owner_auth_header) - assert response.status_code == 200 - assert response.json() == {"items": []} - - async def test_list_provider_with_oauth_disabled_from_settings( - self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings - ): - default_oauth_settings.enabled = False - with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): - response = await async_client.get("/api/v1/oauth2/providers", headers=owner_auth_header) - assert response.status_code == 200 - assert response.json() == {"items": []} - async def test_list_providers( self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings ): @@ -99,33 +74,6 @@ async def test_provider_huggingface_authentication( assert b"/oauth/authorize?response_type=code&client_id=client_id" in redirect_url.target assert b"&extra=params" in redirect_url.target - async def test_provider_authentication_with_oauth_disabled( - self, - async_client: AsyncClient, - owner_auth_header: dict, - disabled_oauth_settings: OAuth2Settings, - ): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: disabled_oauth_settings - ): - response = await async_client.get( - "/api/v1/oauth2/providers/huggingface/authentication", headers=owner_auth_header - ) - assert response.status_code == 404 - - async def test_provider_authentication_with_oauth_disabled_and_provider_defined( - self, - async_client: AsyncClient, - owner_auth_header: dict, - default_oauth_settings: OAuth2Settings, - ): - default_oauth_settings.enabled = False - with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): - response = await async_client.get( - "/api/v1/oauth2/providers/huggingface/authentication", headers=owner_auth_header - ) - assert response.status_code == 404 - async def test_provider_authentication_with_invalid_provider( self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings ): @@ -215,20 +163,6 @@ async def test_provider_huggingface_access_token_with_missing_name( assert user.role == UserRole.annotator assert user.first_name == "username" - async def test_provider_access_token_with_oauth_disabled( - self, - async_client: AsyncClient, - owner_auth_header: dict, - disabled_oauth_settings: OAuth2Settings, - ): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: disabled_oauth_settings - ): - response = await async_client.get( - "/api/v1/oauth2/providers/huggingface/access-token", headers=owner_auth_header - ) - assert response.status_code == 404 - async def test_provider_access_token_with_invalid_provider( self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings ): diff --git a/argilla-server/tests/unit/test_app.py b/argilla-server/tests/unit/test_app.py index 48fbc82bff..ac3e42514d 100644 --- a/argilla-server/tests/unit/test_app.py +++ b/argilla-server/tests/unit/test_app.py @@ -80,7 +80,6 @@ async def test_create_allowed_workspaces(self, db: AsyncSession): "argilla_server.security.settings.Settings.oauth", new_callable=lambda: OAuth2Settings.from_dict( { - "enabled": True, "allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}], } ), @@ -91,25 +90,8 @@ async def test_create_allowed_workspaces(self, db: AsyncSession): assert len(workspaces) == 2 assert set([ws.name for ws in workspaces]) == {"ws1", "ws2"} - async def test_create_allowed_workspaces_with_oauth_disabled(self, db: AsyncSession): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", - new_callable=lambda: OAuth2Settings.from_dict( - { - "enabled": False, - "allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}], - } - ), - ): - await _create_oauth_allowed_workspaces(db) - - workspaces = (await db.scalars(select(Workspace))).all() - assert len(workspaces) == 0 - async def test_create_workspaces_with_empty_workspaces_list(self, db: AsyncSession): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: OAuth2Settings(enabled=True) - ): + with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=OAuth2Settings): await _create_oauth_allowed_workspaces(db) workspaces = (await db.scalars(select(Workspace))).all() @@ -120,7 +102,7 @@ async def test_create_workspaces_with_existing_workspaces(self, db: AsyncSession with mock.patch( "argilla_server.security.settings.Settings.oauth", - new_callable=lambda: OAuth2Settings(enabled=True, allowed_workspaces=[AllowedWorkspace(name=ws.name)]), + new_callable=lambda: OAuth2Settings(allowed_workspaces=[AllowedWorkspace(name=ws.name)]), ): await _create_oauth_allowed_workspaces(db) From 6cd157e7aabeeaf6cc48ade7cd412cf37ee081f1 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 12:06:14 +0100 Subject: [PATCH 02/22] fix: Prefix backend name to state cookie --- .../authentication/oauth2/providers/_base.py | 7 +++++-- .../tests/unit/api/handlers/v1/test_oauth2.py | 14 +++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py index ef6586f92d..40bf216fff 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py @@ -103,7 +103,7 @@ def authorization_redirect(self, request: Request) -> RedirectResponse: response = RedirectResponse(url, 303) response.set_cookie( - self.OAUTH_STATE_COOKIE_NAME, + self._get_state_cookie_name(), value=state, secure=True, httponly=True, @@ -113,6 +113,9 @@ def authorization_redirect(self, request: Request) -> RedirectResponse: return response + def _get_state_cookie_name(self) -> str: + return f"{self.name}_{self.OAUTH_STATE_COOKIE_NAME}" + async def get_user_data(self, request: Request) -> dict: self._check_request_params(request) @@ -131,7 +134,7 @@ def _check_request_params(self, request) -> None: if "state" not in request.query_params: raise ValueError("'state' parameter was not found in callback request") - state = request.cookies.get(self.OAUTH_STATE_COOKIE_NAME) + state = request.cookies.get(self._get_state_cookie_name()) if request.query_params.get("state") != state: raise ValueError("'state' parameter does not match") diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index 3177c896dd..2630a6ce65 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -99,7 +99,7 @@ async def test_provider_huggingface_access_token( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 200 @@ -128,7 +128,7 @@ async def test_provider_huggingface_access_token_with_missing_username( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 500 @@ -149,7 +149,7 @@ async def test_provider_huggingface_access_token_with_missing_name( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 200 @@ -200,7 +200,7 @@ async def test_provider_access_token_with_invalid_state( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "invalid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 422 assert response.json() == {"detail": "'state' parameter does not match"} @@ -217,7 +217,7 @@ async def test_provider_access_token_with_authentication_error( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 401 assert response.json() == {"detail": "error"} @@ -240,7 +240,7 @@ async def test_provider_access_token_with_already_created_user( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 200 @@ -266,7 +266,7 @@ async def test_provider_access_token_with_same_username( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) # This will throw an error once we detect users created by OAuth2 assert response.status_code == 200 From c81d3868b95b9fc2ff4dd90b216d95d5d771c7fa Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 14:07:10 +0100 Subject: [PATCH 03/22] chore: Move provider to provider.py --- .../{providers/_base.py => provider.py} | 113 +++++++++++++----- .../oauth2/providers/__init__.py | 42 ------- .../oauth2/providers/_github.py | 28 ----- .../oauth2/providers/_huggingface.py | 47 -------- 4 files changed, 80 insertions(+), 150 deletions(-) rename argilla-server/src/argilla_server/security/authentication/oauth2/{providers/_base.py => provider.py} (68%) delete mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py delete mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py delete mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py similarity index 68% rename from argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py rename to argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index 40bf216fff..f249bc6c27 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -17,20 +17,21 @@ import random import re import string -from typing import Dict, Any, ClassVar, Type, Optional, Union, List, Tuple +from typing import Dict, Any, ClassVar, Type, Optional, List, Tuple from urllib.parse import urljoin import httpx from oauthlib.oauth2 import WebApplicationClient +from social_core.backends.github import GithubOAuth2 +from social_core.backends.google import GoogleOAuth2 from social_core.backends.oauth import BaseOAuth2 +from social_core.backends.open_id_connect import OpenIdConnectAuth from social_core.exceptions import AuthException - from social_core.strategy import BaseStrategy from starlette.requests import Request -from starlette.responses import RedirectResponse +from starlette.responses import RedirectResponse, Response from argilla_server.errors import future -from argilla_server.security.authentication.claims import Claims from argilla_server.security.settings import settings @@ -42,41 +43,53 @@ def absolute_uri(self, path=None) -> str: return path def get_setting(self, name): - return None + return os.environ[name] class OAuth2ClientProvider: - """OAuth2 flow handler of a certain provider.""" + """OAuth2 flow handler of a certain provider.""" OAUTH_STATE_COOKIE_NAME = "oauth2_state" OAUTH_STATE_COOKIE_MAX_AGE = 90 - name: ClassVar[str] - backend_class: ClassVar[Type[BaseOAuth2]] - claims: ClassVar[Optional[Union[Claims, dict]]] = None backend_strategy: ClassVar[BaseStrategy] = Strategy() def __init__( self, + backend_class: Type[BaseOAuth2], client_id: str = None, client_secret: str = None, scope: Optional[List[str]] = None, redirect_uri: str = None, ) -> None: + self.name = backend_class.name + self._backend = backend_class(strategy=self.backend_strategy) + + self._authorization_endpoint = self._backend.authorization_url() + self._token_endpoint = self._backend.access_token_url() + # Social Core uses the key and secret names for the client_id and client_secret + # These lines allow the use of the same environment variables as the social_core library. + # See https://python-social-auth.readthedocs.io/en/latest/configuration/settings.html for more information. + client_id = client_id or self.backend_strategy.setting("key", default=None, backend=self._backend) + client_secret = client_secret or self.backend_strategy.setting("secret", default=None, backend=self._backend) + scope = scope or self.backend_strategy.setting( + "scope", + default=self._backend.get_scope(), + backend=self._backend, + ) + self.client_id = client_id or self._environment_variable_for_property("client_id") self.client_secret = client_secret or self._environment_variable_for_property("client_secret") self.scope = scope or self._environment_variable_for_property("scope", "") - self.scope = self.scope.split(" ") if self.scope else [] - self.redirect_uri = redirect_uri or self._environment_variable_for_property("redirect_uri") - self.redirect_uri = self.redirect_uri or f"/oauth/{self.name}/callback" + if isinstance(self.scope, str): + self.scope = self.scope.split(" ") + self.scope = self.scope or [] - self._backend = self.backend_class(strategy=self.backend_strategy) - self._authorization_endpoint = self._backend.authorization_url() - self._token_endpoint = self._backend.access_token_url() + self.redirect_uri = redirect_uri or f"/oauth/{self.name}/callback" @classmethod - def from_dict(cls, provider: dict) -> "OAuth2ClientProvider": - return cls(**provider) + def from_dict(cls, provider: dict, backend_class: Type[BaseOAuth2]) -> "OAuth2ClientProvider": + return cls(backend_class=backend_class, **provider) def new_oauth_client(self) -> WebApplicationClient: return WebApplicationClient(self.client_id) @@ -102,19 +115,17 @@ def authorization_redirect(self, request: Request) -> RedirectResponse: url, state = self.authorization_url(request) response = RedirectResponse(url, 303) - response.set_cookie( - self._get_state_cookie_name(), - value=state, - secure=True, - httponly=True, - max_age=self.OAUTH_STATE_COOKIE_MAX_AGE, - samesite="none", - ) + self._set_state(state, response) return response - def _get_state_cookie_name(self) -> str: - return f"{self.name}_{self.OAUTH_STATE_COOKIE_NAME}" + def standardize(self, data: Dict[str, Any]) -> Dict[str, Any]: + data = self._backend.get_user_details(data) + + data["provider"] = self.name + data["scope"] = self.scope + + return data async def get_user_data(self, request: Request) -> dict: self._check_request_params(request) @@ -134,10 +145,13 @@ def _check_request_params(self, request) -> None: if "state" not in request.query_params: raise ValueError("'state' parameter was not found in callback request") - state = request.cookies.get(self._get_state_cookie_name()) + state = self._get_state(request) if request.query_params.get("state") != state: raise ValueError("'state' parameter does not match") + def _get_state(self, request) -> Optional[str]: + return request.cookies.get(self._get_state_cookie_name()) + @staticmethod def _align_url_to_allow_http_redirect(url: str) -> str: """This method is used to align the URL to the HTTP/HTTPS scheme""" @@ -145,11 +159,18 @@ def _align_url_to_allow_http_redirect(url: str) -> str: scheme = "http" if settings.oauth.allow_http_redirect else "https" return re.sub(r"^https?", scheme, url) - def standardize(self, data: Dict[str, Any]) -> Dict[str, Any]: - data["provider"] = self.name - data["scope"] = self.scope + def _set_state(self, state: str, response: Response) -> None: + response.set_cookie( + self._get_state_cookie_name(), + value=state, + secure=True, + httponly=True, + max_age=self.OAUTH_STATE_COOKIE_MAX_AGE, + samesite="none", + ) - return data + def _get_state_cookie_name(self) -> str: + return f"{self.name}_{self.OAUTH_STATE_COOKIE_NAME}" async def _fetch_user_data(self, authorization_response: str, **oauth_query_params) -> dict: oauth_client = self.new_oauth_client() @@ -165,7 +186,6 @@ async def _fetch_user_data(self, authorization_response: str, **oauth_query_para try: response = await session.post(token_url, headers=headers, content=content) oauth_client.parse_request_body_response(json.dumps(response.json())) - return self.standardize(self._backend.user_data(oauth_client.access_token)) except httpx.HTTPError as e: raise ValueError(str(e)) @@ -176,3 +196,30 @@ def _environment_variable_for_property(self, property_name: str, default: str = env_var_name = f"OAUTH2_{self.name.upper()}_{property_name.upper()}" return os.getenv(env_var_name, default) + + +class HuggingfaceOpenId(OpenIdConnectAuth): + """Huggingface OpenID Connect authentication backend.""" + + name = "huggingface" + + AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" + ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" + + # OIDC configuration + OIDC_ENDPOINT = "https://huggingface.co" + + +SUPPORTED_BACKENDS = { + GithubOAuth2.name: GithubOAuth2, + HuggingfaceOpenId.name: HuggingfaceOpenId, + GoogleOAuth2.name: GoogleOAuth2, +} + + +def get_supported_backend_by_name(name: str) -> Type[BaseOAuth2]: + """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" + if provider := SUPPORTED_BACKENDS.get(name): + return provider + else: + raise future.NotFoundError(f"Unsupported provider {name}. Supported providers are {SUPPORTED_BACKENDS.keys()}") diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py deleted file mode 100644 index 0950bc30d2..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Type - -from argilla_server.errors.future import NotFoundError -from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider -from argilla_server.security.authentication.oauth2.providers._github import GitHubClientProvider -from argilla_server.security.authentication.oauth2.providers._huggingface import HuggingfaceClientProvider - -__all__ = [ - "OAuth2ClientProvider", - "GitHubClientProvider", - "HuggingfaceClientProvider", - "get_provider_by_name", -] - -_ALL_SUPPORTED_OAUTH2_PROVIDERS = { - GitHubClientProvider.name: GitHubClientProvider, - HuggingfaceClientProvider.name: HuggingfaceClientProvider, -} - - -def get_provider_by_name(name: str) -> Type["OAuth2ClientProvider"]: - """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" - if provider_class := _ALL_SUPPORTED_OAUTH2_PROVIDERS.get(name): - return provider_class - else: - raise NotFoundError( - f"Unsupported provider {name}. " f"Supported providers are {_ALL_SUPPORTED_OAUTH2_PROVIDERS.keys()}" - ) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py deleted file mode 100644 index ea4f3f1918..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from social_core.backends.github import GithubOAuth2 - -from argilla_server.security.authentication.claims import Claims -from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider - - -class GitHubClientProvider(OAuth2ClientProvider): - claims = Claims( - picture="avatar_url", - identity=lambda user: f"{user.provider}:{user.id}", - username="login", - ) - backend_class = GithubOAuth2 - name = "github" diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py deleted file mode 100644 index 57365930d8..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from social_core.backends.open_id_connect import OpenIdConnectAuth - -from argilla_server.logging import LoggingMixin -from argilla_server.security.authentication.claims import Claims -from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider - -_LOGGER = logging.getLogger("argilla.security.oauth2.providers.huggingface") - - -class HuggingfaceOpenId(OpenIdConnectAuth): - """Huggingface OpenID Connect authentication backend.""" - - name = "huggingface" - - OIDC_ENDPOINT = "https://huggingface.co" - AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" - ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" - - def oidc_endpoint(self) -> str: - return self.OIDC_ENDPOINT - - -_HF_PREFERRED_USERNAME = "preferred_username" - - -class HuggingfaceClientProvider(OAuth2ClientProvider, LoggingMixin): - """Specialized HuggingFace OAuth2 provider.""" - - claims = Claims(username=_HF_PREFERRED_USERNAME, first_name="name") - backend_class = HuggingfaceOpenId - name = "huggingface" From 62915e0a484b8784bd435b4751156fbb9fedbb9b Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 14:10:20 +0100 Subject: [PATCH 04/22] refactor: Add user create validator and remove local schema for oauth2 user creation --- .../argilla_server/api/handlers/v1/oauth2.py | 25 +++-------- .../src/argilla_server/contexts/accounts.py | 27 ++++++------ .../src/argilla_server/validators/users.py | 41 +++++++++++++++++++ 3 files changed, 60 insertions(+), 33 deletions(-) create mode 100644 argilla-server/src/argilla_server/validators/users.py diff --git a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py index b0c255f5e1..0e16f07e28 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py @@ -11,21 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from fastapi import APIRouter, Depends, Request, Path from fastapi.responses import RedirectResponse from sqlalchemy.ext.asyncio import AsyncSession -from argilla_server import telemetry from argilla_server.api.schemas.v1.oauth2 import Provider, Providers, Token -from argilla_server.api.schemas.v1.users import UserCreate from argilla_server.contexts import accounts from argilla_server.database import get_async_db -from argilla_server.enums import UserRole from argilla_server.errors.future import NotFoundError from argilla_server.models import User -from argilla_server.pydantic_v1 import Field from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider from argilla_server.security.authentication.userinfo import UserInfo from argilla_server.security.settings import settings @@ -33,14 +28,6 @@ router = APIRouter(prefix="/oauth2", tags=["Authentication"]) -class UserOAuthCreate(UserCreate): - """This schema is used to validate the creation of a new user by using the oauth userinfo""" - - username: str = Field(min_length=1) - role: Optional[UserRole] - password: Optional[str] = None - - def get_provider_by_name_or_raise(provider: str = Path()) -> OAuth2ClientProvider: try: return settings.oauth.providers[provider] @@ -55,7 +42,7 @@ def list_providers() -> Providers: @router.get("/providers/{provider}/authentication") -def get_authentication( +async def get_authentication( request: Request, provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise), ) -> RedirectResponse: @@ -69,7 +56,7 @@ async def get_access_token( db: AsyncSession = Depends(get_async_db), ) -> Token: user_data = await provider.get_user_data(request) - userinfo = UserInfo(user_data).use_claims(provider.claims) + userinfo = UserInfo(user_data) if not userinfo.username: raise RuntimeError("OAuth error: Missing username") @@ -78,11 +65,9 @@ async def get_access_token( if user is None: user = await accounts.create_user_with_random_password( db, - **UserOAuthCreate( - username=userinfo.username, - first_name=userinfo.first_name, - role=userinfo.role, - ).dict(exclude_unset=True), + username=userinfo.username, + first_name=userinfo.first_name, + role=userinfo.role, workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces], ) diff --git a/argilla-server/src/argilla_server/contexts/accounts.py b/argilla-server/src/argilla_server/contexts/accounts.py index 01aa1fa8d5..e1ec8a873d 100644 --- a/argilla-server/src/argilla_server/contexts/accounts.py +++ b/argilla-server/src/argilla_server/contexts/accounts.py @@ -26,6 +26,7 @@ from argilla_server.models import User, Workspace, WorkspaceUser from argilla_server.security.authentication.jwt import JWT from argilla_server.security.authentication.userinfo import UserInfo +from argilla_server.validators.users import UserCreateValidator _CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -54,7 +55,7 @@ async def list_workspaces(db: AsyncSession) -> List[Workspace]: return result.scalars().all() -async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> List[Workspace]: +async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> Sequence[Workspace]: result = await db.execute( select(Workspace) .join(WorkspaceUser) @@ -104,22 +105,22 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U return result.scalars().all() -# TODO: After removing API v0 implementation we can remove the workspaces attribute. -# With API v1 the workspaces will be created doing additional requests to other endpoints for it. -async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User: - if await get_user_by_username(db, user_attrs["username"]) is not None: - raise NotUniqueError(f"User username `{user_attrs['username']}` is not unique") - - user = await User.create( - db, +async def create_user( + db: AsyncSession, + user_attrs: dict, + workspaces: Union[List[str], None] = None, +) -> User: + new_user = User( first_name=user_attrs["first_name"], last_name=user_attrs["last_name"], username=user_attrs["username"], role=user_attrs["role"], password_hash=hash_password(user_attrs["password"]), - autocommit=False, ) + await UserCreateValidator.validate(db, user=new_user) + + await new_user.save(db, autocommit=False) if workspaces is not None: for workspace_name in workspaces: workspace = await Workspace.get_by(db, name=workspace_name) @@ -128,14 +129,14 @@ async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List await WorkspaceUser.create( db, - workspace_id=workspace.id, - user_id=user.id, + workspace=workspace, + user=new_user, autocommit=False, ) await db.commit() - return user + return new_user async def create_user_with_random_password( diff --git a/argilla-server/src/argilla_server/validators/users.py b/argilla-server/src/argilla_server/validators/users.py new file mode 100644 index 0000000000..3d506fb032 --- /dev/null +++ b/argilla-server/src/argilla_server/validators/users.py @@ -0,0 +1,41 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError +from argilla_server.models import User + + +class UserCreateValidator: + @classmethod + async def validate(cls, db: AsyncSession, user: User) -> None: + await cls._validate_username(db, user) + + @classmethod + async def _validate_username(cls, db, user: User): + await cls._validate_username_length(user) + await cls._validate_unique_username(db, user) + + @classmethod + async def _validate_unique_username(cls, db, user): + from argilla_server.contexts import accounts + + if await accounts.get_user_by_username(db, user.username) is not None: + raise NotUniqueError(f"User username `{user.username}` is not unique") + + @classmethod + async def _validate_username_length(cls, user: User): + if len(user.username) < 1: + raise UnprocessableEntityError("Username must be at least 1 characters long") From 1e4bda5fa68f334f8ee457c7c7801f696df8f5f9 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 14:11:36 +0100 Subject: [PATCH 05/22] chore: Remove Claims --- .../security/authentication/claims.py | 40 ------------------- .../authentication/oauth2/__init__.py | 2 +- .../authentication/oauth2/auth_backend.py | 7 +--- .../authentication/oauth2/settings.py | 6 +-- .../security/authentication/userinfo.py | 17 ++++---- 5 files changed, 13 insertions(+), 59 deletions(-) delete mode 100644 argilla-server/src/argilla_server/security/authentication/claims.py diff --git a/argilla-server/src/argilla_server/security/authentication/claims.py b/argilla-server/src/argilla_server/security/authentication/claims.py deleted file mode 100644 index a34696a685..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/claims.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from typing import Any, Callable, Union, Optional - -from argilla_server.enums import UserRole - - -def _parse_role_from_environment(userinfo: dict) -> Optional[UserRole]: - """This is a temporal solution, and it will be replaced by a proper Sign up process""" - if userinfo["username"] == os.getenv("USERNAME"): - return UserRole.owner - - -class Claims(dict): - """Claims configuration for a single provider.""" - - display_name: Union[str, Callable[[dict], Any]] - identity: Union[str, Callable[[dict], Any]] - picture: Union[str, Callable[[dict], Any]] - email: Union[str, Callable[[dict], Any]] - - def __init__(self, seq=None, **kwargs) -> None: - super().__init__(seq or {}, **kwargs) - self["display_name"] = kwargs.get("display_name", self.get("display_name", "name")) - self["identity"] = kwargs.get("identity", self.get("identity", "sub")) - self["picture"] = kwargs.get("picture", self.get("picture", "picture")) - self["email"] = kwargs.get("email", self.get("email", "email")) - self["role"] = kwargs.get("role", _parse_role_from_environment) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py b/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py index f8dcc52a18..f5605ab591 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider # noqa +from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider # noqa from argilla_server.security.authentication.oauth2.settings import OAuth2Settings # noqa __all__ = ["OAuth2Settings", "OAuth2ClientProvider"] diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py b/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py index 53774a2bf8..41f84d0b35 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py @@ -19,7 +19,7 @@ from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser from argilla_server.security.authentication.jwt import JWT -from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider +from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider from argilla_server.security.authentication.userinfo import UserInfo @@ -39,7 +39,4 @@ async def authenticate(self, request: Request) -> Optional[Tuple[AuthCredentials token_data = JWT.decode(credentials.credentials) user = UserInfo(token_data) - provider = self.providers.get(user.get("provider")) - claims = provider.claims if provider else {} - - return AuthCredentials(user.pop("scope", [])), user.use_claims(claims) + return AuthCredentials(user.pop("scope", [])), user diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py index 904849e065..92e000e3e6 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py @@ -16,7 +16,7 @@ import yaml -from argilla_server.security.authentication.oauth2.providers import get_provider_by_name, OAuth2ClientProvider +from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider, get_supported_backend_by_name __all__ = ["OAuth2Settings"] @@ -88,8 +88,8 @@ def _build_providers(cls, settings: dict) -> List["OAuth2ClientProvider"]: for provider in settings.pop("providers", []): name = provider.pop("name") - provider_class = get_provider_by_name(name) - providers.append(provider_class.from_dict(provider)) + backend_class = get_supported_backend_by_name(name) + providers.append(OAuth2ClientProvider.from_dict(provider, backend_class)) return providers diff --git a/argilla-server/src/argilla_server/security/authentication/userinfo.py b/argilla-server/src/argilla_server/security/authentication/userinfo.py index 54220fc027..70173cda7e 100644 --- a/argilla-server/src/argilla_server/security/authentication/userinfo.py +++ b/argilla-server/src/argilla_server/security/authentication/userinfo.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import Any, Optional from starlette.authentication import BaseUser from argilla_server.enums import UserRole -from argilla_server.security.authentication.claims import Claims _DEFAULT_USER_ROLE = UserRole.annotator @@ -39,16 +38,14 @@ def first_name(self) -> str: @property def role(self) -> UserRole: - role = self.get("role") or _DEFAULT_USER_ROLE + role = self.get("role") or self._parse_role_from_environment() return UserRole(role) - def use_claims(self, claims: Optional[Claims]) -> "UserInfo": - claims = claims or {} - - for attr, item in claims.items(): - self[attr] = self.__getprop__(item) - - return self + def _parse_role_from_environment(self) -> Optional[UserRole]: + """This is a temporal solution, and it will be replaced by a proper Sign up process""" + if self["username"] == os.getenv("USERNAME"): + return UserRole.owner + return _DEFAULT_USER_ROLE def __getprop__(self, item, default="") -> Any: if callable(item): From 2fbd7167f0ee1468d99b2bb274c53497a71209e9 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 14:30:04 +0100 Subject: [PATCH 06/22] chore: update tests --- .../tests/unit/api/handlers/v1/test_oauth2.py | 20 +++++++++---------- .../security/authentication/test_userinfo.py | 14 +------------ 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index 2630a6ce65..48d5c2b3a5 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -92,8 +92,8 @@ async def test_provider_huggingface_access_token( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": "username", "name": "name"}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": "username", "name": "name"}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", @@ -121,7 +121,7 @@ async def test_provider_huggingface_access_token_with_missing_username( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", return_value={"name": "name"}, ): response = await async_client.get( @@ -142,8 +142,8 @@ async def test_provider_huggingface_access_token_with_missing_name( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": "username"}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": "username"}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", @@ -210,7 +210,7 @@ async def test_provider_access_token_with_authentication_error( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", side_effect=AuthenticationError("error"), ): response = await async_client.get( @@ -233,8 +233,8 @@ async def test_provider_access_token_with_already_created_user( with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": admin.username, "name": admin.first_name}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": admin.username, "name": admin.first_name}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", @@ -259,8 +259,8 @@ async def test_provider_access_token_with_same_username( with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": user.username, "name": user.first_name}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": user.username, "name": user.first_name}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", diff --git a/argilla-server/tests/unit/security/authentication/test_userinfo.py b/argilla-server/tests/unit/security/authentication/test_userinfo.py index 8203d56c66..200fb940c1 100644 --- a/argilla-server/tests/unit/security/authentication/test_userinfo.py +++ b/argilla-server/tests/unit/security/authentication/test_userinfo.py @@ -18,7 +18,6 @@ from argilla_server.enums import UserRole from argilla_server.security.authentication import UserInfo -from argilla_server.security.authentication.claims import Claims class TestUserInfo: @@ -43,19 +42,8 @@ def test_get_userinfo_role(self): userinfo = UserInfo({"username": "user", "role": "owner"}) assert userinfo.role == UserRole.owner - def test_get_userinfo_with_claims(self): - userinfo = UserInfo({"username": "user"}).use_claims( - Claims( - first_name=lambda user: user["username"].upper(), - last_name=lambda user: "Peter", - ) - ) - - assert userinfo.first_name == "USER" - assert userinfo.last_name == "Peter" - def test_get_userinfo_role_with_username_env(self, mocker: MockerFixture): mocker.patch.dict(os.environ, {"USERNAME": "user"}) - userinfo = UserInfo({"id": "user"}).use_claims(Claims(username="id")) + userinfo = UserInfo({"username": "user"}) assert userinfo.role == UserRole.owner From 4d6912657f38880017c87bc56b0b5a1c778264c3 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 15:11:42 +0100 Subject: [PATCH 07/22] chore: Add more supported backends --- .../authentication/oauth2/provider.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index f249bc6c27..c420a99a2a 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -22,8 +22,14 @@ import httpx from oauthlib.oauth2 import WebApplicationClient -from social_core.backends.github import GithubOAuth2 +from social_core.backends.github import GithubOAuth2, GithubOrganizationOAuth2, GithubTeamOAuth2 +from social_core.backends.github_enterprise import ( + GithubEnterpriseOAuth2, + GithubEnterpriseOrganizationOAuth2, + GithubEnterpriseTeamOAuth2, +) from social_core.backends.google import GoogleOAuth2 +from social_core.backends.google_openidconnect import GoogleOpenIdConnect from social_core.backends.oauth import BaseOAuth2 from social_core.backends.open_id_connect import OpenIdConnectAuth from social_core.exceptions import AuthException @@ -210,15 +216,25 @@ class HuggingfaceOpenId(OpenIdConnectAuth): OIDC_ENDPOINT = "https://huggingface.co" -SUPPORTED_BACKENDS = { - GithubOAuth2.name: GithubOAuth2, - HuggingfaceOpenId.name: HuggingfaceOpenId, - GoogleOAuth2.name: GoogleOAuth2, -} +_BACKENDS = [ + HuggingfaceOpenId, + GoogleOAuth2, + GoogleOpenIdConnect, + GithubOAuth2, + GithubEnterpriseOAuth2, + GithubTeamOAuth2, + GithubEnterpriseTeamOAuth2, + GithubEnterpriseTeamOAuth2, + GithubOrganizationOAuth2, + GithubEnterpriseOrganizationOAuth2, +] + +SUPPORTED_BACKENDS = {backend.name: backend for backend in _BACKENDS} def get_supported_backend_by_name(name: str) -> Type[BaseOAuth2]: """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" + if provider := SUPPORTED_BACKENDS.get(name): return provider else: From ac2324955995b058a4a8a1644d8588cfb15f9be5 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Nov 2024 15:43:52 +0100 Subject: [PATCH 08/22] chore: Add discord oauth --- .../argilla_server/security/authentication/oauth2/provider.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index c420a99a2a..4b9e3b0108 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -22,6 +22,7 @@ import httpx from oauthlib.oauth2 import WebApplicationClient +from social_core.backends.discord import DiscordOAuth2 from social_core.backends.github import GithubOAuth2, GithubOrganizationOAuth2, GithubTeamOAuth2 from social_core.backends.github_enterprise import ( GithubEnterpriseOAuth2, @@ -227,6 +228,7 @@ class HuggingfaceOpenId(OpenIdConnectAuth): GithubEnterpriseTeamOAuth2, GithubOrganizationOAuth2, GithubEnterpriseOrganizationOAuth2, + DiscordOAuth2, ] SUPPORTED_BACKENDS = {backend.name: backend for backend in _BACKENDS} From 08d3f88e5bcacafd48f839021887b0632a36ea0d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 10:11:37 +0100 Subject: [PATCH 09/22] fix: Read argilla env vars first --- .../security/authentication/oauth2/provider.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index 4b9e3b0108..67a6d22dc2 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -74,23 +74,25 @@ def __init__( self._authorization_endpoint = self._backend.authorization_url() self._token_endpoint = self._backend.access_token_url() + # Social Core uses the key and secret names for the client_id and client_secret # These lines allow the use of the same environment variables as the social_core library. # See https://python-social-auth.readthedocs.io/en/latest/configuration/settings.html for more information. - client_id = client_id or self.backend_strategy.setting("key", default=None, backend=self._backend) - client_secret = client_secret or self.backend_strategy.setting("secret", default=None, backend=self._backend) - scope = scope or self.backend_strategy.setting( + self.client_id = ( + client_id or self._environment_variable_for_property("client_id") + ) or self._backend.strategy.setting("key") + + self.client_secret = ( + client_secret or self._environment_variable_for_property("client_secret") + ) or self._backend.strategy.setting("secret") + + self.scope = (scope or self._environment_variable_for_property("scope")) or self.backend_strategy.setting( "scope", default=self._backend.get_scope(), backend=self._backend, ) - - self.client_id = client_id or self._environment_variable_for_property("client_id") - self.client_secret = client_secret or self._environment_variable_for_property("client_secret") - self.scope = scope or self._environment_variable_for_property("scope", "") if isinstance(self.scope, str): self.scope = self.scope.split(" ") - self.scope = self.scope or [] self.redirect_uri = redirect_uri or f"/oauth/{self.name}/callback" From 8773c088475bfcb2e73e83f674fe1d66a721ffb0 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 10:12:06 +0100 Subject: [PATCH 10/22] chore: Add default scope for huggingface backend --- .../argilla_server/security/authentication/oauth2/provider.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index 67a6d22dc2..35f976c4a1 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -218,6 +218,8 @@ class HuggingfaceOpenId(OpenIdConnectAuth): # OIDC configuration OIDC_ENDPOINT = "https://huggingface.co" + DEFAULT_SCOPE = ["openid", "profile"] + _BACKENDS = [ HuggingfaceOpenId, From 125889bb7f010beec7d4822f797bb5843d7c7af9 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 11:32:11 +0100 Subject: [PATCH 11/22] fix: Using backend.settings --- .../security/authentication/oauth2/provider.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index 35f976c4a1..fcdd07d73c 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -78,18 +78,17 @@ def __init__( # Social Core uses the key and secret names for the client_id and client_secret # These lines allow the use of the same environment variables as the social_core library. # See https://python-social-auth.readthedocs.io/en/latest/configuration/settings.html for more information. - self.client_id = ( - client_id or self._environment_variable_for_property("client_id") - ) or self._backend.strategy.setting("key") + self.client_id = (client_id or self._environment_variable_for_property("client_id")) or self._backend.setting( + "key" + ) self.client_secret = ( client_secret or self._environment_variable_for_property("client_secret") - ) or self._backend.strategy.setting("secret") + ) or self._backend.setting("secret") - self.scope = (scope or self._environment_variable_for_property("scope")) or self.backend_strategy.setting( + self.scope = (scope or self._environment_variable_for_property("scope")) or self._backend.setting( "scope", default=self._backend.get_scope(), - backend=self._backend, ) if isinstance(self.scope, str): self.scope = self.scope.split(" ") From 6e2bba593a5f9863ca2a12aff8d9243b5062fc93 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 14 Nov 2024 12:22:19 +0100 Subject: [PATCH 12/22] fix: Revert code refactor --- argilla-server/src/argilla_server/contexts/accounts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/argilla-server/src/argilla_server/contexts/accounts.py b/argilla-server/src/argilla_server/contexts/accounts.py index e1ec8a873d..a4890c19b4 100644 --- a/argilla-server/src/argilla_server/contexts/accounts.py +++ b/argilla-server/src/argilla_server/contexts/accounts.py @@ -129,8 +129,8 @@ async def create_user( await WorkspaceUser.create( db, - workspace=workspace, - user=new_user, + workspace_id=workspace.id, + user_id=new_user.id, autocommit=False, ) From 9ba0f00d3222f562cbdd6cc4040a1a8fb8ad7fbd Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 13:19:20 +0100 Subject: [PATCH 13/22] chore: Add last-fm support --- .../argilla_server/security/authentication/oauth2/provider.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index fcdd07d73c..a7fe717c27 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -31,6 +31,7 @@ ) from social_core.backends.google import GoogleOAuth2 from social_core.backends.google_openidconnect import GoogleOpenIdConnect +from social_core.backends.lastfm import LastFmAuth from social_core.backends.oauth import BaseOAuth2 from social_core.backends.open_id_connect import OpenIdConnectAuth from social_core.exceptions import AuthException @@ -231,7 +232,7 @@ class HuggingfaceOpenId(OpenIdConnectAuth): GithubEnterpriseTeamOAuth2, GithubOrganizationOAuth2, GithubEnterpriseOrganizationOAuth2, - DiscordOAuth2, + LastFmAuth, ] SUPPORTED_BACKENDS = {backend.name: backend for backend in _BACKENDS} From 7b919d2edafe22e1b82b1e18da601bba0d848722 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 15:43:57 +0100 Subject: [PATCH 14/22] refactor: Moving backend logic to a separate module and load backends by str --- .../authentication/oauth2/_backends.py | 75 +++++++++++++++++++ .../authentication/oauth2/provider.py | 63 +--------------- .../authentication/oauth2/settings.py | 3 +- 3 files changed, 78 insertions(+), 63 deletions(-) create mode 100644 argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py new file mode 100644 index 0000000000..9a2aa67f9a --- /dev/null +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py @@ -0,0 +1,75 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Type, Dict, Any + +from social_core.backends.oauth import BaseOAuth2 +from social_core.backends.open_id_connect import OpenIdConnectAuth +from social_core.backends.utils import load_backends +from social_core.strategy import BaseStrategy + +from argilla_server.errors.future import NotFoundError + + +class Strategy(BaseStrategy): + def request_data(self, merge=True) -> Dict[str, Any]: + return {} + + def absolute_uri(self, path=None) -> str: + return path + + def get_setting(self, name): + return os.environ[name] + + +class HuggingfaceOpenId(OpenIdConnectAuth): + """Huggingface OpenID Connect authentication backend.""" + + name = "huggingface" + + AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" + ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" + + # OIDC configuration + OIDC_ENDPOINT = "https://huggingface.co" + + DEFAULT_SCOPE = ["openid", "profile"] + + +def _load_backends(): + backends = [ + "argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId", + "social_core.backends.github.GithubOAuth2", + "social_core.backends.github.GithubOrganizationOAuth2", + "social_core.backends.github.GithubTeamOAuth2", + "social_core.backends.github_enterprise.GithubEnterpriseOAuth2", + "social_core.backends.github_enterprise.GithubEnterpriseOrganizationOAuth2", + "social_core.backends.github_enterprise.GithubEnterpriseTeamOAuth2", + "social_core.backends.google.GoogleOAuth2", + "social_core.backends.google_openidconnect.GoogleOpenIdConnect", + ] + return load_backends(backends, force_load=True) + + +_SUPPORTED_BACKENDS = _load_backends() + + +def get_supported_backend_by_name(name: str) -> Type[BaseOAuth2]: + """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" + + if provider := _SUPPORTED_BACKENDS.get(name): + return provider + else: + raise NotFoundError(f"Unsupported provider {name}. Supported providers are {_SUPPORTED_BACKENDS.keys()}") diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index a7fe717c27..7faa025a53 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -22,38 +22,17 @@ import httpx from oauthlib.oauth2 import WebApplicationClient -from social_core.backends.discord import DiscordOAuth2 -from social_core.backends.github import GithubOAuth2, GithubOrganizationOAuth2, GithubTeamOAuth2 -from social_core.backends.github_enterprise import ( - GithubEnterpriseOAuth2, - GithubEnterpriseOrganizationOAuth2, - GithubEnterpriseTeamOAuth2, -) -from social_core.backends.google import GoogleOAuth2 -from social_core.backends.google_openidconnect import GoogleOpenIdConnect -from social_core.backends.lastfm import LastFmAuth from social_core.backends.oauth import BaseOAuth2 -from social_core.backends.open_id_connect import OpenIdConnectAuth from social_core.exceptions import AuthException from social_core.strategy import BaseStrategy from starlette.requests import Request from starlette.responses import RedirectResponse, Response from argilla_server.errors import future +from argilla_server.security.authentication.oauth2._backends import Strategy from argilla_server.security.settings import settings -class Strategy(BaseStrategy): - def request_data(self, merge=True) -> Dict[str, Any]: - return {} - - def absolute_uri(self, path=None) -> str: - return path - - def get_setting(self, name): - return os.environ[name] - - class OAuth2ClientProvider: """OAuth2 flow handler of a certain provider.""" @@ -205,43 +184,3 @@ def _environment_variable_for_property(self, property_name: str, default: str = env_var_name = f"OAUTH2_{self.name.upper()}_{property_name.upper()}" return os.getenv(env_var_name, default) - - -class HuggingfaceOpenId(OpenIdConnectAuth): - """Huggingface OpenID Connect authentication backend.""" - - name = "huggingface" - - AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" - ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" - - # OIDC configuration - OIDC_ENDPOINT = "https://huggingface.co" - - DEFAULT_SCOPE = ["openid", "profile"] - - -_BACKENDS = [ - HuggingfaceOpenId, - GoogleOAuth2, - GoogleOpenIdConnect, - GithubOAuth2, - GithubEnterpriseOAuth2, - GithubTeamOAuth2, - GithubEnterpriseTeamOAuth2, - GithubEnterpriseTeamOAuth2, - GithubOrganizationOAuth2, - GithubEnterpriseOrganizationOAuth2, - LastFmAuth, -] - -SUPPORTED_BACKENDS = {backend.name: backend for backend in _BACKENDS} - - -def get_supported_backend_by_name(name: str) -> Type[BaseOAuth2]: - """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" - - if provider := SUPPORTED_BACKENDS.get(name): - return provider - else: - raise future.NotFoundError(f"Unsupported provider {name}. Supported providers are {SUPPORTED_BACKENDS.keys()}") diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py index 92e000e3e6..04270c5c92 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py @@ -16,7 +16,8 @@ import yaml -from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider, get_supported_backend_by_name +from argilla_server.security.authentication.oauth2._backends import get_supported_backend_by_name +from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider __all__ = ["OAuth2Settings"] From 4e91069c560b5054395b7615df84799b37006d44 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 15:57:52 +0100 Subject: [PATCH 15/22] refactor: Passing config dict to OAuthSettings constructor --- .../authentication/oauth2/settings.py | 20 +++------ .../tests/unit/api/handlers/v1/test_oauth2.py | 18 ++++---- .../authentication/oauth2/test_settings.py | 42 +++++++++---------- argilla-server/tests/unit/test_app.py | 8 +--- 4 files changed, 34 insertions(+), 54 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py index 04270c5c92..3635f48ec7 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py @@ -42,17 +42,16 @@ class OAuth2Settings: ALLOWED_WORKSPACES_KEY = "allowed_workspaces" PROVIDERS_KEY = "providers" + EXTRA_BACKENDS_KEY = "extra_backends" def __init__( self, allow_http_redirect: bool = False, - providers: List[OAuth2ClientProvider] = None, - allowed_workspaces: List[AllowedWorkspace] = None, - **kwargs, # Ignore any other key + **settings, ): self.allow_http_redirect = allow_http_redirect - self.allowed_workspaces = allowed_workspaces or [] - self._providers = providers or [] + self.allowed_workspaces = self._build_workspaces(settings) or [] + self._providers = self._build_providers(settings) or [] if self.allow_http_redirect: # See https://stackoverflow.com/questions/27785375/testing-flask-oauthlib-locally-without-https @@ -67,16 +66,7 @@ def from_yaml(cls, yaml_file: str) -> "OAuth2Settings": """Creates an instance of OAuth2Settings from a YAML file.""" with open(yaml_file) as f: - return cls.from_dict(yaml.safe_load(f)) - - @classmethod - def from_dict(cls, settings: dict) -> "OAuth2Settings": - """Creates an instance of OAuth2Settings from a dictionary.""" - - settings[cls.PROVIDERS_KEY] = cls._build_providers(settings) - settings[cls.ALLOWED_WORKSPACES_KEY] = cls._build_workspaces(settings) - - return cls(**settings) + return cls(**yaml.safe_load(f)) @classmethod def _build_workspaces(cls, settings: dict) -> List[AllowedWorkspace]: diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index 48d5c2b3a5..76bc4c6b4c 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -30,16 +30,14 @@ @pytest.fixture def default_oauth_settings() -> OAuth2Settings: - return OAuth2Settings.from_dict( - { - "providers": [ - { - "name": "huggingface", - "client_id": "client_id", - "client_secret": "client_secret", - } - ], - } + return OAuth2Settings( + providers=[ + { + "name": "huggingface", + "client_id": "client_id", + "client_secret": "client_secret", + } + ] ) diff --git a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py index c0175f273b..b7d9411adb 100644 --- a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py +++ b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py @@ -21,20 +21,18 @@ class TestOAuth2Settings: def test_configure_unsupported_provider(self): with pytest.raises(NotFoundError): - OAuth2Settings.from_dict({"providers": [{"name": "unsupported"}]}) + OAuth2Settings(providers=[{"name": "unsupported"}]) def test_configure_github_provider(self): - settings = OAuth2Settings.from_dict( - { - "providers": [ - { - "name": "github", - "client_id": "github_client_id", - "client_secret": "github_client_secret", - "scope": "user:email", - } - ] - } + settings = OAuth2Settings( + providers=[ + { + "name": "github", + "client_id": "github_client_id", + "client_secret": "github_client_secret", + "scope": "user:email", + } + ] ) github_provider = settings.providers["github"] @@ -44,17 +42,15 @@ def test_configure_github_provider(self): assert github_provider.scope == ["user:email"] def test_configure_huggingface_provider(self): - settings = OAuth2Settings.from_dict( - { - "providers": [ - { - "name": "huggingface", - "client_id": "huggingface_client_id", - "client_secret": "huggingface_client_secret", - "scope": "openid profile email", - } - ] - } + settings = OAuth2Settings( + providers=[ + { + "name": "huggingface", + "client_id": "huggingface_client_id", + "client_secret": "huggingface_client_secret", + "scope": "openid profile email", + } + ] ) huggingface_provider = settings.providers["huggingface"] diff --git a/argilla-server/tests/unit/test_app.py b/argilla-server/tests/unit/test_app.py index ac3e42514d..93de167ed6 100644 --- a/argilla-server/tests/unit/test_app.py +++ b/argilla-server/tests/unit/test_app.py @@ -78,11 +78,7 @@ def test_server_timing_header(self): async def test_create_allowed_workspaces(self, db: AsyncSession): with mock.patch( "argilla_server.security.settings.Settings.oauth", - new_callable=lambda: OAuth2Settings.from_dict( - { - "allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}], - } - ), + new_callable=lambda: OAuth2Settings(allowed_workspaces=[{"name": "ws1"}, {"name": "ws2"}]), ): await _create_oauth_allowed_workspaces(db) @@ -102,7 +98,7 @@ async def test_create_workspaces_with_existing_workspaces(self, db: AsyncSession with mock.patch( "argilla_server.security.settings.Settings.oauth", - new_callable=lambda: OAuth2Settings(allowed_workspaces=[AllowedWorkspace(name=ws.name)]), + new_callable=lambda: OAuth2Settings(allowed_workspaces=[{"name": ws.name}]), ): await _create_oauth_allowed_workspaces(db) From b9f092c151791801964f05a3301331c17ba4ebee Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 16:13:09 +0100 Subject: [PATCH 16/22] feat: Allow add extra backends --- .../authentication/oauth2/_backends.py | 17 +++++++++++--- .../authentication/oauth2/settings.py | 13 ++++++++--- .../authentication/oauth2/test_settings.py | 22 +++++++++++++++++++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py index 9a2aa67f9a..385c11801f 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py @@ -48,7 +48,12 @@ class HuggingfaceOpenId(OpenIdConnectAuth): DEFAULT_SCOPE = ["openid", "profile"] -def _load_backends(): +_SUPPORTED_BACKENDS = {} + + +def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseOAuth2]]: + global _SUPPORTED_BACKENDS + backends = [ "argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId", "social_core.backends.github.GithubOAuth2", @@ -60,14 +65,20 @@ def _load_backends(): "social_core.backends.google.GoogleOAuth2", "social_core.backends.google_openidconnect.GoogleOpenIdConnect", ] - return load_backends(backends, force_load=True) + if extra_backends: + backends.extend(extra_backends) -_SUPPORTED_BACKENDS = _load_backends() + _SUPPORTED_BACKENDS = load_backends(backends, force_load=True) + return _SUPPORTED_BACKENDS def get_supported_backend_by_name(name: str) -> Type[BaseOAuth2]: """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" + global _SUPPORTED_BACKENDS + + if not _SUPPORTED_BACKENDS: + _SUPPORTED_BACKENDS = load_supported_backends() if provider := _SUPPORTED_BACKENDS.get(name): return provider diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py index 3635f48ec7..61c702ec69 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py @@ -16,7 +16,10 @@ import yaml -from argilla_server.security.authentication.oauth2._backends import get_supported_backend_by_name +from argilla_server.security.authentication.oauth2._backends import ( + get_supported_backend_by_name, + load_supported_backends, +) from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider __all__ = ["OAuth2Settings"] @@ -47,11 +50,13 @@ class OAuth2Settings: def __init__( self, allow_http_redirect: bool = False, + extra_backends: List[str] = None, **settings, ): self.allow_http_redirect = allow_http_redirect + self.extra_backends = extra_backends or [] self.allowed_workspaces = self._build_workspaces(settings) or [] - self._providers = self._build_providers(settings) or [] + self._providers = self._build_providers(settings, extra_backends) or [] if self.allow_http_redirect: # See https://stackoverflow.com/questions/27785375/testing-flask-oauthlib-locally-without-https @@ -74,9 +79,11 @@ def _build_workspaces(cls, settings: dict) -> List[AllowedWorkspace]: return [AllowedWorkspace(**workspace) for workspace in allowed_workspaces] @classmethod - def _build_providers(cls, settings: dict) -> List["OAuth2ClientProvider"]: + def _build_providers(cls, settings: dict, extra_backends) -> List["OAuth2ClientProvider"]: providers = [] + load_supported_backends(extra_backends=extra_backends) + for provider in settings.pop("providers", []): name = provider.pop("name") diff --git a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py index b7d9411adb..370424a9af 100644 --- a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py +++ b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py @@ -58,3 +58,25 @@ def test_configure_huggingface_provider(self): assert huggingface_provider.client_id == "huggingface_client_id" assert huggingface_provider.client_secret == "huggingface_client_secret" assert huggingface_provider.scope == ["openid", "profile", "email"] + + def test_configure_extra_backends(self): + from social_core.backends.microsoft import MicrosoftOAuth2 + + provider_name = MicrosoftOAuth2.name + settings = OAuth2Settings( + extra_backends=["social_core.backends.microsoft.MicrosoftOAuth2"], + providers=[ + { + "name": provider_name, + "client_id": "microsoft_client_id", + "client_secret": "microsoft_client_secret", + } + ], + ) + + assert len(settings.providers) == 1 + extra_provider = settings.providers[provider_name] + + assert extra_provider.name == provider_name + assert extra_provider.client_id == "microsoft_client_id" + assert extra_provider.client_secret == "microsoft_client_secret" From 054703aeb5c54b00e19228048434a997ec96fddc Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 14 Nov 2024 16:13:55 +0100 Subject: [PATCH 17/22] chore: Configure minimal set of oauth backends --- .../security/authentication/oauth2/_backends.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py index 385c11801f..019ba02d00 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py @@ -57,13 +57,7 @@ def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseO backends = [ "argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId", "social_core.backends.github.GithubOAuth2", - "social_core.backends.github.GithubOrganizationOAuth2", - "social_core.backends.github.GithubTeamOAuth2", - "social_core.backends.github_enterprise.GithubEnterpriseOAuth2", - "social_core.backends.github_enterprise.GithubEnterpriseOrganizationOAuth2", - "social_core.backends.github_enterprise.GithubEnterpriseTeamOAuth2", "social_core.backends.google.GoogleOAuth2", - "social_core.backends.google_openidconnect.GoogleOpenIdConnect", ] if extra_backends: From 6c671d98741d5d0fb20cb814e4f2cbd0128b07ea Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 15 Nov 2024 10:30:49 +0100 Subject: [PATCH 18/22] fix: Use backend to determine token auth mode: basic or post --- .../security/authentication/oauth2/provider.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index 7faa025a53..1a4aaac6f4 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -162,14 +162,21 @@ def _get_state_cookie_name(self) -> str: async def _fetch_user_data(self, authorization_response: str, **oauth_query_params) -> dict: oauth_client = self.new_oauth_client() + token_request_params = {**oauth_query_params} + + auth = None + if self._backend.use_basic_auth(): + auth = httpx.BasicAuth(self.client_id, self.client_secret) + else: + token_request_params["client_secret"] = self.client_secret + token_url, headers, content = oauth_client.prepare_token_request( self._token_endpoint, authorization_response=authorization_response, - **oauth_query_params, + **token_request_params, ) headers.update({"Accept": "application/json"}) - auth = httpx.BasicAuth(self.client_id, self.client_secret) async with httpx.AsyncClient(auth=auth) as session: try: response = await session.post(token_url, headers=headers, content=content) From 6cfe28493b383e301fb33b88041af786605d70e8 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 15 Nov 2024 11:19:55 +0100 Subject: [PATCH 19/22] chore: Code refactor --- .../security/authentication/oauth2/provider.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index 1a4aaac6f4..bad389c09a 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -90,8 +90,12 @@ def authorization_url(self, request: Request) -> Tuple[str, Optional[str]]: redirect_uri = self.get_redirect_uri(request) state = "".join([random.choice(string.ascii_letters) for _ in range(32)]) - oauth2_query_params = dict(state=state, scope=self.scope, redirect_uri=redirect_uri) - oauth2_query_params.update(request.query_params) + oauth2_query_params = { + "state": state, + "scope": self.scope, + "redirect_uri": redirect_uri, + **request.query_params, + } authorization_url = str( self.new_oauth_client().prepare_request_uri(self._authorization_endpoint, **oauth2_query_params) From a1dbc01925dd91716f4996a374649dd28a9868b6 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 18 Nov 2024 12:18:07 +0100 Subject: [PATCH 20/22] feat: Fail when backend is not supported --- .../security/authentication/oauth2/_backends.py | 8 ++++++++ .../security/authentication/oauth2/test_settings.py | 13 +++++++++++++ 2 files changed, 21 insertions(+) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py index 019ba02d00..a663034ae3 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py @@ -64,6 +64,14 @@ def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseO backends.extend(extra_backends) _SUPPORTED_BACKENDS = load_backends(backends, force_load=True) + + for backend in _SUPPORTED_BACKENDS.values(): + if not issubclass(backend, BaseOAuth2): + raise ValueError( + f"Backend {backend} is not a supported OAuth2 backend. " + "Please, make sure it is a subclass of BaseOAuth2." + ) + return _SUPPORTED_BACKENDS diff --git a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py index 370424a9af..d396434706 100644 --- a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py +++ b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py @@ -80,3 +80,16 @@ def test_configure_extra_backends(self): assert extra_provider.name == provider_name assert extra_provider.client_id == "microsoft_client_id" assert extra_provider.client_secret == "microsoft_client_secret" + + def test_configure_non_supported_extra_backends(self): + with pytest.raises(ValueError): + OAuth2Settings( + extra_backends=["social_core.backends.twitter.TwitterOAuth"], + providers=[ + { + "name": "github", + "client_id": "github_client_id", + "client_secret": "github_client_secret", + } + ], + ) From a4814828da80c2e5d0bca425b073a2a483f3b898 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 20 Nov 2024 09:31:45 +0100 Subject: [PATCH 21/22] [ENHANCEMENT] `Argilla frontend`: Add default OAuth button (#5695) # Description This PR adds logic to the sign-in page to include a default OAuth provider button. So, users can sign using whatever provider is defined. Refs https://github.com/argilla-io/argilla/pull/5689 **Type of change** - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Leire Aguirre --- .../components/features/login/OAuthLogin.vue | 28 +++++++++--- .../login/components/HuggingFaceButton.vue | 1 + .../login/components/OAuthLoginButton.vue | 45 +++++++++++++++++++ argilla-frontend/translation/de.js | 1 + argilla-frontend/translation/en.js | 1 + argilla-frontend/translation/es.js | 1 + 6 files changed, 70 insertions(+), 7 deletions(-) create mode 100644 argilla-frontend/components/features/login/components/OAuthLoginButton.vue diff --git a/argilla-frontend/components/features/login/OAuthLogin.vue b/argilla-frontend/components/features/login/OAuthLogin.vue index 9ca0a29b73..3b4053c77e 100644 --- a/argilla-frontend/components/features/login/OAuthLogin.vue +++ b/argilla-frontend/components/features/login/OAuthLogin.vue @@ -1,13 +1,19 @@ @@ -29,6 +35,14 @@ export default { display: flex; flex-direction: column; gap: $base-space * 3; + &__providers { + display: flex; + flex-direction: column; + gap: $base-space; + justify-content: center; + padding: 0; + list-style: none; + } } } diff --git a/argilla-frontend/components/features/login/components/HuggingFaceButton.vue b/argilla-frontend/components/features/login/components/HuggingFaceButton.vue index e50eaa90fd..d825ce9bb6 100644 --- a/argilla-frontend/components/features/login/components/HuggingFaceButton.vue +++ b/argilla-frontend/components/features/login/components/HuggingFaceButton.vue @@ -16,6 +16,7 @@ export default { background: var(--color-black); color: var(--color-white); width: 100%; + min-height: $base-space * 6; padding: calc($base-space / 2) $base-space * 4; justify-content: center; &:hover { diff --git a/argilla-frontend/components/features/login/components/OAuthLoginButton.vue b/argilla-frontend/components/features/login/components/OAuthLoginButton.vue new file mode 100644 index 0000000000..3f906066e9 --- /dev/null +++ b/argilla-frontend/components/features/login/components/OAuthLoginButton.vue @@ -0,0 +1,45 @@ + + + + diff --git a/argilla-frontend/translation/de.js b/argilla-frontend/translation/de.js index 8226fa127d..df848feb10 100644 --- a/argilla-frontend/translation/de.js +++ b/argilla-frontend/translation/de.js @@ -117,6 +117,7 @@ export default { button: { ignore_and_continue: "Ignorieren und fortfahren", login: "Anmelden", + signin_with_provider: "Anmeldung bei {provider} starten", "hf-login": "Mit Hugging Face anmelden", sign_in_with_username: "Mit Benutzername anmelden", cancel: "Abbrechen", diff --git a/argilla-frontend/translation/en.js b/argilla-frontend/translation/en.js index 448064600b..d5fe9957d8 100644 --- a/argilla-frontend/translation/en.js +++ b/argilla-frontend/translation/en.js @@ -115,6 +115,7 @@ export default { button: { ignore_and_continue: "Ignore and continue", login: "Sign in", + signin_with_provider: "Sign in with {provider}", "hf-login": "Sign in with Hugging Face", sign_in_with_username: "Sign in with username", cancel: "Cancel", diff --git a/argilla-frontend/translation/es.js b/argilla-frontend/translation/es.js index 3bc4d8c922..5e1ac2fee9 100644 --- a/argilla-frontend/translation/es.js +++ b/argilla-frontend/translation/es.js @@ -114,6 +114,7 @@ export default { button: { ignore_and_continue: "Ignorar y continuar", login: "Iniciar sesión", + signin_with_provider: "Iniciar sesión con {provider}", "hf-login": "Iniciar sesión con Hugging Face", sign_in_with_username: "Iniciar sesión con usuario", cancel: "Cancelar", From 7f114781215bbc37bfe00a981af96b2b6a5a49c2 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 21 Nov 2024 09:25:05 +0100 Subject: [PATCH 22/22] [DOCS] OAuth2 configuration (#5694) # Description Docs for changes included in https://github.com/argilla-io/argilla/pull/5689 **Type of change** - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- ...how-to-configure-argilla-on-huggingface.md | 6 +- .../reference/argilla-server/configuration.md | 3 +- .../argilla-server/oauth2_configuration.md | 163 ++++++++++++++++++ argilla/mkdocs.yml | 1 + 4 files changed, 168 insertions(+), 5 deletions(-) create mode 100644 argilla/docs/reference/argilla-server/oauth2_configuration.md diff --git a/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md b/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md index 72b5c1ae87..9d55019640 100644 --- a/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md +++ b/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md @@ -51,8 +51,6 @@ To restrict access or change the default behaviour, there's two options: **Modify the `.oauth.yml` configuration file**. You can find and modify this file under the `Files` tab of your Space. The default file looks like this: ```yaml -# Change to `false` to disable HF oauth integration -#enabled: false providers: - name: huggingface @@ -61,10 +59,10 @@ providers: allowed_workspaces: - name: argilla ``` -You can modify two things: +You can: -- Uncomment `enabled: false` to completely disable the Sign in with Hugging Face. If you disable it make sure to set the `USERNAME` and `PASSWORD` Space secrets to be able to login as an `owner`. - Change the list of `allowed` workspaces. +- Rename the `.oauth.yml` file to disable OAuth access. For example if you want to let users join a new workspace `community-initiative`: diff --git a/argilla/docs/reference/argilla-server/configuration.md b/argilla/docs/reference/argilla-server/configuration.md index af61bbde3b..b34f0ea936 100644 --- a/argilla/docs/reference/argilla-server/configuration.md +++ b/argilla/docs/reference/argilla-server/configuration.md @@ -50,9 +50,10 @@ You can set the following environment variables to further configure your server #### Authentication -- `ARGILLA_AUTH_SECRET_KEY`: The secret key used to sign the API token data. You can use `openssl rand -hex 32` to generate a 32 character string to use with this environment variable. By default a random value is generated, so if you are using more than one server worker (or more than one Argilla server) you will need to set the same value for all of them. - `USERNAME`: If provided, the owner username (Default: `None`). - `PASSWORD`: If provided, the owner password (Default: `None`). +- `ARGILLA_AUTH_SECRET_KEY`: The secret key used to sign the API token data. You can use `openssl rand -hex 32` to generate a 32 character string to use with this environment variable. By default a random value is generated, so if you are using more than one server worker (or more than one Argilla server) you will need to set the same value for all of them. +- `ARGILLA_AUTH_OAUTH_CFG`: Path to the OAuth2 configuration file (Default: `$PWD/.oauth.yml`). If `USERNAME` and `PASSWORD` are provided, the owner user will be created with these credentials on the server startup. diff --git a/argilla/docs/reference/argilla-server/oauth2_configuration.md b/argilla/docs/reference/argilla-server/oauth2_configuration.md new file mode 100644 index 0000000000..da14903ed2 --- /dev/null +++ b/argilla/docs/reference/argilla-server/oauth2_configuration.md @@ -0,0 +1,163 @@ +# OAuth2 configuration + +Argilla supports OAuth2 authentication for users. This allows users to authenticate using other services like Google, +GitHub, or Hugging Face. Next sections will guide you through the configuration of the OAuth2 authentication. + +## The OAuth2 configuration file + +The OAuth2 configuration file is a YAML file that contains the configuration for the OAuth2 providers that you want to +enable. The default file name is `.oauth.yml` and it should be placed in the root directory of the Argilla server. You +can also specify a different file name using the `ARGILLA_AUTH_OAUTH_CFG` environment variable. + +The file should have the following structure: + +```yaml +providers: + - name: huggingface + client_id: "" + client_secret: "" + scope: "openid profile" + + - name: google-oauth2 + client_id: "" + client_secret: "" + scope: "openid email profile" + + - name: github + client_id: "" + client_secret: "" + +allowed_workspaces: + - name: argilla + +allow_http_redirect: false +``` + +### Providers + +The `providers` key is a list of dictionaries, each dictionary represents a provider configuration, including the +following fields: + +- `name`: The name of the provider. The available options by default are `huggingface`, `github` and `google-oauth2`. +We will see later how to add more providers not supported by default. +- `client_id`: The client ID provided by the OAuth2 provider. You can get this value by creating an application in the +provider's developer console. This is a required field, but you can also use the +`ARGILLA_OAUTH2__CLIENT_ID` environment variable to set the value. +- `client_secret`: The client secret provided by the OAuth2 provider. You can get this value by creating an application +in the provider's developer console. This is a required field, but you can also use +the `ARGILLA_OAUTH2__CLIENT_SECRET` environment variable to set the value. +- `scope`: The scope of the OAuth2 provider. This is an optional field, and normally you don't need to set it, but +you can use it to request specific permissions from the user access. + +### Allowed Workspaces + +The `allowed_workspaces` key defines the available workspaces when users log in using the OAuth2 provider. This is +a list of `name` fields that should match the workspace name in the Argilla server. By default, the `argilla` workspace +is allowed to authenticate using the OAuth2 provider. + +If the workspace doesn't exist, it will be created automatically on the first server startup. + +### Allow HTTP Redirect + +The `allow_http_redirect` key is a boolean value that allows the OAuth2 provider to redirect the user to an HTTP URL. +By default, this value is set to `false`, and you should set it to `true` only if you are running the Argilla server +behind a proxy that doesn't support HTTPS or if you are running the server locally. + +Enabling this option is not recommended for production environments and should be used only for development purposes. + +## Supported OAuth2 providers configuration + +The following sections will guide you through the configuration of the supported OAuth2 providers. Before diving into +the configuration, you should create an application in the provider's developer console to get the client ID and client +secret. + +A common step when creating an application in the provider's developer console is to set the redirect URI. The +redirect URI is the URL where the OAuth2 provider will redirect the user after the authentication process. + +The redirect URI should be set to the Argilla server URL, followed by `/oauth//callback`. For example, +if the Argilla server is running on `http://localhost:8000`, the redirect URI for provider application should +be `http://localhost:8000/oauth/huggingface/callback`. + +### Hugging Face OAuth2 configuration + +Argilla supports Hugging Face OAuth2 authentication out of the box, and is already configured when running Argilla +on Hugging Face Spaces (See the [Hugging Face Spaces settings](../../getting_started/how-to-configure-argilla-on-huggingface.md) for more information). + +But, if you want to manually configure the Hugging Face OAuth2 provider, you should define the following +fields in the `.oauth.yml` file: + +```yaml + +providers: + - name: huggingface + client_id: "" # You can use the ARGILLA_OAUTH2_HUGGINGFACE_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_HUGGINGFACE_CLIENT_SECRET environment variable + scope: "openid profile" # This field is optional. But this value must be aligned your OAuth2 application created in Hugging Face. + +... +``` + +To get your client ID and client secret, you need to create an [OAuth2 application](https://huggingface.co/settings/applications/new) in the Hugging Face +settings page. + +The minimal scope required for the Hugging Face OAuth2 provider is `openid profile`, so you don't need to +change the `scope` when creating the application. + +### GitHub OAuth2 configuration + +Argilla also supports GitHub OAuth2 authentication out of the box. To configure the GitHub OAuth2 provider, you should +define the following fields in the `.oauth.yml` file: + +```yaml + +providers: + - name: github + client_id: "" # You can use the ARGILLA_OAUTH2_GITHUB_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_GITHUB_CLIENT_SECRET environment variable + +... +``` + +To get your client ID and client secret, you need to register a new [OAuth application](https://github.com/settings/applications/new) in the GitHub settings page. + +### Google OAuth2 configuration + +Argilla also supports Google OAuth2 authentication out of the box. To configure the Google OAuth2 provider, you +should define the following fields in the `.oauth.yml` file: + +```yaml + +providers: + - name: google-oauth2 + client_id: "" # You can use the ARGILLA_OAUTH2_GOOGLE_OAUTH2_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_GOOGLE_OAUTH2_CLIENT_SECRET environment variable + +... +``` + +To get your client ID and client secret, you need to create a new [OAuth2 client](https://console.cloud.google.com/apis/credentials/oauthclient) in the Google Cloud Console. + +### Adding more OAuth2 providers + +If you want to add more OAuth2 providers that are not supported by default, you can do so by adding a new provider +configuration to the `.oauth.yml` file. The Argilla server uses the [Social Auth backends](https://python-social-auth.readthedocs.io/en/latest/backends/index.html) component to define +the provider configuration. You only need to register the provider backend using the `extra_backends` key in +the `.oauth.yml` file. + +For example, to configure the [Apple OAuth2 provider](https://python-social-auth.readthedocs.io/en/latest/backends/apple.html), you should add the following configuration to +the `.oauth.yml` file: + +```yaml + +providers: + - name: apple-id + client_id: "" # You can use the ARGILLA_OAUTH2_APPLE_ID_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_APPLE_ID_CLIENT_SECRET environment variable + +extra_backends: + - social_core.backends.apple.AppleIdAuth # Register the Apple OAuth2 provider backend + +``` + +All the `SOCIAL_AUTH_*` environment variables are supported by the Argilla server, so you can customize the provider +configuration using these environment variables. diff --git a/argilla/mkdocs.yml b/argilla/mkdocs.yml index 645b72f514..bba5cc39d8 100644 --- a/argilla/mkdocs.yml +++ b/argilla/mkdocs.yml @@ -187,6 +187,7 @@ nav: - Python SDK: reference/argilla/ - FastAPI Server: - Server configuration: reference/argilla-server/configuration.md + - OAuth2 configuration: reference/argilla-server/oauth2_configuration.md - Telemetry: - Server Telemetry: reference/argilla-server/telemetry.md - Community: