Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][ENHANCEMENT] argilla server: better OAuth2 integration #5689

Open
wants to merge 31 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
793beae
chore: Remove oauth.enabled attribute
frascuchon Nov 13, 2024
6cd157e
fix: Prefix backend name to state cookie
frascuchon Nov 13, 2024
c81d386
chore: Move provider to provider.py
frascuchon Nov 13, 2024
62915e0
refactor: Add user create validator and remove local schema for oauth…
frascuchon Nov 13, 2024
1e4bda5
chore: Remove Claims
frascuchon Nov 13, 2024
2fbd716
chore: update tests
frascuchon Nov 13, 2024
4d69126
chore: Add more supported backends
frascuchon Nov 13, 2024
ac23249
chore: Add discord oauth
frascuchon Nov 13, 2024
08d3f88
fix: Read argilla env vars first
frascuchon Nov 14, 2024
8773c08
chore: Add default scope for huggingface backend
frascuchon Nov 14, 2024
125889b
fix: Using backend.settings
frascuchon Nov 14, 2024
6e2bba5
fix: Revert code refactor
frascuchon Nov 14, 2024
9ba0f00
chore: Add last-fm support
frascuchon Nov 14, 2024
1f346c7
Merge branch 'refactor/argilla-server/better-oauth2-integration' of g…
frascuchon Nov 14, 2024
7b919d2
refactor: Moving backend logic to a separate module and load backends…
frascuchon Nov 14, 2024
4e91069
refactor: Passing config dict to OAuthSettings constructor
frascuchon Nov 14, 2024
b9f092c
feat: Allow add extra backends
frascuchon Nov 14, 2024
054703a
chore: Configure minimal set of oauth backends
frascuchon Nov 14, 2024
6c671d9
fix: Use backend to determine token auth mode: basic or post
frascuchon Nov 15, 2024
6cfe284
chore: Code refactor
frascuchon Nov 15, 2024
a1dbc01
feat: Fail when backend is not supported
frascuchon Nov 18, 2024
8700b81
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 18, 2024
207e361
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 19, 2024
a481482
[ENHANCEMENT] `Argilla frontend`: Add default OAuth button (#5695)
frascuchon Nov 20, 2024
9a2f374
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 20, 2024
7f11478
[DOCS] OAuth2 configuration (#5694)
frascuchon Nov 21, 2024
d8ec3e7
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 21, 2024
3bde023
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 21, 2024
0bea71e
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 21, 2024
1c5089c
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 29, 2024
b3e2cbe
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions argilla-server/src/argilla_server/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,6 @@ def _show_telemetry_warning():
async def _create_oauth_allowed_workspaces(db: AsyncSession):
from argilla_server.security.settings import settings as security_settings

if not security_settings.oauth.enabled:
return

for allowed_workspace in security_settings.oauth.allowed_workspaces:
if await Workspace.get_by(db, name=allowed_workspace.name) is None:
_LOGGER.info(f"Creating workspace with name {allowed_workspace.name!r}")
Expand Down
41 changes: 11 additions & 30 deletions argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from fastapi import APIRouter, Depends, Request, Path
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server import telemetry
from argilla_server.api.schemas.v1.oauth2 import Provider, Providers, Token
from argilla_server.api.schemas.v1.users import UserCreate
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.enums import UserRole
from argilla_server.errors.future import NotFoundError
from argilla_server.models import User
from argilla_server.pydantic_v1 import Field
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings

router = APIRouter(prefix="/oauth2", tags=["Authentication"])


class UserOAuthCreate(UserCreate):
"""This schema is used to validate the creation of a new user by using the oauth userinfo"""

username: str = Field(min_length=1)
role: Optional[UserRole]
password: Optional[str] = None


def get_provider_by_name_or_raise(provider: str = Path()) -> OAuth2ClientProvider:
if not settings.oauth.enabled:
raise NotFoundError(message="OAuth2 is not enabled")

if provider in settings.oauth.providers:
try:
return settings.oauth.providers[provider]

raise NotFoundError(message=f"OAuth Provider '{provider}' not found")
except KeyError:
raise NotFoundError(message=f"OAuth Provider '{provider}' not found")


@router.get("/providers", response_model=Providers)
def list_providers() -> Providers:
if not settings.oauth.enabled:
return Providers(items=[])

return Providers(items=[Provider(name=provider_name) for provider_name in settings.oauth.providers])
providers = [Provider(name=provider_name) for provider_name in settings.oauth.providers]
return Providers(items=providers)


@router.get("/providers/{provider}/authentication")
def get_authentication(
async def get_authentication(
request: Request,
provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise),
) -> RedirectResponse:
Expand All @@ -73,7 +55,8 @@ async def get_access_token(
provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise),
db: AsyncSession = Depends(get_async_db),
) -> Token:
userinfo = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims)
user_data = await provider.get_user_data(request)
userinfo = UserInfo(user_data)

