Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding support for login with OpenIDC (#50) #87

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backend/alembic/versions/c8009ed33089_init_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
branch_labels = None
depends_on = None

#: Long tokens -- 64kbytes should be enough for everyone
TOKEN_SIZE = 64 * 1024


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
Expand All @@ -35,9 +38,9 @@ def upgrade():
sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False),
sa.Column("user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False),
sa.Column("oauth_name", sa.String(length=100), nullable=False),
sa.Column("access_token", sa.String(length=1024), nullable=False),
sa.Column("access_token", sa.String(length=TOKEN_SIZE), nullable=False),
sa.Column("expires_at", sa.Integer(), nullable=True),
sa.Column("refresh_token", sa.String(length=1024), nullable=True),
sa.Column("refresh_token", sa.String(length=TOKEN_SIZE), nullable=True),
sa.Column("account_id", sa.String(length=320), nullable=False),
sa.Column("account_email", sa.String(length=320), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"),
Expand Down
26 changes: 24 additions & 2 deletions backend/app/api/api_v1/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from fastapi import APIRouter
from httpx_oauth.clients.openid import OpenID

from app.api.api_v1.endpoints import adminmsgs
from app.api.api_v1.endpoints import adminmsgs, auth
from app.core.auth import auth_backend_bearer, auth_backend_cookie, fastapi_users
from app.schemas.user import UserCreate, UserRead, UserUpdate
from app.core.config import settings
from app.schemas.user import UserRead, UserUpdate

api_router = APIRouter()
api_router.include_router(adminmsgs.router, prefix="/adminmsgs", tags=["adminmsgs"])
Expand All @@ -13,6 +15,7 @@
api_router.include_router(
fastapi_users.get_auth_router(auth_backend_cookie), prefix="/auth/cookie", tags=["auth"]
)
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
# api_router.include_router(
# fastapi_users.get_register_router(UserRead, UserCreate),
# prefix="/auth",
Expand All @@ -33,3 +36,22 @@
prefix="/users",
tags=["users"],
)

# For now, we only provide oauth clients for cookie-based authentication.
for config in settings.OAUTH2_PROVIDERS:
oauth_client = OpenID(
client_id=config.client_id,
client_secret=config.client_secret,
openid_configuration_endpoint=str(config.config_url),
)
oauth_router = fastapi_users.get_oauth_router(
oauth_client=oauth_client,
backend=auth_backend_cookie,
state_secret=settings.SECRET_KEY,
associate_by_email=True,
)
api_router.include_router(
oauth_router,
prefix=f"/auth/external/cookie/{config.name}",
tags=["auth"],
)
18 changes: 18 additions & 0 deletions backend/app/api/api_v1/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

from app import crud, models, schemas
from app.api import deps
from app.core import config

router = APIRouter()


@router.get("/oauth2-providers", response_model=list[schemas.OAuth2ProviderPublic])
async def list_oauth2_providers() -> list[schemas.OAuth2ProviderPublic]:
"""Retrieve all admin messages"""
providers = [
schemas.OAuth2ProviderPublic.model_validate(obj.model_dump())
for obj in config.settings.OAUTH2_PROVIDERS
]
return providers
14 changes: 12 additions & 2 deletions backend/app/core/auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from typing import Any

import redis.asyncio
from fastapi import Depends, Request
from fastapi import Depends, Request, Response
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin
from fastapi_users.authentication import (
AuthenticationBackend,
Expand Down Expand Up @@ -43,7 +44,16 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db

bearer_transport = BearerTransport(tokenUrl=f"{settings.API_V1_STR}/auth/login")

cookie_transport = CookieTransport(cookie_max_age=settings.SESSION_EXPIRE_MINUTES * 60)

class CookieRedirectTransport(CookieTransport):
async def get_login_response(self, token: str) -> Response:
response = await super().get_login_response(token)
response.status_code = 302
response.headers["Location"] = "/profile"
return response


cookie_transport = CookieRedirectTransport(cookie_max_age=settings.SESSION_EXPIRE_MINUTES * 60)

redis_obj = redis.asyncio.from_url(settings.REDIS_URL, decode_responses=True)

Expand Down
7 changes: 6 additions & 1 deletion backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import secrets
from typing import Any

from pydantic import AnyHttpUrl, EmailStr, HttpUrl, PostgresDsn, field_validator
from pydantic import AnyHttpUrl, BaseModel, EmailStr, HttpUrl, PostgresDsn, field_validator
from pydantic_core.core_schema import ValidationInfo
from pydantic_settings import BaseSettings, SettingsConfigDict

from app.schemas import OAuth2ProviderConfig


class Settings(BaseSettings):
model_config = SettingsConfigDict(
Expand Down Expand Up @@ -95,6 +97,9 @@ def assemble_cors_origins(cls, v: str | list[str]) -> list[str] | str: # pragma
#: Email of test users, ignored.
EMAIL_TEST_USER: EmailStr = "[email protected]" # type: ignore

#: OAuth2 providers
OAUTH2_PROVIDERS: list[OAuth2ProviderConfig] = []

# -- Database Configuration ----------------------------------------------

# Note that when os.environ["CI"] is "true" then we will use an in-memory
Expand Down
3 changes: 3 additions & 0 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from app.core.config import settings
from app.db.init_db import create_superuser

if settings.DEBUG:
logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger(__name__)

app = FastAPI(
Expand Down
16 changes: 13 additions & 3 deletions backend/app/models/user.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
from typing import List
from typing import TYPE_CHECKING, List, Optional

from fastapi_users_db_sqlalchemy import (
SQLAlchemyBaseOAuthAccountTableUUID,
SQLAlchemyBaseUserTableUUID,
)
from sqlalchemy.orm import Mapped, relationship
from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.db.base import Base

#: Long tokens -- 64kbytes should be enough for everyone
TOKEN_SIZE = 64 * 1024


class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
if TYPE_CHECKING: # pragma: no cover
access_token: str
refresh_token: Optional[str]
else:
# We need to increase the token size for the OAuthAccount table.
access_token: Mapped[str] = mapped_column(String(TOKEN_SIZE), nullable=False)
refresh_token: Mapped[Optional[str]] = mapped_column(String(TOKEN_SIZE), nullable=True)


class User(SQLAlchemyBaseUserTableUUID, Base):
Expand Down
1 change: 1 addition & 0 deletions backend/app/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from app.schemas.adminmsg import AdminMessageCreate, AdminMessageRead, AdminMessageUpdate # noqa
from app.schemas.auth import OAuth2ProviderConfig, OAuth2ProviderPublic # noqa
from app.schemas.user import UserCreate, UserRead, UserUpdate # noqa
25 changes: 25 additions & 0 deletions backend/app/schemas/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic import BaseModel, HttpUrl


class OAuth2ProviderBase(BaseModel):
"""Base class for OAuth2 providers infos."""

#: Name of the identity provider.
name: str
#: Label to display to users
label: str


class OAuth2ProviderPublic(OAuth2ProviderBase):
"""Information exposed via API."""


class OAuth2ProviderConfig(OAuth2ProviderBase):
"""OAuth2 provider configuration with client secrets."""

#: Configuration URL of the provider.
config_url: HttpUrl
#: Client ID to use.
client_id: str
#: Client secret to use.
client_secret: str
5 changes: 4 additions & 1 deletion backend/env.dev
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Application configuration
SERVER_NAME=localhost
SERVER_HOST=http://localhost:8080
SERVER_HOST=http://localhost:8081
BACKEND_CORS_ORIGINS=["http://localhost:8081"]
DEBUG=1

Expand All @@ -17,6 +17,9 @@ BACKEND_PREFIX_MEHARI=http://localhost:3002
BACKEND_PREFIX_VIGUNO=http://localhost:3003
BACKEND_PREFIX_NGINX=http://localhost:3004

# Access to redis as it runs Docker Compose.
REDIS_URL=redis://localhost:3030

# Superuser to setup on startup
[email protected]
FIRST_SUPERUSER_PASSWORD=password
27 changes: 27 additions & 0 deletions frontend/src/api/auth.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import { API_V1_BASE_PREFIX } from '@/api/common'

export interface OAuth2Provider {
name: string
label: string
url: string
}

interface OAuth2LoginUrlResponse {
authorization_url: string
}

/** Access to the authentication-related part of the API.
*/
export class AuthClient {
Expand Down Expand Up @@ -37,4 +47,21 @@ export class AuthClient {
})
return await response.text()
}

async fetchOAuth2Providers(): Promise<OAuth2Provider[]> {
const response = await fetch(`${this.apiBaseUrl}auth/oauth2-providers`, {
method: 'GET'
})
return await response.json()
}

async fetchOAuth2LoginUrl(provider: OAuth2Provider, redirectTo?: string | null): Promise<string> {
let url = `${this.apiBaseUrl}/auth/external/cookie/${provider.name}/authorize`
if (redirectTo) {
url += `?redirect_to=${encodeURIComponent(redirectTo)}`
}
const response = await fetch(url, { method: 'GET' })
const response_json: OAuth2LoginUrlResponse = await response.json()
return response_json.authorization_url
}
}
12 changes: 11 additions & 1 deletion frontend/src/stores/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import { defineStore } from 'pinia'
import { computed, ref } from 'vue'

import { AuthClient, type OAuth2Provider } from '@/api/auth'
import { UnauthenticatedError, UsersClient } from '@/api/users'
import { StoreState } from '@/stores/misc'

Expand All @@ -19,6 +20,9 @@ export const useUserStore = defineStore('user', () => {
/* The current store state. */
const storeState = ref<StoreState>(StoreState.Initial)

/* The available OAuth2 providers. */
const oauth2Providers = ref<OAuth2Provider[]>([])

/* The current user, null if none, undefined if not set already */
const currentUser = ref<UserData | null | undefined>(undefined)

Expand All @@ -33,7 +37,12 @@ export const useUserStore = defineStore('user', () => {
return // do not initialize twice
}

await loadCurrentUser()
await Promise.all([loadOAuth2Endpoints(), loadCurrentUser()])
}

const loadOAuth2Endpoints = async () => {
const client = new AuthClient()
oauth2Providers.value = await client.fetchOAuth2Providers()
}

const loadCurrentUser = async () => {
Expand All @@ -57,6 +66,7 @@ export const useUserStore = defineStore('user', () => {

return {
storeState,
oauth2Providers,
currentUser,
isAuthenticated,
loadCurrentUser,
Expand Down
Loading