Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][ENHANCEMENT] argilla server: better OAuth2 integration #5689

Open
wants to merge 31 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
793beae
chore: Remove oauth.enabled attribute
frascuchon Nov 13, 2024
6cd157e
fix: Prefix backend name to state cookie
frascuchon Nov 13, 2024
c81d386
chore: Move provider to provider.py
frascuchon Nov 13, 2024
62915e0
refactor: Add user create validator and remove local schema for oauth…
frascuchon Nov 13, 2024
1e4bda5
chore: Remove Claims
frascuchon Nov 13, 2024
2fbd716
chore: update tests
frascuchon Nov 13, 2024
4d69126
chore: Add more supported backends
frascuchon Nov 13, 2024
ac23249
chore: Add discord oauth
frascuchon Nov 13, 2024
08d3f88
fix: Read argilla env vars first
frascuchon Nov 14, 2024
8773c08
chore: Add default scope for huggingface backend
frascuchon Nov 14, 2024
125889b
fix: Using backend.settings
frascuchon Nov 14, 2024
6e2bba5
fix: Revert code refactor
frascuchon Nov 14, 2024
9ba0f00
chore: Add last-fm support
frascuchon Nov 14, 2024
1f346c7
Merge branch 'refactor/argilla-server/better-oauth2-integration' of g…
frascuchon Nov 14, 2024
7b919d2
refactor: Moving backend logic to a separate module and load backends…
frascuchon Nov 14, 2024
4e91069
refactor: Passing config dict to OAuthSettings constructor
frascuchon Nov 14, 2024
b9f092c
feat: Allow add extra backends
frascuchon Nov 14, 2024
054703a
chore: Configure minimal set of oauth backends
frascuchon Nov 14, 2024
6c671d9
fix: Use backend to determine token auth mode: basic or post
frascuchon Nov 15, 2024
6cfe284
chore: Code refactor
frascuchon Nov 15, 2024
a1dbc01
feat: Fail when backend is not supported
frascuchon Nov 18, 2024
8700b81
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 18, 2024
207e361
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 19, 2024
a481482
[ENHANCEMENT] `Argilla frontend`: Add default OAuth button (#5695)
frascuchon Nov 20, 2024
9a2f374
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 20, 2024
7f11478
[DOCS] OAuth2 configuration (#5694)
frascuchon Nov 21, 2024
d8ec3e7
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 21, 2024
3bde023
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 21, 2024
0bea71e
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 21, 2024
1c5089c
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Nov 29, 2024
b3e2cbe
Merge branch 'develop' into refactor/argilla-server/better-oauth2-int…
frascuchon Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions argilla-frontend/components/features/login/OAuthLogin.vue
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
<template>
<div class="oauth__container" v-if="providers.length">
<BaseSeparator />

<div v-for="provider in providers" :key="provider.name">
<HuggingFaceButton
v-if="provider.isHuggingFace"
@click="authorize(provider.name)"
/>
</div>
<ul class="oauth__container__providers">
<li v-for="provider in providers" :key="provider.name">
<HuggingFaceButton
v-if="provider.isHuggingFace"
@click="authorize(provider.name)"
/>
<OAuthLoginButton
v-else
:provider="provider.name"
@click="authorize(provider.name)"
/>
</li>
</ul>
</div>
</template>

Expand All @@ -29,6 +35,14 @@ export default {
display: flex;
flex-direction: column;
gap: $base-space * 3;
&__providers {
display: flex;
flex-direction: column;
gap: $base-space;
justify-content: center;
padding: 0;
list-style: none;
}
}
}
</style>
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export default {
background: var(--color-black);
color: var(--color-white);
width: 100%;
min-height: $base-space * 6;
padding: calc($base-space / 2) $base-space * 4;
justify-content: center;
&:hover {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
<template>
<BaseButton class="sign-in-button" @click="$emit('click')">
{{ signinText }}
</BaseButton>
</template>
<script>
export default {
name: "OAuthLoginButton",
props: {
provider: {
type: String,
required: true,
},
},
computed: {
providerName() {
return this.provider.charAt(0).toUpperCase() + this.provider.slice(1);
},
signinText() {
return this.$t("button.signin_with_provider", {
provider: this.providerName,
});
},
},
};
</script>

<style lang="scss" scoped>
.sign-in-button {
@extend %button !optional;
background: var(--color-black);
color: var(--color-white);
width: 100%;
min-height: $base-space * 6;
padding: calc($base-space / 2) $base-space * 4;
justify-content: center;
&:hover {
background: hsl(from var(--color-black) h s l / 80%);
}
svg {
width: 30px;
height: auto;
}
}
</style>
1 change: 1 addition & 0 deletions argilla-frontend/translation/de.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ export default {
button: {
ignore_and_continue: "Ignorieren und fortfahren",
login: "Anmelden",
signin_with_provider: "Anmeldung bei {provider} starten",
"hf-login": "Mit Hugging Face anmelden",
sign_in_with_username: "Mit Benutzername anmelden",
cancel: "Abbrechen",
Expand Down
1 change: 1 addition & 0 deletions argilla-frontend/translation/en.js
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export default {
button: {
ignore_and_continue: "Ignore and continue",
login: "Sign in",
signin_with_provider: "Sign in with {provider}",
"hf-login": "Sign in with Hugging Face",
sign_in_with_username: "Sign in with username",
cancel: "Cancel",
Expand Down
1 change: 1 addition & 0 deletions argilla-frontend/translation/es.js
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ export default {
button: {
ignore_and_continue: "Ignorar y continuar",
login: "Iniciar sesión",
signin_with_provider: "Iniciar sesión con {provider}",
"hf-login": "Iniciar sesión con Hugging Face",
sign_in_with_username: "Iniciar sesión con nombre de usuario",
cancel: "Cancelar",
Expand Down
3 changes: 0 additions & 3 deletions argilla-server/src/argilla_server/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,6 @@ def _show_telemetry_warning():
async def _create_oauth_allowed_workspaces(db: AsyncSession):
from argilla_server.security.settings import settings as security_settings

if not security_settings.oauth.enabled:
return

for allowed_workspace in security_settings.oauth.allowed_workspaces:
if await Workspace.get_by(db, name=allowed_workspace.name) is None:
_LOGGER.info(f"Creating workspace with name {allowed_workspace.name!r}")
Expand Down
40 changes: 11 additions & 29 deletions argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from fastapi import APIRouter, Depends, Request, Path
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.api.schemas.v1.oauth2 import Provider, Providers, Token
from argilla_server.api.schemas.v1.users import UserCreate
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.enums import UserRole
from argilla_server.errors.future import NotFoundError
from argilla_server.models import User
from pydantic import Field
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings

router = APIRouter(prefix="/oauth2", tags=["Authentication"])


class UserOAuthCreate(UserCreate):
"""This schema is used to validate the creation of a new user by using the oauth userinfo"""

username: str = Field(min_length=1)
role: Optional[UserRole]
password: Optional[str] = None


def get_provider_by_name_or_raise(provider: str = Path()) -> OAuth2ClientProvider:
if not settings.oauth.enabled:
raise NotFoundError(message="OAuth2 is not enabled")

if provider in settings.oauth.providers:
try:
return settings.oauth.providers[provider]

raise NotFoundError(message=f"OAuth Provider '{provider}' not found")
except KeyError:
raise NotFoundError(message=f"OAuth Provider '{provider}' not found")


@router.get("/providers", response_model=Providers)
def list_providers() -> Providers:
if not settings.oauth.enabled:
return Providers(items=[])

return Providers(items=[Provider(name=provider_name) for provider_name in settings.oauth.providers])
providers = [Provider(name=provider_name) for provider_name in settings.oauth.providers]
return Providers(items=providers)


@router.get("/providers/{provider}/authentication")
def get_authentication(
async def get_authentication(
request: Request,
provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise),
) -> RedirectResponse:
Expand All @@ -72,7 +55,8 @@ async def get_access_token(
provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise),
db: AsyncSession = Depends(get_async_db),
) -> Token:
userinfo = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims)
user_data = await provider.get_user_data(request)
userinfo = UserInfo(user_data)

if not userinfo.username:
raise RuntimeError("OAuth error: Missing username")
Expand All @@ -81,11 +65,9 @@ async def get_access_token(
if user is None:
user = await accounts.create_user_with_random_password(
db,
**UserOAuthCreate(
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
).model_dump(exclude_unset=True),
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
)

Expand Down
25 changes: 13 additions & 12 deletions argilla-server/src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from argilla_server.models import User, Workspace, WorkspaceUser
from argilla_server.security.authentication.jwt import JWT
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.validators.users import UserCreateValidator


async def create_workspace_user(db: AsyncSession, workspace_user_attrs: dict) -> WorkspaceUser:
Expand All @@ -52,7 +53,7 @@ async def list_workspaces(db: AsyncSession) -> List[Workspace]:
return result.scalars().all()


async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> List[Workspace]:
async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> Sequence[Workspace]:
result = await db.execute(
select(Workspace)
.join(WorkspaceUser)
Expand Down Expand Up @@ -102,22 +103,22 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U
return result.scalars().all()


# TODO: After removing API v0 implementation we can remove the workspaces attribute.
# With API v1 the workspaces will be created doing additional requests to other endpoints for it.
async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User:
if await get_user_by_username(db, user_attrs["username"]) is not None:
raise NotUniqueError(f"User username `{user_attrs['username']}` is not unique")

user = await User.create(
db,
async def create_user(
db: AsyncSession,
user_attrs: dict,
workspaces: Union[List[str], None] = None,
) -> User:
new_user = User(
first_name=user_attrs["first_name"],
last_name=user_attrs["last_name"],
username=user_attrs["username"],
role=user_attrs["role"],
password_hash=hash_password(user_attrs["password"]),
autocommit=False,
)

await UserCreateValidator.validate(db, user=new_user)

await new_user.save(db, autocommit=False)
if workspaces is not None:
for workspace_name in workspaces:
workspace = await Workspace.get_by(db, name=workspace_name)
Expand All @@ -127,13 +128,13 @@ async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List
await WorkspaceUser.create(
db,
workspace_id=workspace.id,
user_id=user.id,
user_id=new_user.id,
autocommit=False,
)

await db.commit()

return user
return new_user


async def create_user_with_random_password(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider # noqa
from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider # noqa
from argilla_server.security.authentication.oauth2.settings import OAuth2Settings # noqa

__all__ = ["OAuth2Settings", "OAuth2ClientProvider"]
Loading
Loading