if not userinfo.username:
raise RuntimeError("OAuth error: Missing username")
Expand All @@ -82,11 +65,9 @@ async def get_access_token(
if user is None:
user = await accounts.create_user_with_random_password(
db,
**UserOAuthCreate(
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
).dict(exclude_unset=True),
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
)

Expand Down
27 changes: 14 additions & 13 deletions argilla-server/src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from argilla_server.models import User, Workspace, WorkspaceUser
from argilla_server.security.authentication.jwt import JWT
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.validators.users import UserCreateValidator

_CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")

Expand Down Expand Up @@ -54,7 +55,7 @@ async def list_workspaces(db: AsyncSession) -> List[Workspace]:
return result.scalars().all()


async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> List[Workspace]:
async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> Sequence[Workspace]:
result = await db.execute(
select(Workspace)
.join(WorkspaceUser)
Expand Down Expand Up @@ -104,22 +105,22 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U
return result.scalars().all()


# TODO: After removing API v0 implementation we can remove the workspaces attribute.
# With API v1 the workspaces will be created doing additional requests to other endpoints for it.
async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User:
if await get_user_by_username(db, user_attrs["username"]) is not None:
raise NotUniqueError(f"User username `{user_attrs['username']}` is not unique")

user = await User.create(
db,
async def create_user(
db: AsyncSession,
user_attrs: dict,
workspaces: Union[List[str], None] = None,
) -> User:
new_user = User(
first_name=user_attrs["first_name"],
last_name=user_attrs["last_name"],
username=user_attrs["username"],
role=user_attrs["role"],
password_hash=hash_password(user_attrs["password"]),
autocommit=False,
)

await UserCreateValidator.validate(db, user=new_user)

await new_user.save(db, autocommit=False)
if workspaces is not None:
for workspace_name in workspaces:
workspace = await Workspace.get_by(db, name=workspace_name)
Expand All @@ -128,14 +129,14 @@ async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List

await WorkspaceUser.create(
db,
workspace_id=workspace.id,
user_id=user.id,
workspace=workspace,
user=new_user,
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
autocommit=False,
)

await db.commit()

return user
return new_user


async def create_user_with_random_password(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider # noqa
from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider # noqa
from argilla_server.security.authentication.oauth2.settings import OAuth2Settings # noqa

__all__ = ["OAuth2Settings", "OAuth2ClientProvider"]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser

from argilla_server.security.authentication.jwt import JWT
from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider
from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider

Check warning on line 22 in argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py#L22

Added line #L22 was not covered by tests
from argilla_server.security.authentication.userinfo import UserInfo


Expand All @@ -39,7 +39,4 @@
token_data = JWT.decode(credentials.credentials)
user = UserInfo(token_data)

provider = self.providers.get(user.get("provider"))
claims = provider.claims if provider else {}

return AuthCredentials(user.pop("scope", [])), user.use_claims(claims)
return AuthCredentials(user.pop("scope", [])), user

Check warning on line 42 in argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py#L42

Added line #L42 was not covered by tests
Loading