From fe8561da948dfa6f93da6c594b3a558a802fbdfc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 16 May 2024 22:19:45 +0200 Subject: [PATCH 01/10] add session based auth workflow --- ragna/_utils.py | 54 +++- ragna/core/_rag.py | 45 ++-- ragna/core/_utils.py | 10 - ragna/deploy/__init__.py | 12 +- ragna/deploy/_api.py | 30 +-- ragna/deploy/_auth.py | 344 ++++++++++++++++++++++++ ragna/deploy/_authentication.py | 132 --------- ragna/deploy/_config.py | 20 +- ragna/deploy/_core.py | 8 + ragna/deploy/_database.py | 74 +++-- ragna/deploy/_engine.py | 16 +- ragna/deploy/_key_value_store.py | 84 ++++++ ragna/deploy/_orm.py | 3 +- ragna/deploy/_schemas.py | 5 + ragna/deploy/_templates/__init__.py | 11 + ragna/deploy/_templates/base.html | 27 ++ ragna/deploy/_templates/basic_auth.html | 24 ++ ragna/deploy/_templates/oauth.html | 7 + ragna/deploy/_ui/api_wrapper.py | 4 +- ragna/deploy/_ui/app.py | 3 + scripts/add_chats.py | 27 +- tests/deploy/api/utils.py | 35 +++ 22 files changed, 735 insertions(+), 240 deletions(-) create mode 100644 ragna/deploy/_auth.py delete mode 100644 ragna/deploy/_authentication.py create mode 100644 ragna/deploy/_key_value_store.py create mode 100644 ragna/deploy/_templates/__init__.py create mode 100644 ragna/deploy/_templates/base.html create mode 100644 ragna/deploy/_templates/basic_auth.html create mode 100644 ragna/deploy/_templates/oauth.html create mode 100644 tests/deploy/api/utils.py diff --git a/ragna/_utils.py b/ragna/_utils.py index efa6f95c..fe52349f 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -1,10 +1,26 @@ +import contextlib import functools +import getpass import inspect import os import sys import threading from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + TypeVar, + Union, + cast, +) + +from starlette.concurrency import iterate_in_threadpool, run_in_threadpool + +T = TypeVar("T") _LOCAL_ROOT = ( Path(os.environ.get("RAGNA_LOCAL_ROOT", "~/.cache/ragna")).expanduser().resolve() @@ -110,3 +126,39 @@ def is_debugging() -> bool: if any(part.startswith(name) for part in parts): return True return False + + +def as_awaitable( + fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any, **kwargs: Any +) -> Awaitable[T]: + if inspect.iscoroutinefunction(fn): + fn = cast(Callable[..., Awaitable[T]], fn) + awaitable = fn(*args, **kwargs) + else: + fn = cast(Callable[..., T], fn) + awaitable = run_in_threadpool(fn, *args, **kwargs) + + return awaitable + + +def as_async_iterator( + fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], + *args: Any, + **kwargs: Any, +) -> AsyncIterator[T]: + if inspect.isasyncgenfunction(fn): + fn = cast(Callable[..., AsyncIterator[T]], fn) + async_iterator = fn(*args, **kwargs) + else: + fn = cast(Callable[..., Iterator[T]], fn) + async_iterator = iterate_in_threadpool(fn(*args, **kwargs)) + + return async_iterator + + +def default_user() -> str: + with contextlib.suppress(Exception): + return getpass.getuser() + with contextlib.suppress(Exception): + return os.getlogin() + return "Bodil" diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index b11e1061..b3a0ec3d 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -2,7 +2,6 @@ import contextlib import datetime -import inspect import itertools import uuid from collections import defaultdict @@ -24,11 +23,12 @@ import pydantic import pydantic_core from fastapi import status -from starlette.concurrency import iterate_in_threadpool, run_in_threadpool + +from ragna._utils import as_async_iterator, as_awaitable, default_user from ._components import Assistant, Component, Message, MessageRole, SourceStorage from ._document import Document, LocalDocument -from ._utils import RagnaException, default_user, merge_models +from ._utils import RagnaException, merge_models if TYPE_CHECKING: from ragna.deploy import Config @@ -241,11 +241,11 @@ async def prepare(self) -> Message: raise RagnaException( "Chat is already prepared", chat=self, - http_status_code=400, + http_status_code=status.HTTP_400_BAD_REQUEST, detail=RagnaException.EVENT, ) - await self._run(self.source_storage.store, self.documents) + await self._as_awaitable(self.source_storage.store, self.documents) self._prepared = True welcome = Message( @@ -269,16 +269,18 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: raise RagnaException( "Chat is not prepared", chat=self, - http_status_code=400, + http_status_code=status.HTTP_400_BAD_REQUEST, detail=RagnaException.EVENT, ) self._messages.append(Message(content=prompt, role=MessageRole.USER)) - sources = await self._run(self.source_storage.retrieve, self.documents, prompt) + sources = await self._as_awaitable( + self.source_storage.retrieve, self.documents, prompt + ) answer = Message( - content=self._run_gen(self.assistant.answer, prompt, sources), + content=self._as_async_iterator(self.assistant.answer, prompt, sources), role=MessageRole.ASSISTANT, sources=sources, ) @@ -416,34 +418,17 @@ def format_error( raise RagnaException("\n".join(parts)) - async def _run( + def _as_awaitable( self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any - ) -> T: - kwargs = self._unpacked_params[fn] - if inspect.iscoroutinefunction(fn): - fn = cast(Callable[..., Awaitable[T]], fn) - coro = fn(*args, **kwargs) - else: - fn = cast(Callable[..., T], fn) - coro = run_in_threadpool(fn, *args, **kwargs) + ) -> Awaitable[T]: + return as_awaitable(fn, *args, **self._unpacked_params[fn]) - return await coro - - async def _run_gen( + def _as_async_iterator( self, fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], *args: Any, ) -> AsyncIterator[T]: - kwargs = self._unpacked_params[fn] - if inspect.isasyncgenfunction(fn): - fn = cast(Callable[..., AsyncIterator[T]], fn) - async_gen = fn(*args, **kwargs) - else: - fn = cast(Callable[..., Iterator[T]], fn) - async_gen = iterate_in_threadpool(fn(*args, **kwargs)) - - async for item in async_gen: - yield item + return as_async_iterator(fn, *args, **self._unpacked_params[fn]) async def __aenter__(self) -> Chat: await self.prepare() diff --git a/ragna/core/_utils.py b/ragna/core/_utils.py index 897a69d6..b70c4f2a 100644 --- a/ragna/core/_utils.py +++ b/ragna/core/_utils.py @@ -1,10 +1,8 @@ from __future__ import annotations import abc -import contextlib import enum import functools -import getpass import importlib import importlib.metadata import os @@ -125,14 +123,6 @@ def __repr__(self) -> str: return self._name -def default_user() -> str: - with contextlib.suppress(Exception): - return getpass.getuser() - with contextlib.suppress(Exception): - return os.getlogin() - return "Ragna" - - def merge_models( model_name: str, *models: Type[pydantic.BaseModel], diff --git a/ragna/deploy/__init__.py b/ragna/deploy/__init__.py index f3a86255..70bf42ce 100644 --- a/ragna/deploy/__init__.py +++ b/ragna/deploy/__init__.py @@ -1,11 +1,17 @@ __all__ = [ - "Authentication", + "Auth", "Config", - "RagnaDemoAuthentication", + "DummyBasicAuth", + "GithubOAuth", + "InMemoryKeyValueStore", + "KeyValueStore", + "NoAuth", + "RedisKeyValueStore", ] -from ._authentication import Authentication, RagnaDemoAuthentication +from ._auth import Auth, DummyBasicAuth, GithubOAuth, NoAuth from ._config import Config +from ._key_value_store import InMemoryKeyValueStore, KeyValueStore, RedisKeyValueStore # isort: split diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 4de2737c..b76def56 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -2,35 +2,25 @@ from typing import Annotated, AsyncIterator import pydantic -from fastapi import ( - APIRouter, - Body, - Depends, - UploadFile, -) +from fastapi import APIRouter, Body, UploadFile from fastapi.responses import StreamingResponse from ragna._compat import anext -from ragna.core._utils import default_user from . import _schemas as schemas +from ._auth import UserDependency from ._engine import Engine def make_router(engine: Engine) -> APIRouter: router = APIRouter(tags=["API"]) - def get_user() -> str: - return default_user() - - UserDependency = Annotated[str, Depends(get_user)] - @router.post("/documents") def register_documents( user: UserDependency, document_registrations: list[schemas.DocumentRegistration] ) -> list[schemas.Document]: return engine.register_documents( - user=user, document_registrations=document_registrations + user=user.name, document_registrations=document_registrations ) @router.put("/documents") @@ -45,7 +35,7 @@ async def content_stream() -> AsyncIterator[bytes]: return content_stream() await engine.store_documents( - user=user, + user=user.name, ids_and_streams=[ (uuid.UUID(document.filename), make_content_stream(document)) for document in documents @@ -61,19 +51,19 @@ async def create_chat( user: UserDependency, chat_creation: schemas.ChatCreation, ) -> schemas.Chat: - return engine.create_chat(user=user, chat_creation=chat_creation) + return engine.create_chat(user=user.name, chat_creation=chat_creation) @router.get("/chats") async def get_chats(user: UserDependency) -> list[schemas.Chat]: - return engine.get_chats(user=user) + return engine.get_chats(user=user.name) @router.get("/chats/{id}") async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: - return engine.get_chat(user=user, id=id) + return engine.get_chat(user=user.name, id=id) @router.post("/chats/{id}/prepare") async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: - return await engine.prepare_chat(user=user, id=id) + return await engine.prepare_chat(user=user.name, id=id) @router.post("/chats/{id}/answer") async def answer( @@ -82,7 +72,7 @@ async def answer( prompt: Annotated[str, Body(..., embed=True)], stream: Annotated[bool, Body(..., embed=True)] = False, ) -> schemas.Message: - message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt) + message_stream = engine.answer_stream(user=user.name, chat_id=id, prompt=prompt) answer = await anext(message_stream) if not stream: @@ -107,6 +97,6 @@ async def to_jsonl( @router.delete("/chats/{id}") async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: - engine.delete_chat(user=user, id=id) + engine.delete_chat(user=user.name, id=id) return router diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py new file mode 100644 index 00000000..a115cb0e --- /dev/null +++ b/ragna/deploy/_auth.py @@ -0,0 +1,344 @@ +import abc +import base64 +import json +import os +import uuid +from typing import Annotated, Awaitable, Callable, Optional, Union, cast + +import httpx +import pydantic +from fastapi import Depends, FastAPI, Request, status +from fastapi.responses import HTMLResponse, RedirectResponse, Response +from fastapi.security.utils import get_authorization_scheme_param +from starlette.middleware.base import BaseHTTPMiddleware +from tornado.web import create_signed_value + +from ragna._utils import as_awaitable, default_user +from ragna.core import RagnaException + +from . import _schemas as schemas +from . import _templates as templates +from ._config import Config +from ._engine import Engine +from ._key_value_store import KeyValueStore +from ._utils import redirect + + +class Session(pydantic.BaseModel): + user: schemas.User + + +CallNext = Callable[[Request], Awaitable[Response]] + + +class SessionMiddleware(BaseHTTPMiddleware): + def __init__( + self, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool + ) -> None: + super().__init__(app) + self._config = config + self._engine = engine + self._api = api + self._ui = ui + self._sessions: KeyValueStore[Session] = config.key_value_store() + + _COOKIE_NAME = "ragna" + + async def dispatch(self, request: Request, call_next: CallNext) -> Response: + if (authorization := request.headers.get("Authorization")) is not None: + return await self._api_token_dispatch( + request, call_next, authorization=authorization + ) + elif (cookie := request.cookies.get(self._COOKIE_NAME)) is not None: + return await self._cookie_dispatch(request, call_next, cookie=cookie) + elif request.url.path in {"/login", "/oauth-callback"}: + return await self._login_dispatch(request, call_next) + elif self._api and request.url.path.startswith("/api"): + return self._unauthorized() + elif self._ui and request.url.path.startswith("/ui"): + return redirect("/login") + else: + # Either an unknown route or something on the default router. In any case, + # this doesn't need a session and so we let it pass. + request.state.session = None + return await call_next(request) + + async def _api_token_dispatch( + self, request: Request, call_next: CallNext, authorization: str + ) -> Response: + scheme, api_key = get_authorization_scheme_param(authorization) + if scheme.lower() != "bearer": + return self._unauthorized() + + user = self._engine.get_user(api_key=api_key) + if user is None: + # Unknown API key + return self._unauthorized() + + session = self._sessions.get(api_key) + if session is None: + # First time the API key is used + session = self._sessions[api_key] = Session(user=user) + + request.state.session = session + return await call_next(request) + + # panel uses cookies to transfer user information (see _cookie_dispatch() below) and + # signs them for security. However, since this happens after our authentication + # check, we can use an arbitrary, hardcoded value here. + PANEL_COOKIE_SECRET = "ragna" + + async def _cookie_dispatch( + self, request: Request, call_next: CallNext, *, cookie: str + ) -> Response: + session = self._sessions.get(cookie) + response: Response + if session is None: + # Invalid cookie + response = redirect("/login") + self._delete_cookie(response) + return response + + request.state.session = session + if self._ui and request.method == "GET" and request.url.path == "/ui": + # panel.state.user and panel.state.user_info are based on the two cookies + # below that the panel auth flow sets. Since we don't want extra cookies + # just for panel, we just inject them into the scope here, which will be + # parsed by panel down the line. After this initial request, the values are + # tied to the active session and don't have to be set again. + extra_cookies = { + "user": session.user.name, + "id_token": base64.b64encode(json.dumps(session.user.data).encode()), + } + extra_values = [ + ( + f"{key}=".encode() + + create_signed_value( + self.PANEL_COOKIE_SECRET, key, value, version=1 + ) + ) + for key, value in extra_cookies.items() + ] + + cookie_key = b"cookie" + idx, value = next( + (idx, value) + for idx, (key, value) in enumerate(request.scope["headers"]) + if key == cookie_key + ) + # We are not setting request.cookies or request.headers here, because any + # changes to them are not reflected back to the scope, which is the only + # safe way to transfer data between the middleware and an endpoint. + request.scope["headers"][idx] = ( + cookie_key, + b";".join([value, *extra_values]), + ) + + response = await call_next(request) + + if request.url.path == "/logout": + del self._sessions[cookie] + self._delete_cookie(response) + else: + self._add_cookie(response, cookie) + + return response + + async def _login_dispatch(self, request: Request, call_next: CallNext) -> Response: + request.state.session = None + response = await call_next(request) + session = request.state.session + + if session is not None: + cookie = str(uuid.uuid4()) + self._sessions[cookie] = session + self._add_cookie(response, cookie=cookie) + + return response + + def _unauthorized(self) -> Response: + return Response( + content="Not authenticated", + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + ) + + def _add_cookie(self, response: Response, cookie: str) -> None: + response.set_cookie( + key=self._COOKIE_NAME, + value=cookie, + # FIXME + max_age=3600, + # max_age=self._config.deploy.cookie_expires, + httponly=True, + samesite="lax", + ) + + def _delete_cookie(self, response: Response) -> None: + response.delete_cookie( + key=self._COOKIE_NAME, + httponly=True, + samesite="lax", + ) + + +async def _get_session(request: Request) -> Session: + session = cast(Optional[Session], request.state.session) + if session is None: + raise RagnaException("ADDME") + return session + + +SessionDependency = Annotated[Session, Depends(_get_session)] + + +async def _get_user(session: SessionDependency) -> schemas.User: + return session.user + + +UserDependency = Annotated[schemas.User, Depends(_get_user)] + + +class Auth(abc.ABC): + @classmethod + def _add_to_app( + cls, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool + ) -> None: + self = cls() + + @app.get("/login", include_in_schema=False) + async def login_page(request: Request) -> Response: + return await as_awaitable(self.login_page, request) + + async def _login(request: Request) -> Response: + result = await as_awaitable(self.login, request) + if not isinstance(result, schemas.User): + return result + + engine.maybe_add_user(result) + request.state.session = Session(user=result) + return redirect("/") + + @app.post("/login", include_in_schema=False) + async def login(request: Request) -> Response: + return await _login(request) + + @app.get("/oauth-callback", include_in_schema=False) + async def oauth_callback(request: Request) -> Response: + return await _login(request) + + @app.get("/logout", include_in_schema=False) + async def logout() -> RedirectResponse: + return redirect("/") + + app.add_middleware( + SessionMiddleware, + config=config, + engine=engine, + api=api, + ui=ui, + ) + + @abc.abstractmethod + def login_page(self, request: Request) -> Response: ... + + @abc.abstractmethod + def login(self, request: Request) -> Union[schemas.User, Response]: ... + + +class NoAuth(Auth): + def login_page(self, request: Request) -> Response: + # To invoke the login() method below, the client either needs to + # - POST /login or + # - GET /oauth-callback + # Since we cannot instruct a browser to post when sending redirect response, we + # use the OAuth callback endpoint here, although this has nothing to do with + # OAuth. + return redirect("/oauth-callback") + + def login(self, request: Request) -> schemas.User: + return schemas.User(name=request.headers.get("X-User", default_user())) + + +class DummyBasicAuth(Auth): + """Dummy OAuth2 password authentication without requirements. + + !!! danger + + As the name implies, this authentication is just testing or demo purposes and + should not be used in production. + """ + + def __init__(self) -> None: + self._password = os.environ.get("RAGNA_DUMMY_BASIC_AUTH_PASSWORD") + + def login_page( + self, request: Request, *, fail_reason: Optional[str] = None + ) -> HTMLResponse: + return HTMLResponse(templates.render("basic_auth.html")) + + async def login(self, request: Request) -> Union[schemas.User, HTMLResponse]: + async with request.form() as form: + username = cast(str, form.get("username")) + password = cast(str, form.get("password")) + + if username is None or password is None: + # This can only happen if the endpoint is not hit through the login page. + # Thus, instead of returning the failed login page like below, we just + # return an error. + raise RagnaException( + "Field 'username' or 'password' is missing from the form data.", + http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + http_detail=RagnaException.MESSAGE, + ) + + if (self._password is not None and password != self._password) or ( + self._password is None and password != username + ): + # FIXME: send the login page again with a failure message + return HTMLResponse("Unauthorized!") + + return schemas.User(name=username) + + +class GithubOAuth(Auth): + def __init__(self) -> None: + # FIXME: requirements + self._client_id = os.environ["RAGNA_GITHUB_OAUTH_CLIENT_ID"] + self._client_secret = os.environ["RAGNA_GITHUB_OAUTH_CLIENT_SECRET"] + + def login_page(self, request: Request) -> HTMLResponse: + return HTMLResponse( + templates.render( + "oauth.html", + service="GitHub", + url=f"https://github.com/login/oauth/authorize?client_id={self._client_id}", + ) + ) + + async def login(self, request: Request) -> Union[schemas.User, Response]: + async with httpx.AsyncClient(headers={"Accept": "application/json"}) as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + json={ + "client_id": self._client_id, + "client_secret": self._client_secret, + "code": request.query_params["code"], + }, + ) + access_token = response.json()["access_token"] + client.headers["Authorization"] = f"Bearer {access_token}" + + user_data = (await client.get("https://api.github.com/user")).json() + + organizations_data = ( + await client.get(user_data["organizations_url"]) + ).json() + organizations = { + organization_data["login"] for organization_data in organizations_data + } + if not (organizations & {"Quansight", "Quansight-Labs"}): + # FIXME: send the login page again with a failure message + return HTMLResponse("Unauthorized!") + + return schemas.User(name=user_data["login"]) diff --git a/ragna/deploy/_authentication.py b/ragna/deploy/_authentication.py deleted file mode 100644 index b8a4cbb8..00000000 --- a/ragna/deploy/_authentication.py +++ /dev/null @@ -1,132 +0,0 @@ -import abc -import os -import secrets -import time -from typing import cast - -import jwt -import rich -from fastapi import HTTPException, Request, status -from fastapi.security.utils import get_authorization_scheme_param - - -class Authentication(abc.ABC): - """Abstract base class for authentication used by the REST API.""" - - @abc.abstractmethod - async def create_token(self, request: Request) -> str: - """Authenticate user and create an authorization token. - - Args: - request: Request send to the `/token` endpoint of the REST API. - - Returns: - Authorization token. - """ - pass - - @abc.abstractmethod - async def get_user(self, request: Request) -> str: - """ - Args: - request: Request send to any endpoint of the REST API that requires - authorization. - - Returns: - Authorized user. - """ - pass - - -class RagnaDemoAuthentication(Authentication): - """Demo OAuth2 password authentication without requirements. - - !!! danger - - As the name implies, this authentication is just for demo purposes and should - not be used in production. - """ - - def __init__(self) -> None: - msg = f"INFO:\t{type(self).__name__}: You can log in with any username" - self._password = os.environ.get("RAGNA_DEMO_AUTHENTICATION_PASSWORD") - if self._password is None: - msg = f"{msg} and a matching password." - else: - msg = f"{msg} and the password {self._password}" - rich.print(msg) - - _JWT_SECRET = os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_SECRET", secrets.token_urlsafe(32)[:32] - ) - _JWT_ALGORITHM = "HS256" - _JWT_TTL = int(os.environ.get("RAGNA_DEMO_AUTHENTICATION_TTL", 60 * 60 * 24 * 7)) - - async def create_token(self, request: Request) -> str: - """Authenticate user and create an authorization token. - - User name is arbitrary. Authentication is possible in two ways: - - 1. If the `RAGNA_DEMO_AUTHENTICATION_PASSWORD` environment variable is set, the - password is checked against that. - 2. Otherwise, the password has to match the user name. - - Args: - request: Request send to the `/token` endpoint of the REST API. Must include - the `"username"` and `"password"` as form data. - - Returns: - Authorization [JWT](https://jwt.io/) that expires after one week. - """ - async with request.form() as form: - username = form.get("username") - password = form.get("password") - - if username is None or password is None: - raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY) - - if (self._password is not None and password != self._password) or ( - self._password is None and password != username - ): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - - return jwt.encode( - payload={"user": username, "exp": time.time() + self._JWT_TTL}, - key=self._JWT_SECRET, - algorithm=self._JWT_ALGORITHM, - ) - - async def get_user(self, request: Request) -> str: - """Get user from an authorization token. - - Token has to be supplied in the - [Bearer authentication scheme](https://swagger.io/docs/specification/authentication/bearer-authentication/), - i.e. including a `Authorization: Bearer {token}` header. - - Args: - request: Request send to any endpoint of the REST API that requires - authorization. - - Returns: - Authorized user. - """ - - unauthorized = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - - authorization = request.headers.get("Authorization") - scheme, token = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() != "bearer": - raise unauthorized - - try: - payload = jwt.decode( - token, key=self._JWT_SECRET, algorithms=[self._JWT_ALGORITHM] - ) - except (jwt.InvalidSignatureError, jwt.ExpiredSignatureError): - raise unauthorized - - return cast(str, payload["user"]) diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index e960f831..46b70e72 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -2,7 +2,16 @@ import itertools from pathlib import Path -from typing import Annotated, Any, Callable, Generic, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Generic, + Type, + TypeVar, + Union, +) import tomlkit import tomlkit.container @@ -18,7 +27,9 @@ from ragna._utils import make_directory from ragna.core import Assistant, Document, RagnaException, SourceStorage -from ._authentication import Authentication +if TYPE_CHECKING: + from ._auth import Auth + from ._key_value_store import KeyValueStore T = TypeVar("T") @@ -79,8 +90,9 @@ def settings_customise_sources( default_factory=ragna.local_root ) - authentication: ImportString[type[Authentication]] = ( - "ragna.deploy.RagnaDemoAuthentication" # type: ignore[assignment] + auth: ImportString[type[Auth]] = "ragna.deploy.NoAuth" # type: ignore[assignment] + key_value_store: ImportString[type[KeyValueStore]] = ( + "ragna.deploy.InMemoryKeyValueStore" # type: ignore[assignment] ) document: ImportString[type[Document]] = "ragna.core.LocalDocument" # type: ignore[assignment] diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 44c672a8..47597fc6 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -12,7 +12,9 @@ import ragna from ragna.core import RagnaException +from . import _schemas as schemas from ._api import make_router as make_api_router +from ._auth import UserDependency from ._config import Config from ._engine import Engine from ._ui import app as make_ui_app @@ -75,6 +77,8 @@ def server_available() -> bool: ignore_unavailable_components=ignore_unavailable_components, ) + config.auth._add_to_app(app, config=config, engine=engine, api=api, ui=ui) + if api: app.include_router(make_api_router(engine), prefix="/api") @@ -94,6 +98,10 @@ async def health() -> Response: async def version() -> str: return ragna.__version__ + @app.get("/user") + async def user(user: UserDependency) -> schemas.User: + return user + @app.exception_handler(RagnaException) async def ragna_exception_handler( request: Request, exc: RagnaException diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 529fa3b6..6855c389 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -1,7 +1,8 @@ from __future__ import annotations +import secrets import uuid -from typing import Any, Collection, Optional +from typing import Any, Collection, Optional, cast from urllib.parse import urlsplit from sqlalchemy import create_engine, select @@ -13,6 +14,14 @@ from . import _schemas as schemas +class UnknownUser(Exception): + def __init__( + self, name: Optional[str] = None, api_key: Optional[str] = None + ) -> None: + self.name = name + self.api_key = api_key + + class Database: def __init__(self, url: str) -> None: components = urlsplit(url) @@ -28,20 +37,52 @@ def __init__(self, url: str) -> None: self._to_orm = SchemaToOrmConverter() self._to_schema = OrmToSchemaConverter() - def _get_user(self, session: Session, *, username: str) -> orm.User: - user: Optional[orm.User] = session.execute( - select(orm.User).where(orm.User.name == username) - ).scalar_one_or_none() + def _get_orm_user( + self, + session: Session, + *, + name: Optional[str] = None, + api_key: Optional[str] = None, + ) -> orm.User: + selector = select(orm.User) + if name is None and api_key is None: + raise RagnaException + elif name is not None: + selector = selector.where(orm.User.name == name) + elif api_key is not None: + selector = selector.where(orm.User.api_key == api_key) + + user = cast(Optional[orm.User], session.execute(selector).scalar_one_or_none()) if user is None: - # Add a new user if the current username is not registered yet. Since this - # is behind the authentication layer, we don't need any extra security here. - user = orm.User(id=uuid.uuid4(), name=username) - session.add(user) - session.commit() + raise UnknownUser(name=name, api_key=api_key) return user + def maybe_add_user(self, session: Session, *, user: schemas.User) -> None: + try: + self._get_orm_user(session, name=user.name) + except UnknownUser: + orm_user = orm.User( + id=uuid.uuid4(), name=user.name, api_key=secrets.token_urlsafe(32)[:32] + ) + session.add(orm_user) + session.commit() + + def get_user( + self, + session: Session, + *, + name: Optional[str] = None, + api_key: Optional[str] = None, + ) -> Optional[schemas.User]: + try: + return self._to_schema.user( + self._get_orm_user(session, name=name, api_key=api_key) + ) + except UnknownUser: + return None + def add_documents( self, session: Session, @@ -49,7 +90,7 @@ def add_documents( user: str, documents: list[schemas.Document], ) -> None: - user_id = self._get_user(session, username=user).id + user_id = self._get_orm_user(session, name=user).id session.add_all( [self._to_orm.document(document, user_id=user_id) for document in documents] ) @@ -82,7 +123,7 @@ def get_documents( def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_user(session, username=user).id + chat, user_id=self._get_orm_user(session, name=user).id ) # We need to merge and not add here, because the documents are already in the DB session.merge(orm_chat) @@ -102,7 +143,7 @@ def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: self._to_schema.chat(chat) for chat in session.execute( self._select_chat(eager=True).where( - orm.Chat.user_id == self._get_user(session, username=user).id + orm.Chat.user_id == self._get_orm_user(session, name=user).id ) ) .scalars() @@ -117,7 +158,7 @@ def _get_orm_chat( session.execute( self._select_chat(eager=eager).where( (orm.Chat.id == id) - & (orm.Chat.user_id == self._get_user(session, username=user).id) + & (orm.Chat.user_id == self._get_orm_user(session, name=user).id) ) ) .unique() @@ -134,7 +175,7 @@ def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Cha def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_user(session, username=user).id + chat, user_id=self._get_orm_user(session, name=user).id ) session.merge(orm_chat) session.commit() @@ -199,6 +240,9 @@ def chat( class OrmToSchemaConverter: + def user(self, user: orm.User) -> schemas.User: + return schemas.User(name=user.name) + def document(self, document: orm.Document) -> schemas.Document: return schemas.Document( id=document.id, name=document.name, metadata=document.metadata_ diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 2209a61f..f0430177 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -4,14 +4,14 @@ from fastapi import status as http_status_code import ragna -from ragna import Rag, core +from ragna import core from ragna._compat import aiter, anext from ragna._utils import make_directory -from ragna.core import RagnaException +from ragna.core import Rag, RagnaException from ragna.core._rag import SpecialChatParams -from ragna.deploy import Config from . import _schemas as schemas +from ._config import Config from ._database import Database @@ -34,6 +34,16 @@ def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> No self._to_core = SchemaToCoreConverter(config=self._config, rag=self._rag) self._to_schema = CoreToSchemaConverter() + def maybe_add_user(self, user: schemas.User) -> None: + with self._database.get_session() as session: + return self._database.maybe_add_user(session, user=user) + + def get_user( + self, name: Optional[str] = None, api_key: Optional[str] = None + ) -> Optional[schemas.User]: + with self._database.get_session() as session: + return self._database.get_user(session, name=name, api_key=api_key) + def _get_component_json_schema( self, component: Type[core.Component], diff --git a/ragna/deploy/_key_value_store.py b/ragna/deploy/_key_value_store.py new file mode 100644 index 00000000..a22e53ca --- /dev/null +++ b/ragna/deploy/_key_value_store.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import abc +import os +from typing import Any, Generic, Optional, TypeVar + +import pydantic + +from ragna.core import PackageRequirement, Requirement +from ragna.core._utils import RequirementsMixin + +M = TypeVar("M", bound=pydantic.BaseModel) + + +class SerializableModel(pydantic.BaseModel, Generic[M]): + cls: pydantic.ImportString[type[M]] + obj: dict[str, Any] + + @classmethod + def from_model(cls, model: M) -> SerializableModel[M]: + return SerializableModel(cls=type(model), obj=model.model_dump(mode="json")) + + def to_model(self) -> M: + return self.cls.model_validate(self.obj) + + +class KeyValueStore(abc.ABC, RequirementsMixin, Generic[M]): + def serialize(self, model: M) -> str: + return SerializableModel.from_model(model).model_dump_json() + + def deserialize(self, json_str: str) -> M: + return SerializableModel.model_validate_json(json_str).to_model() + + @abc.abstractmethod + def __setitem__(self, key: str, model: M) -> None: ... + + @abc.abstractmethod + def __getitem__(self, key: str) -> M: ... + + @abc.abstractmethod + def __delitem__(self, key: str) -> None: ... + + def get(self, key: str, default: Optional[M] = None) -> Optional[M]: + try: + return self[key] + except KeyError: + return default + + +class InMemoryKeyValueStore(KeyValueStore[M]): + def __init__(self) -> None: + self._store: dict[str, M] = {} + + def __setitem__(self, key: str, model: M) -> None: + self._store[key] = model + + def __getitem__(self, key: str) -> M: + return self._store[key] + + def __delitem__(self, key: str) -> None: + del self._store[key] + + +class RedisKeyValueStore(KeyValueStore[M]): + @classmethod + def requirements(cls) -> list[Requirement]: + return [PackageRequirement("redis")] + + def __init__(self) -> None: + import redis + + self._r = redis.Redis( + host=os.environ.get("RAGNA_REDIS_HOST", "localhost"), + port=int(os.environ.get("RAGNA_REDIS_PORT", 6379)), + ) + + def __setitem__(self, key: str, model: M) -> None: + self._r[key] = self.serialize(model) + + def __getitem__(self, key: str) -> M: + return self.deserialize(self._r[key]) + + def __delitem__(self, key: str) -> None: + del self._r[key] diff --git a/ragna/deploy/_orm.py b/ragna/deploy/_orm.py index 04a3583e..7d3b6176 100644 --- a/ragna/deploy/_orm.py +++ b/ragna/deploy/_orm.py @@ -41,7 +41,8 @@ class User(Base): __tablename__ = "users" id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] - name = Column(types.String, nullable=False) + name = Column(types.String, nullable=False, unique=True) + api_key = Column(types.String, nullable=False, unique=True) document_chat_association_table = Table( diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index cc5490b7..7a13192c 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -9,6 +9,11 @@ import ragna.core +class User(BaseModel): + name: str + data: dict[str, Any] = Field(default_factory=dict) + + class Components(BaseModel): documents: list[str] source_storages: list[dict[str, Any]] diff --git a/ragna/deploy/_templates/__init__.py b/ragna/deploy/_templates/__init__.py new file mode 100644 index 00000000..9b7f9c20 --- /dev/null +++ b/ragna/deploy/_templates/__init__.py @@ -0,0 +1,11 @@ +from pathlib import Path +from typing import Any + +from jinja2 import Environment, FileSystemLoader + +ENVIRONMENT = Environment(loader=FileSystemLoader(Path(__file__).parent)) + + +def render(template: str, **context: Any) -> str: + template = ENVIRONMENT.get_template(template) + return template.render(**context) diff --git a/ragna/deploy/_templates/base.html b/ragna/deploy/_templates/base.html new file mode 100644 index 00000000..cd1e49a2 --- /dev/null +++ b/ragna/deploy/_templates/base.html @@ -0,0 +1,27 @@ + + + + + Ragna + + + + + + + + {% block content %}{% endblock %} + + diff --git a/ragna/deploy/_templates/basic_auth.html b/ragna/deploy/_templates/basic_auth.html new file mode 100644 index 00000000..039fbed2 --- /dev/null +++ b/ragna/deploy/_templates/basic_auth.html @@ -0,0 +1,24 @@ +{% extends "base.html" %} {% block content %} +
+
+ +
+
+ +
+
+
+ +
+{% endblock %} diff --git a/ragna/deploy/_templates/oauth.html b/ragna/deploy/_templates/oauth.html new file mode 100644 index 00000000..7a21f0cd --- /dev/null +++ b/ragna/deploy/_templates/oauth.html @@ -0,0 +1,7 @@ +{% extends "base.html" %} {% block content %} +
+ + + +
+{% endblock %} diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index 170e8bbd..776de870 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -2,9 +2,9 @@ from datetime import datetime import emoji +import panel as pn import param -from ragna.core._utils import default_user from ragna.deploy import _schemas as schemas from ragna.deploy._engine import Engine @@ -12,7 +12,7 @@ class ApiWrapper(param.Parameterized): def __init__(self, engine: Engine): super().__init__() - self._user = default_user() + self._user = pn.state.user self._engine = engine async def get_chats(self): diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index 052ff36d..7d8f386a 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -6,6 +6,7 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles +from ragna.deploy._auth import SessionMiddleware from ragna.deploy._engine import Engine from . import js @@ -138,6 +139,8 @@ def serve_with_fastapi(self, app: FastAPI, endpoint: str): name=dir, ) + pn.config.cookie_secret = SessionMiddleware.PANEL_COOKIE_SECRET + def app(engine: Engine) -> App: return App(engine) diff --git a/scripts/add_chats.py b/scripts/add_chats.py index 5f550289..14d8827a 100644 --- a/scripts/add_chats.py +++ b/scripts/add_chats.py @@ -8,25 +8,14 @@ def main(): client = httpx.Client(base_url="http://127.0.0.1:31476") client.get("/health").raise_for_status() - # ## authentication - # - # username = default_user() - # token = ( - # client.post( - # "/token", - # data={ - # "username": username, - # "password": os.environ.get( - # "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - # ), - # }, - # ) - # .raise_for_status() - # .json() - # ) - # client.headers["Authorization"] = f"Bearer {token}" - - print() + ## authentication + + # This only works if Ragna was deployed with ragna.core.NoAuth + # If that is not the case, login in whatever way is required, grab the API token and + # use the following instead + # client.headers["Authorization"] = f"Bearer {api_token}" + + client.get("/login", follow_redirects=True).raise_for_status() ## documents diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py new file mode 100644 index 00000000..6dd3fb78 --- /dev/null +++ b/tests/deploy/api/utils.py @@ -0,0 +1,35 @@ +import os + +from fastapi.testclient import TestClient + +from ragna._utils import default_user +from ragna.deploy._core import make_app + + +def make_api_app(*, config, ignore_unavailable_components): + return make_app( + config, + api=True, + ui=False, + ignore_unavailable_components=ignore_unavailable_components, + open_browser=False, + ) + + +def authenticate(client: TestClient) -> None: + return + username = default_user() + token = ( + client.post( + "/token", + data={ + "username": username, + "password": os.environ.get( + "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username + ), + }, + ) + .raise_for_status() + .json() + ) + client.headers["Authorization"] = f"Bearer {token}" From a9b718616528be7293c003dda93bf8779b996301 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 2 Aug 2024 09:32:58 +0200 Subject: [PATCH 02/10] cleanup basic auth --- ragna/deploy/_auth.py | 21 ++++++++++++++++----- ragna/deploy/_templates/basic_auth.html | 8 ++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index a115cb0e..8b36120e 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -273,9 +273,17 @@ def __init__(self) -> None: self._password = os.environ.get("RAGNA_DUMMY_BASIC_AUTH_PASSWORD") def login_page( - self, request: Request, *, fail_reason: Optional[str] = None + self, + request: Request, + *, + username: Optional[str] = None, + fail_reason: Optional[str] = None, ) -> HTMLResponse: - return HTMLResponse(templates.render("basic_auth.html")) + return HTMLResponse( + templates.render( + "basic_auth.html", username=username, fail_reason=fail_reason + ) + ) async def login(self, request: Request) -> Union[schemas.User, HTMLResponse]: async with request.form() as form: @@ -292,11 +300,14 @@ async def login(self, request: Request) -> Union[schemas.User, HTMLResponse]: http_detail=RagnaException.MESSAGE, ) - if (self._password is not None and password != self._password) or ( + if not username: + return self.login_page(request, fail_reason="Username cannot be empty") + elif (self._password is not None and password != self._password) or ( self._password is None and password != username ): - # FIXME: send the login page again with a failure message - return HTMLResponse("Unauthorized!") + return self.login_page( + request, username=username, fail_reason="Password incorrect" + ) return schemas.User(name=username) diff --git a/ragna/deploy/_templates/basic_auth.html b/ragna/deploy/_templates/basic_auth.html index 039fbed2..930be227 100644 --- a/ragna/deploy/_templates/basic_auth.html +++ b/ragna/deploy/_templates/basic_auth.html @@ -1,4 +1,9 @@ {% extends "base.html" %} {% block content %} +{% if fail_reason %} +
+ {{ fail_reason }} +
+{% endif %}
From 48e1f30c6a9b28c9f6fb322bcf17e8981d6f5ece Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 2 Aug 2024 10:15:37 +0200 Subject: [PATCH 03/10] mypy --- pyproject.toml | 7 +++++-- ragna/deploy/_auth.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 70627e4c..692f8a82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,9 +170,12 @@ disable_error_code = [ ] [[tool.mypy.overrides]] -# It is a fundamental feature of the components to request more parameters than the base -# class. Thus, we just silence mypy here. +# 1. We automatically handle user-defined sync and async methods +# 2. It is a fundamental feature of the RAG components to request more parameters than +# the base class. +# Thus, we just silence mypy where it would complain about the points above. module = [ + "ragna.deploy._auth", "ragna.source_storages.*", "ragna.assistants.*" ] diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index 8b36120e..283fb371 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -106,7 +106,7 @@ async def _cookie_dispatch( # just for panel, we just inject them into the scope here, which will be # parsed by panel down the line. After this initial request, the values are # tied to the active session and don't have to be set again. - extra_cookies = { + extra_cookies: dict[str, Union[str, bytes]] = { "user": session.user.name, "id_token": base64.b64encode(json.dumps(session.user.data).encode()), } @@ -285,7 +285,7 @@ def login_page( ) ) - async def login(self, request: Request) -> Union[schemas.User, HTMLResponse]: + async def login(self, request: Request) -> Union[schemas.User, Response]: async with request.form() as form: username = cast(str, form.get("username")) password = cast(str, form.get("password")) From 57cd2a4b3ba08f080953315c4b6e371b6cdd9d7e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 5 Aug 2024 23:43:42 +0200 Subject: [PATCH 04/10] initial styling --- ragna/deploy/_templates/__init__.py | 7 ++- ragna/deploy/_templates/base.html | 25 ++++++++++- ragna/deploy/_templates/basic_auth.css | 15 +++++++ ragna/deploy/_templates/basic_auth.html | 59 +++++++++++++------------ 4 files changed, 75 insertions(+), 31 deletions(-) create mode 100644 ragna/deploy/_templates/basic_auth.css diff --git a/ragna/deploy/_templates/__init__.py b/ragna/deploy/_templates/__init__.py index 9b7f9c20..7c977794 100644 --- a/ragna/deploy/_templates/__init__.py +++ b/ragna/deploy/_templates/__init__.py @@ -1,11 +1,16 @@ +import contextlib from pathlib import Path from typing import Any -from jinja2 import Environment, FileSystemLoader +from jinja2 import Environment, FileSystemLoader, TemplateNotFound ENVIRONMENT = Environment(loader=FileSystemLoader(Path(__file__).parent)) def render(template: str, **context: Any) -> str: + with contextlib.suppress(TemplateNotFound): + css_template = ENVIRONMENT.get_template(str(Path(template).with_suffix(".css"))) + context["__template_css__"] = css_template.render(**context) + template = ENVIRONMENT.get_template(template) return template.render(**context) diff --git a/ragna/deploy/_templates/base.html b/ragna/deploy/_templates/base.html index cd1e49a2..255044ee 100644 --- a/ragna/deploy/_templates/base.html +++ b/ragna/deploy/_templates/base.html @@ -21,7 +21,30 @@ href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.11.3/font/bootstrap-icons.min.css" /> + - {% block content %}{% endblock %} +
+ {% block content %}{% endblock %} +
diff --git a/ragna/deploy/_templates/basic_auth.css b/ragna/deploy/_templates/basic_auth.css new file mode 100644 index 00000000..a515d1a0 --- /dev/null +++ b/ragna/deploy/_templates/basic_auth.css @@ -0,0 +1,15 @@ +.basic-auth { + height: 100%; + display: flex; + flex-direction: column; + justify-content: space-between; +} + +.form-control { + background-color: gold; + margin: 10px; +} + +fieldset { + border: 1px solid #000; +} diff --git a/ragna/deploy/_templates/basic_auth.html b/ragna/deploy/_templates/basic_auth.html index 930be227..5269334f 100644 --- a/ragna/deploy/_templates/basic_auth.html +++ b/ragna/deploy/_templates/basic_auth.html @@ -1,32 +1,33 @@ {% extends "base.html" %} {% block content %} -{% if fail_reason %} -
- {{ fail_reason }} -
-{% endif %} - -
- -
-
- -
- -
- +
+

Log in

+ {% if fail_reason %} +
{{ fail_reason }}
+ {% endif %} +
+
+ Username + +
+
+ Password + +
+
+
+ +
{% endblock %} From 7b57a79d4569db2d5f1ebdbbdd298000bebe4137 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 26 Aug 2024 15:29:57 +0200 Subject: [PATCH 05/10] fix docs and cleanup --- docs/examples/gallery_streaming.py | 22 +-- docs/references/{rest-api.md => deploy.md} | 2 +- docs/tutorials/gallery_custom_components.py | 36 ++-- docs/tutorials/gallery_rest_api.py | 67 +++----- mkdocs.yml | 2 +- ragna/_docs.py | 181 ++++++++------------ ragna/_utils.py | 56 +++++- ragna/core/__init__.py | 1 - ragna/core/_rag.py | 1 - ragna/deploy/_auth.py | 8 + scripts/docs/gen_files.py | 16 +- tests/assistants/test_api.py | 28 ++- tests/utils.py | 24 --- 13 files changed, 216 insertions(+), 228 deletions(-) rename docs/references/{rest-api.md => deploy.md} (91%) diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 09a39abd..846897e6 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -107,29 +107,30 @@ def answer(self, messages): config = Config(assistants=[DemoStreamingAssistant]) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # Start and prepare the chat chat = ( client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": source_storages.RagnaDemoSourceStorage.display_name(), "assistant": DemoStreamingAssistant.display_name(), - "params": {}, }, ) .raise_for_status() .json() ) -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() # %% # Streaming the response is performed with [JSONL](https://jsonlines.org/). Each line @@ -140,7 +141,7 @@ def answer(self, messages): with client.stream( "POST", - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?", "stream": True}, ) as response: chunks = [json.loads(data) for data in response.iter_lines()] @@ -163,7 +164,8 @@ def answer(self, messages): print("".join(chunk["content"] for chunk in chunks)) # %% -# Before we close the example, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it with the `ragna api` command. +# Before we close the example, let's terminate the REST API and have a look at what +# would have printed in the terminal if we had started it with the `ragna deploy` +# command. -rest_api.stop() +ragna_deploy.terminate() diff --git a/docs/references/rest-api.md b/docs/references/deploy.md similarity index 91% rename from docs/references/rest-api.md rename to docs/references/deploy.md index bfd77544..6f39d899 100644 --- a/docs/references/rest-api.md +++ b/docs/references/deploy.md @@ -1,7 +1,7 @@ # REST API reference diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 8a411f81..6c556043 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -186,9 +186,11 @@ def answer(self, messages: list[Message]) -> Iterator[str]: assistants=[TutorialAssistant], ) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # To select our custom components, we pass their display names to the chat creation. @@ -201,10 +203,10 @@ def answer(self, messages: list[Message]) -> Iterator[str]: import json response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": TutorialAssistant.display_name(), "params": {}, @@ -212,10 +214,10 @@ def answer(self, messages: list[Message]) -> Iterator[str]: ).raise_for_status() chat = response.json() -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -225,7 +227,7 @@ def answer(self, messages: list[Message]) -> Iterator[str]: # Let's stop the REST API and have a look at what would have printed in the terminal if # we had started it with the `ragna api` command. -rest_api.stop() +ragna_deploy.terminate() # %% # ### Web UI @@ -263,9 +265,7 @@ def answer( my_optional_parameter: str = "foo", ) -> Iterator[str]: print(f"Running {type(self).__name__}().answer()") - yield ( - f"I was given {my_required_parameter=} and {my_optional_parameter=}." - ) + yield f"I was given {my_required_parameter=} and {my_optional_parameter=}." # %% @@ -319,19 +319,21 @@ def answer( assistants=[ElaborateTutorialAssistant], ) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # To pass custom parameters, define them in the `params` mapping when creating a new # chat. response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": ElaborateTutorialAssistant.display_name(), "params": { @@ -344,10 +346,10 @@ def answer( # %% -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -357,7 +359,7 @@ def answer( # Let's stop the REST API and have a look at what would have printed in the terminal if # we had started it with the `ragna api` command. -rest_api.stop() +ragna_deploy.terminate() # %% # ### Web UI diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index befcbfb3..590cbdf5 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -3,9 +3,9 @@ Ragna was designed to help you quickly build custom RAG powered web applications. For this you can leverage the built-in -[REST API](../../references/rest-api.md). +[REST API](../../references/deploy.md). -This tutorial walks you through basic steps of using Ragnas REST API. +This tutorial walks you through basic steps of using Ragna's REST API. """ # %% @@ -39,42 +39,28 @@ config = Config() -rest_api = ragna_docs.RestApi() -_ = rest_api.start(config) +ragna_deploy = ragna_docs.RagnaDeploy(config=config) # %% # Let's make sure the REST API is started correctly and can be reached. import httpx -client = httpx.Client(base_url=config.api.url) -client.get("/").raise_for_status() +client = httpx.Client(base_url=f"http://{config.hostname}:{config.port}") +client.get("/health").raise_for_status() # %% # ## Step 2: Authentication # -# In order to use Ragnas REST API, we need to authenticate first. To forge an API token +# In order to use Ragna's REST API, we need to authenticate first. To forge an API token # we send a request to the `/token` endpoint. This is processed by the -# [`Authentication`][ragna.deploy.Authentication], which can be overridden through the +# [ragna.deploy.Auth][] class, which can be overridden through the # config. For this tutorial, we use the default -# [ragna.deploy.RagnaDemoAuthentication][], which requires a matching username and +# [ragna.deploy.NoAuth][], which requires a matching username and # password. -username = password = "Ragna" - -response = client.post( - "/token", - data={"username": username, "password": password}, -).raise_for_status() -token = response.json() - -# %% -# We set the API token on our HTTP client so we don't have to manually supply it for -# each request below. - -client.headers["Authorization"] = f"Bearer {token}" - +client.get("/login", follow_redirects=True) # %% # ## Step 3: Uploading documents @@ -84,7 +70,7 @@ import json -response = client.get("/components").raise_for_status() +response = client.get("/api/components").raise_for_status() print(json.dumps(response.json(), indent=2)) # %% @@ -110,9 +96,9 @@ # 2. Perform the actual upload with the information from step 1. response = client.post( - "/document", json={"name": document_path.name} + "/api/documents", json=[{"name": document_path.name}] ).raise_for_status() -document_upload = response.json() +documents = response.json() print(json.dumps(response.json(), indent=2)) # %% @@ -125,15 +111,10 @@ # # We perform the actual upload with the latter now. -document = document_upload["document"] - -parameters = document_upload["parameters"] -client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": open(document_path, "rb")}, -).raise_for_status() +client.put( + "/api/documents", + files=[("documents", (documents[0]["id"], open(document_path, "rb")))], +) # %% # ## Step 4: Select a source storage and assistant @@ -155,13 +136,12 @@ # be used, we can create a new chat. response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"] for document in documents], "source_storage": source_storage, "assistant": assistant, - "params": {}, }, ).raise_for_status() chat = response.json() @@ -171,13 +151,13 @@ # As can be seen by the `"prepared"` field in the `chat` JSON object we still need to # prepare it. -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() # %% # Finally, we can get answers to our questions. response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -188,7 +168,8 @@ print(answer["content"]) # %% -# Before we close the tutorial, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it with the `ragna api` command. +# Before we close the tutorial, let's terminate the REST API and have a look at what +# would have printed in the terminal if we had started it with the `ragna deploy` +# command. -rest_api.stop() +ragna_deploy.terminate() diff --git a/mkdocs.yml b/mkdocs.yml index a31d6ffe..787bae60 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -91,7 +91,7 @@ nav: - community/contribute.md - References: - references/python-api.md - - references/rest-api.md + - references/deploy.md - references/cli.md - references/config.md - references/faq.md diff --git a/ragna/_docs.py b/ragna/_docs.py index 0d6191d4..d06fd215 100644 --- a/ragna/_docs.py +++ b/ragna/_docs.py @@ -11,11 +11,12 @@ import httpx -from ragna._utils import timeout_after from ragna.core import RagnaException from ragna.deploy import Config -__all__ = ["SAMPLE_CONTENT", "RestApi"] +from ._utils import BackgroundSubprocess + +__all__ = ["SAMPLE_CONTENT", "RagnaDeploy"] SAMPLE_CONTENT = """\ Ragna is an open source project built by Quansight. It is designed to allow @@ -29,51 +30,25 @@ """ -class RestApi: - def __init__(self) -> None: - self._process: Optional[subprocess.Popen] = None - # In case the documentation errors before we call RestApi.stop, we still need to - # stop the server to avoid zombie processes - atexit.register(self.stop, quiet=True) - - def start( - self, - config: Config, - *, - authenticate: bool = False, - upload_document: bool = False, - ) -> tuple[httpx.Client, Optional[dict]]: - if upload_document and not authenticate: - raise RagnaException( - "Cannot upload a document without authenticating first. " - "Set authenticate=True when using upload_document=True." - ) - python_path, config_path = self._prepare_config(config) - - client = httpx.Client(base_url=config.api.url) - - self._process = self._start_api(config_path, python_path, client) +class RagnaDeploy: + def __init__(self, config: Config) -> None: + self.config = config + python_path, config_path = self._prepare_config() + self._process = self._deploy(config, config_path, python_path) + # In case the documentation errors before we call RagnaDeploy.terminate, + # we still need to stop the server to avoid zombie processes + atexit.register(self.terminate, quiet=True) - if authenticate: - self._authenticate(client) - - if upload_document: - document = self._upload_document(client) - else: - document = None - - return client, document - - def _prepare_config(self, config: Config) -> tuple[str, str]: + def _prepare_config(self) -> tuple[str, str]: deploy_directory = Path(tempfile.mkdtemp()) - python_path = ( - f"{deploy_directory}{os.pathsep}{os.environ.get('PYTHONPATH', '')}" + python_path = os.pathsep.join( + [str(deploy_directory), os.environ.get("PYTHONPATH", "")] ) config_path = str(deploy_directory / "ragna.toml") - config.local_root = deploy_directory - config.api.database_url = f"sqlite:///{deploy_directory / 'ragna.db'}" + self.config.local_root = deploy_directory + self.config.database_url = f"sqlite:///{deploy_directory / 'ragna.db'}" sys.modules["__main__"].__file__ = inspect.getouterframes( inspect.currentframe() @@ -88,98 +63,92 @@ def _prepare_config(self, config: Config) -> tuple[str, str]: file.write("from ragna import *\n") file.write("from ragna.core import *\n") - for component in itertools.chain(config.source_storages, config.assistants): + for component in itertools.chain( + self.config.source_storages, self.config.assistants + ): if component.__module__ == "__main__": custom_components.add(component) file.write(f"{textwrap.dedent(inspect.getsource(component))}\n\n") component.__module__ = custom_module - config.to_file(config_path) + self.config.to_file(config_path) for component in custom_components: component.__module__ = "__main__" return python_path, config_path - def _start_api( - self, config_path: str, python_path: str, client: httpx.Client - ) -> subprocess.Popen: + def _deploy( + self, config: Config, config_path: str, python_path: str + ) -> BackgroundSubprocess: env = os.environ.copy() env["PYTHONPATH"] = python_path - process = subprocess.Popen( - [sys.executable, "-m", "ragna", "api", "--config", config_path], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) - - def check_api_available() -> bool: + def startup_fn() -> bool: try: - return client.get("/").is_success + return httpx.get(f"{config._url}/health").is_success except httpx.ConnectError: return False - failure_message = "Failed to the start the Ragna REST API." - - @timeout_after(60, message=failure_message) - def wait_for_api() -> None: - print("Starting Ragna REST API") - while not check_api_available(): - try: - stdout, stderr = process.communicate(timeout=1) - except subprocess.TimeoutExpired: - print(".", end="") - continue - else: - parts = [failure_message] - if stdout: - parts.append(f"\n\nSTDOUT:\n\n{stdout.decode()}") - if stderr: - parts.append(f"\n\nSTDERR:\n\n{stderr.decode()}") - - raise RuntimeError("".join(parts)) - - print() - - wait_for_api() - return process - - def _authenticate(self, client: httpx.Client) -> None: - username = password = "Ragna" + if startup_fn(): + raise RagnaException("ragna server is already running") + + return BackgroundSubprocess( + sys.executable, + "-m", + "ragna", + "deploy", + "--api", + "--no-ui", + "--config", + config_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + startup_fn=startup_fn, + startup_timeout=60, + ) - response = client.post( - "/token", - data={"username": username, "password": password}, - ).raise_for_status() - token = response.json() + def get_http_client( + self, + *, + authenticate: bool = False, + upload_document: bool = False, + ) -> tuple[httpx.Client, Optional[dict[str, Any]]]: + if upload_document and not authenticate: + raise RagnaException( + "Cannot upload a document without authenticating first. " + "Set authenticate=True when using upload_document=True." + ) - client.headers["Authorization"] = f"Bearer {token}" + client = httpx.Client(base_url=self.config._url) - def _upload_document(self, client: httpx.Client) -> dict[str, Any]: - name, content = "ragna.txt", SAMPLE_CONTENT + if authenticate: + client.get("/login", follow_redirects=True) - response = client.post("/document", json={"name": name}).raise_for_status() - document_upload = response.json() + if upload_document: + name, content = "ragna.txt", SAMPLE_CONTENT - document = cast(dict[str, Any], document_upload["document"]) + response = client.post( + "/api/documents", json=[{"name": name}] + ).raise_for_status() + document = cast(dict[str, Any], response.json()[0]) - parameters = document_upload["parameters"] - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": content}, - ).raise_for_status() + client.put( + "/api/documents", + files=[("documents", (document["id"], content.encode()))], + ) + else: + document = None - return document + return client, document - def stop(self, *, quiet: bool = False) -> None: + def terminate(self, quiet: bool = False) -> None: if self._process is None: return - self._process.terminate() - stdout, _ = self._process.communicate() + output = self._process.terminate() - if not quiet: - print(stdout.decode()) + if output and not quiet: + stdout, _ = output + print(stdout) diff --git a/ragna/_utils.py b/ragna/_utils.py index fe52349f..32bb24c5 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import contextlib import functools import getpass import inspect import os +import shlex +import subprocess import sys import threading +import time from pathlib import Path from typing import ( Any, @@ -78,7 +83,7 @@ def timeout_after( seconds: float = 30, *, message: str = "" ) -> Callable[[Callable], Callable]: timeout = f"Timeout after {seconds:.1f} seconds" - message = timeout if message else f"{timeout}: {message}" + message = f"{timeout}: {message}" if message else timeout def decorator(fn: Callable) -> Callable: if is_debugging(): @@ -162,3 +167,52 @@ def default_user() -> str: with contextlib.suppress(Exception): return os.getlogin() return "Bodil" + + +class BackgroundSubprocess: + def __init__( + self, + *cmd: str, + stdout: Any = sys.stdout, + stderr: Any = sys.stdout, + text: bool = True, + startup_fn: Optional[Callable[[], bool]] = None, + startup_timeout: float = 10, + terminate_timeout: float = 10, + **subprocess_kwargs: Any, + ) -> None: + self._process = subprocess.Popen( + cmd, stdout=stdout, stderr=stderr, **subprocess_kwargs + ) + try: + if startup_fn: + + @timeout_after(startup_timeout, message=shlex.join(cmd)) + def wait() -> None: + while not startup_fn(): + time.sleep(0.2) + + wait() + except Exception: + self.terminate() + raise + + self._terminate_timeout = terminate_timeout + + def terminate(self) -> tuple[str, str]: + @timeout_after(self._terminate_timeout) + def terminate() -> tuple[str, str]: + self._process.terminate() + return self._process.communicate() + + try: + return terminate() # type: ignore[no-any-return] + except TimeoutError: + self._process.kill() + return self._process.communicate() + + def __enter__(self) -> BackgroundSubprocess: + return self + + def __exit__(self, *exc_info: Any) -> None: + self.terminate() diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index 44449775..1cdbc667 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -4,7 +4,6 @@ "Component", "Document", "DocumentHandler", - "DocumentUploadParameters", "DocxDocumentHandler", "PptxDocumentHandler", "EnvVarRequirement", diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 23e47d78..915e941a 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -145,7 +145,6 @@ def chat( Args: documents: Documents to use. If any item is not a [ragna.core.Document][], [ragna.core.LocalDocument.from_path][] is invoked on it. - FIXME source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index 283fb371..45faefe9 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -200,6 +200,10 @@ async def _get_user(session: SessionDependency) -> schemas.User: class Auth(abc.ABC): + """ + ADDME + """ + @classmethod def _add_to_app( cls, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool @@ -247,6 +251,10 @@ def login(self, request: Request) -> Union[schemas.User, Response]: ... class NoAuth(Auth): + """ + ADDME + """ + def login_page(self, request: Request) -> Response: # To invoke the login() method below, the client either needs to # - POST /login or diff --git a/scripts/docs/gen_files.py b/scripts/docs/gen_files.py index d3760181..8a89179d 100644 --- a/scripts/docs/gen_files.py +++ b/scripts/docs/gen_files.py @@ -9,7 +9,7 @@ import typer.rich_utils from ragna.deploy import Config -from ragna.deploy._api import app as api_app +from ragna.deploy._core import make_app as make_deploy_app # This is currently needed when using top-level async code in the galleries. It has to # be placed before the ragna.deploy._cli import as this ultimately import panel, which @@ -17,12 +17,12 @@ # See https://github.com/smarie/mkdocs-gallery/issues/93 asyncio.get_event_loop_policy()._local._set_called = False -from ragna.deploy._cli import app as cli_app # noqa: E402 +from ragna._cli import app as cli_app # noqa: E402 def main(): cli_reference() - api_reference() + deploy_reference() config_reference() @@ -51,8 +51,14 @@ def get_doc(command): file.write(get_doc(command.name or command.callback.__name__)) -def api_reference(): - app = api_app(config=Config(), ignore_unavailable_components=False) +def deploy_reference(): + app = make_deploy_app( + config=Config(), + api=True, + ui=True, + ignore_unavailable_components=False, + open_browser=False, + ) openapi_json = fastapi.openapi.utils.get_openapi( title=app.title, version=app.version, diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 10ec9506..b46e6d0e 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -2,7 +2,6 @@ import itertools import json import os -import time from pathlib import Path import httpx @@ -10,10 +9,10 @@ from ragna import assistants from ragna._compat import anext -from ragna._utils import timeout_after +from ragna._utils import BackgroundSubprocess from ragna.assistants._http_api import HttpApiAssistant, HttpStreamingProtocol from ragna.core import Message, RagnaException -from tests.utils import background_subprocess, get_available_port, skip_on_windows +from tests.utils import get_available_port, skip_on_windows HTTP_API_ASSISTANTS = [ assistant @@ -44,26 +43,19 @@ def streaming_server(): port = get_available_port() base_url = f"http://localhost:{port}" - with background_subprocess( + def check_fn(): + try: + return httpx.get(f"{base_url}/health").is_success + except httpx.ConnectError: + return False + + with BackgroundSubprocess( "uvicorn", f"--app-dir={Path(__file__).parent}", f"--port={port}", "streaming_server:app", + startup_fn=check_fn, ): - - def up(): - try: - return httpx.get(f"{base_url}/health").is_success - except httpx.ConnectError: - return False - - @timeout_after(10, message="Failed to start streaming server") - def wait(): - while not up(): - time.sleep(0.2) - - wait() - yield base_url diff --git a/tests/utils.py b/tests/utils.py index b4d37540..87081f65 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,37 +1,13 @@ -import contextlib import platform import socket -import subprocess -import sys import pytest -from ragna._utils import timeout_after - skip_on_windows = pytest.mark.skipif( platform.system() == "Windows", reason="Test is broken skipped on Windows" ) -@contextlib.contextmanager -def background_subprocess(*args, stdout=sys.stdout, stderr=sys.stdout, **kwargs): - process = subprocess.Popen(args, stdout=stdout, stderr=stderr, **kwargs) - try: - yield process - finally: - - @timeout_after(5) - def terminate(): - process.terminate() - process.communicate() - - try: - terminate() - except TimeoutError: - process.kill() - process.communicate() - - def get_available_port(): with socket.socket() as s: s.bind(("", 0)) From 8abaadeee9690db7c29644c493798febf47d5266 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 26 Aug 2024 16:39:27 +0200 Subject: [PATCH 06/10] more doc fixes --- docs/references/config.md | 38 ++++++------------------ docs/references/release-notes.md | 6 ++-- docs/tutorials/gallery_rest_api.py | 46 +++++++++++++++++------------- 3 files changed, 37 insertions(+), 53 deletions(-) diff --git a/docs/references/config.md b/docs/references/config.md index c6ed18b9..ba83263a 100644 --- a/docs/references/config.md +++ b/docs/references/config.md @@ -69,9 +69,9 @@ is equivalent to `RAGNA_API_ORIGINS='["http://localhost:31477"]'`. Local root directory Ragna uses for storing files. See [ragna.local_root][]. -### `authentication` +### `auth` -[ragna.deploy.Authentication][] class to use for authenticating users. +[ragna.deploy.Auth][] class to use for authenticating users. ### `document` @@ -85,48 +85,26 @@ Local root directory Ragna uses for storing files. See [ragna.local_root][]. [ragna.core.Assistant][]s to be available for the user to use. -### `api` - -#### `hostname` +### `hostname` Hostname the REST API will be bound to. -#### `port` +### `port` Port the REST API will be bound to. -#### `root_path` +### `root_path` A path prefix handled by a proxy that is not seen by the REST API, but is seen by external clients. -#### `url` - -URL of the REST API to be accessed by the web UI. Make sure to include the -[`root_path`](#root_path) if set. - -#### `origins` +### `origins` [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) origins that are allowed -to connect to the REST API. The URL of the web UI is required for it to function. +to connect to the REST API. -#### `database_url` +### `database_url` URL of a SQL database that will be used to store the Ragna state. See [SQLAlchemy documentation](https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls) on how to format the URL. - -### `ui` - -#### `hostname` - -Hostname the web UI will be bound to. - -#### `port` - -Port the web UI will be bound to. - -#### `origins` - -[CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) origins that are allowed -to connect to the web UI. diff --git a/docs/references/release-notes.md b/docs/references/release-notes.md index 17457078..40ab7ff6 100644 --- a/docs/references/release-notes.md +++ b/docs/references/release-notes.md @@ -137,9 +137,9 @@ -- The classes [ragna.deploy.Authentication][], [ragna.deploy.RagnaDemoAuthentication][], - and [ragna.deploy.Config][] moved from the [ragna.core][] module to a new - [ragna.deploy][] module. +- The classes `ragna.deploy.Authentication`, `ragna.deploy.RagnaDemoAuthentication`, and + [ragna.deploy.Config][] moved from the [ragna.core][] module to a new [ragna.deploy][] + module. - [ragna.core.Component][], which is the superclass for [ragna.core.Assistant][] and [ragna.core.SourceStorage][], no longer takes a [ragna.deploy.Config][] to instantiate. For example diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index 590cbdf5..ede8833d 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -53,14 +53,25 @@ # %% # ## Step 2: Authentication # -# In order to use Ragna's REST API, we need to authenticate first. To forge an API token -# we send a request to the `/token` endpoint. This is processed by the -# [ragna.deploy.Auth][] class, which can be overridden through the -# config. For this tutorial, we use the default -# [ragna.deploy.NoAuth][], which requires a matching username and -# password. +# In order to use Ragna's REST API, we need to authenticate first. This is handled by +# the [ragna.deploy.Auth][] class, which can be overridden through the config. By +# default, [ragna.deploy.NoAuth][] is used. By hitting the `/login` endpoint, we get a +# session cookie, which is later used to authorize our requests. client.get("/login", follow_redirects=True) +dict(client.cookies) + +# %% +# !!! note +# +# In a regular deployment, you'll have login through your browser and create an API +# key in your profile page. The API key is used as +# [bearer token](https://swagger.io/docs/specification/authentication/bearer-authentication/) +# and can be set with +# +# ```python +# httpx.Client(..., headers={"Authorization": f"Bearer {RAGNA_API_KEY}"}) +# ``` # %% # ## Step 3: Uploading documents @@ -88,28 +99,23 @@ # %% # The upload process in Ragna consists of two parts: # -# 1. Announce the file to be uploaded. Under the hood this pre-registers the document -# in Ragnas database and returns information about how the upload is to be performed. -# This is handled by the [ragna.core.Document][] class. By default, -# [ragna.core.LocalDocument][] is used, which uploads the files to the local file -# system. -# 2. Perform the actual upload with the information from step 1. +# 1. Announce the file to be uploaded. Under the hood this registers the document +# in Ragna's database and returns the document ID, which is needed for the upload. response = client.post( "/api/documents", json=[{"name": document_path.name}] ).raise_for_status() documents = response.json() -print(json.dumps(response.json(), indent=2)) +print(json.dumps(documents, indent=2)) # %% -# The returned JSON contains two parts: the document object that we are later going to -# use to create a chat as well as the upload parameters. -# !!! note -# -# The `"token"` in the response is *not* the Ragna REST API token, but rather a -# separate one to perform the document upload. +# 2. Perform the actual upload with the information from step 1. through a +# [multipart request](https://swagger.io/docs/specification/describing-request-body/multipart-requests/) +# with the following parameters: # -# We perform the actual upload with the latter now. +# - The field is `documents` for all entries +# - The field name is the ID of the document returned by step 1. +# - The field value is the binary content of the document. client.put( "/api/documents", From 876e91d8ad9a38af96e72c369331d7670f3428cb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 29 Aug 2024 12:42:33 +0200 Subject: [PATCH 07/10] dirty --- ragna/deploy/_auth.py | 43 +++++++---- ragna/deploy/_config.py | 1 + ragna/deploy/_core.py | 15 ++++ ragna/deploy/_database.py | 124 +++++++++++++++++++++---------- ragna/deploy/_engine.py | 40 +++++++++- ragna/deploy/_key_value_store.py | 80 ++++++++++++++------ ragna/deploy/_orm.py | 50 +++++++++++-- ragna/deploy/_schemas.py | 50 ++++++++++++- 8 files changed, 314 insertions(+), 89 deletions(-) diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index 45faefe9..b81f5bed 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -54,7 +54,7 @@ async def dispatch(self, request: Request, call_next: CallNext) -> Response: elif request.url.path in {"/login", "/oauth-callback"}: return await self._login_dispatch(request, call_next) elif self._api and request.url.path.startswith("/api"): - return self._unauthorized() + return self._unauthorized("Missing authorization header") elif self._ui and request.url.path.startswith("/ui"): return redirect("/login") else: @@ -68,17 +68,25 @@ async def _api_token_dispatch( ) -> Response: scheme, api_key = get_authorization_scheme_param(authorization) if scheme.lower() != "bearer": - return self._unauthorized() + return self._unauthorized("Bearer authentication scheme required") - user = self._engine.get_user(api_key=api_key) - if user is None: - # Unknown API key - return self._unauthorized() + user, expired = self._engine.get_user_by_api_key(api_key) + if user is None or expired: + self._sessions.delete(api_key) + reason = "Invalid" if user is None else "Expired" + return self._unauthorized(f"{reason} API key") session = self._sessions.get(api_key) if session is None: # First time the API key is used - session = self._sessions[api_key] = Session(user=user) + session = Session(user=user) + # We are using the API key value instead of its ID as session key for two + # reasons: + # 1. Similar to its ID, the value is unique and thus can be safely used as + # key. + # 2. If an API key was deleted, we lose its ID, but still need to be able to + # remove its corresponding session. + self._sessions.set(api_key, session, expires_after=3600) request.state.session = session return await call_next(request) @@ -137,9 +145,10 @@ async def _cookie_dispatch( response = await call_next(request) if request.url.path == "/logout": - del self._sessions[cookie] + self._sessions.delete(cookie) self._delete_cookie(response) else: + self._sessions.refresh(cookie, expires_after=self._config.session_lifetime) self._add_cookie(response, cookie) return response @@ -151,14 +160,16 @@ async def _login_dispatch(self, request: Request, call_next: CallNext) -> Respon if session is not None: cookie = str(uuid.uuid4()) - self._sessions[cookie] = session + self._sessions.set( + cookie, session, expires_after=self._config.session_lifetime + ) self._add_cookie(response, cookie=cookie) return response - def _unauthorized(self) -> Response: + def _unauthorized(self, message: str) -> Response: return Response( - content="Not authenticated", + content=message, status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"}, ) @@ -167,9 +178,7 @@ def _add_cookie(self, response: Response, cookie: str) -> None: response.set_cookie( key=self._COOKIE_NAME, value=cookie, - # FIXME - max_age=3600, - # max_age=self._config.deploy.cookie_expires, + max_age=self._config.session_lifetime, httponly=True, samesite="lax", ) @@ -185,7 +194,11 @@ def _delete_cookie(self, response: Response) -> None: async def _get_session(request: Request) -> Session: session = cast(Optional[Session], request.state.session) if session is None: - raise RagnaException("ADDME") + raise RagnaException( + "Not authenticated", + http_detail=RagnaException.EVENT, + http_status_code=status.HTTP_401_UNAUTHORIZED, + ) return session diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index 46b70e72..7a30b097 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -109,6 +109,7 @@ def settings_customise_sources( origins: list[str] = AfterConfigValidateDefault.make( lambda config: [f"http://{config.hostname}:{config.port}"] ) + session_lifetime: int = 60 * 60 * 24 database_url: str = AfterConfigValidateDefault.make( lambda config: f"sqlite:///{config.local_root}/ragna.db", diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 47597fc6..067fe6ed 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -1,6 +1,7 @@ import contextlib import threading import time +import uuid import webbrowser from typing import AsyncContextManager, AsyncIterator, Callable, Optional, cast @@ -102,6 +103,20 @@ async def version() -> str: async def user(user: UserDependency) -> schemas.User: return user + @app.get("/api-keys") + def list_api_keys(user: UserDependency) -> list[schemas.ApiKey]: + return engine.list_api_keys(user=user.name) + + @app.post("/api-keys") + def create_api_key( + user: UserDependency, api_key_creation: schemas.ApiKeyCreation + ) -> schemas.ApiKey: + return engine.create_api_key(user=user.name, api_key_creation=api_key_creation) + + @app.delete("/api-keys/{id}") + def delete_api_key(user: UserDependency, id: uuid.UUID) -> None: + return engine.delete_api_key(user=user.name, id=id) + @app.exception_handler(RagnaException) async def ragna_exception_handler( request: Request, exc: RagnaException diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 6855c389..54f7b9f2 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -1,6 +1,5 @@ from __future__ import annotations -import secrets import uuid from typing import Any, Collection, Optional, cast from urllib.parse import urlsplit @@ -37,52 +36,88 @@ def __init__(self, url: str) -> None: self._to_orm = SchemaToOrmConverter() self._to_schema = OrmToSchemaConverter() - def _get_orm_user( - self, - session: Session, - *, - name: Optional[str] = None, - api_key: Optional[str] = None, - ) -> orm.User: - selector = select(orm.User) - if name is None and api_key is None: - raise RagnaException - elif name is not None: - selector = selector.where(orm.User.name == name) - elif api_key is not None: - selector = selector.where(orm.User.api_key == api_key) - - user = cast(Optional[orm.User], session.execute(selector).scalar_one_or_none()) + def _get_orm_user_by_name(self, session: Session, *, name: str) -> orm.User: + user = cast( + Optional[orm.User], + session.execute( + select(orm.User).where(orm.User.name == name) + ).scalar_one_or_none(), + ) if user is None: - raise UnknownUser(name=name, api_key=api_key) + raise UnknownUser(name) return user def maybe_add_user(self, session: Session, *, user: schemas.User) -> None: try: - self._get_orm_user(session, name=user.name) + self._get_orm_user_by_name(session, name=user.name) except UnknownUser: - orm_user = orm.User( - id=uuid.uuid4(), name=user.name, api_key=secrets.token_urlsafe(32)[:32] - ) + orm_user = orm.User(id=uuid.uuid4(), name=user.name) session.add(orm_user) session.commit() - def get_user( - self, - session: Session, - *, - name: Optional[str] = None, - api_key: Optional[str] = None, - ) -> Optional[schemas.User]: - try: - return self._to_schema.user( - self._get_orm_user(session, name=name, api_key=api_key) + def add_api_key( + self, session: Session, *, user: str, api_key: schemas.ApiKey + ) -> None: + user_id = self._get_orm_user_by_name(session, name=user).id + orm_api_key = orm.ApiKey( + id=uuid.uuid4(), + user_id=user_id, + name=api_key.name, + value=api_key.value, + expires_at=api_key.expires_at, + ) + session.add(orm_api_key) + session.commit() + + def get_api_keys(self, session: Session, *, user: str) -> list[schemas.ApiKey]: + return [ + self._to_schema.api_key(api_key) + for api_key in session.execute( + select(orm.ApiKey).where( + orm.ApiKey.user_id + == self._get_orm_user_by_name(session, name=user).id + ) ) - except UnknownUser: + .scalars() + .all() + ] + + def delete_api_key(self, session: Session, *, user: str, id: uuid.UUID) -> None: + orm_api_key = session.execute( + select(orm.ApiKey).where( + (orm.ApiKey.id == id) + & ( + orm.ApiKey.user_id + == self._get_orm_user_by_name(session, name=user).id + ) + ) + ).scalar_one_or_none() + + if orm_api_key is None: + raise RagnaException + + session.delete(orm_api_key) # type: ignore[no-untyped-call] + session.commit() + + def get_user_by_api_key( + self, session: Session, api_key_value: str + ) -> Optional[tuple[schemas.User, schemas.ApiKey]]: + orm_api_key = session.execute( + select(orm.ApiKey) # type: ignore[attr-defined] + .options(joinedload(orm.ApiKey.user)) + .where(orm.ApiKey.value == api_key_value) + ).scalar_one_or_none() + + if orm_api_key is None: return None + return ( + self._to_schema.user(orm_api_key.user), + self._to_schema.api_key(orm_api_key), + ) + def add_documents( self, session: Session, @@ -90,7 +125,7 @@ def add_documents( user: str, documents: list[schemas.Document], ) -> None: - user_id = self._get_orm_user(session, name=user).id + user_id = self._get_orm_user_by_name(session, name=user).id session.add_all( [self._to_orm.document(document, user_id=user_id) for document in documents] ) @@ -123,7 +158,7 @@ def get_documents( def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_orm_user(session, name=user).id + chat, user_id=self._get_orm_user_by_name(session, name=user).id ) # We need to merge and not add here, because the documents are already in the DB session.merge(orm_chat) @@ -143,7 +178,8 @@ def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: self._to_schema.chat(chat) for chat in session.execute( self._select_chat(eager=True).where( - orm.Chat.user_id == self._get_orm_user(session, name=user).id + orm.Chat.user_id + == self._get_orm_user_by_name(session, name=user).id ) ) .scalars() @@ -158,7 +194,10 @@ def _get_orm_chat( session.execute( self._select_chat(eager=eager).where( (orm.Chat.id == id) - & (orm.Chat.user_id == self._get_orm_user(session, name=user).id) + & ( + orm.Chat.user_id + == self._get_orm_user_by_name(session, name=user).id + ) ) ) .unique() @@ -175,7 +214,7 @@ def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Cha def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_orm_user(session, name=user).id + chat, user_id=self._get_orm_user_by_name(session, name=user).id ) session.merge(orm_chat) session.commit() @@ -243,6 +282,15 @@ class OrmToSchemaConverter: def user(self, user: orm.User) -> schemas.User: return schemas.User(name=user.name) + def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey: + return schemas.ApiKey( + id=api_key.id, + name=api_key.name, + expires_at=api_key.expires_at, + obfuscated=True, + value=api_key.value, + ) + def document(self, document: orm.Document) -> schemas.Document: return schemas.Document( id=document.id, name=document.name, metadata=document.metadata_ diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index f0430177..45e9376e 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,3 +1,4 @@ +import secrets import uuid from typing import Any, AsyncIterator, Optional, Type, cast @@ -38,11 +39,42 @@ def maybe_add_user(self, user: schemas.User) -> None: with self._database.get_session() as session: return self._database.maybe_add_user(session, user=user) - def get_user( - self, name: Optional[str] = None, api_key: Optional[str] = None - ) -> Optional[schemas.User]: + def get_user_by_api_key( + self, api_key_value: str + ) -> tuple[Optional[schemas.User], bool]: with self._database.get_session() as session: - return self._database.get_user(session, name=name, api_key=api_key) + data = self._database.get_user_by_api_key( + session, api_key_value=api_key_value + ) + + if data is None: + return None, False + + user, api_key = data + return user, api_key.expired + + def create_api_key( + self, user: str, api_key_creation: schemas.ApiKeyCreation + ) -> schemas.ApiKey: + api_key = schemas.ApiKey( + name=api_key_creation.name, + expires_at=api_key_creation.expires_at, + obfuscated=False, + value=secrets.token_urlsafe(32)[:32], + ) + + with self._database.get_session() as session: + self._database.add_api_key(session, user=user, api_key=api_key) + + return api_key + + def list_api_keys(self, user: str) -> list[schemas.ApiKey]: + with self._database.get_session() as session: + return self._database.get_api_keys(session, user=user) + + def delete_api_key(self, user: str, id: uuid.UUID) -> None: + with self._database.get_session() as session: + self._database.delete_api_key(session, user=user, id=id) def _get_component_json_schema( self, diff --git a/ragna/deploy/_key_value_store.py b/ragna/deploy/_key_value_store.py index a22e53ca..8218aacc 100644 --- a/ragna/deploy/_key_value_store.py +++ b/ragna/deploy/_key_value_store.py @@ -2,7 +2,8 @@ import abc import os -from typing import Any, Generic, Optional, TypeVar +import time +from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast import pydantic @@ -28,38 +29,62 @@ class KeyValueStore(abc.ABC, RequirementsMixin, Generic[M]): def serialize(self, model: M) -> str: return SerializableModel.from_model(model).model_dump_json() - def deserialize(self, json_str: str) -> M: - return SerializableModel.model_validate_json(json_str).to_model() + def deserialize(self, data: Union[str, bytes]) -> M: + return SerializableModel.model_validate_json(data).to_model() @abc.abstractmethod - def __setitem__(self, key: str, model: M) -> None: ... + def set( + self, key: str, model: M, *, expires_after: Optional[int] = None + ) -> None: ... @abc.abstractmethod - def __getitem__(self, key: str) -> M: ... + def get(self, key: str) -> Optional[M]: ... @abc.abstractmethod - def __delitem__(self, key: str) -> None: ... + def delete(self, key: str) -> None: ... - def get(self, key: str, default: Optional[M] = None) -> Optional[M]: - try: - return self[key] - except KeyError: - return default + @abc.abstractmethod + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: ... class InMemoryKeyValueStore(KeyValueStore[M]): def __init__(self) -> None: - self._store: dict[str, M] = {} + self._store: dict[str, tuple[M, Optional[float]]] = {} + self._timer: Callable[[], float] = time.monotonic + + def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + if expires_after is not None: + expires_at = self._timer() + expires_after + else: + expires_at = None + self._store[key] = (model, expires_at) + + def get(self, key: str) -> Optional[M]: + value = self._store.get(key) + if value is None: + return None + + model, expires_at = value + if expires_at is not None and self._timer() >= expires_at: + self.delete(key) + return None - def __setitem__(self, key: str, model: M) -> None: - self._store[key] = model + return model - def __getitem__(self, key: str) -> M: - return self._store[key] + def delete(self, key: str) -> None: + if key not in self._store: + return - def __delitem__(self, key: str) -> None: del self._store[key] + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + value = self._store.get(key) + if value is None: + return + + model, _ = value + self.set(key, model, expires_after=expires_after) + class RedisKeyValueStore(KeyValueStore[M]): @classmethod @@ -74,11 +99,20 @@ def __init__(self) -> None: port=int(os.environ.get("RAGNA_REDIS_PORT", 6379)), ) - def __setitem__(self, key: str, model: M) -> None: - self._r[key] = self.serialize(model) + def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + self._r.set(key, self.serialize(model), ex=expires_after) + + def get(self, key: str) -> Optional[M]: + value = cast(bytes, self._r.get(key)) + if value is None: + return None + return self.deserialize(value) - def __getitem__(self, key: str) -> M: - return self.deserialize(self._r[key]) + def delete(self, key: str) -> None: + self._r.delete(key) - def __delitem__(self, key: str) -> None: - del self._r[key] + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + if expires_after is None: + self._r.persist(key) + else: + self._r.expire(key, expires_after) diff --git a/ragna/deploy/_orm.py b/ragna/deploy/_orm.py index 7d3b6176..7d672167 100644 --- a/ragna/deploy/_orm.py +++ b/ragna/deploy/_orm.py @@ -1,5 +1,6 @@ import json -from typing import Any +from datetime import datetime, timezone +from typing import Any, Optional from sqlalchemy import Column, ForeignKey, Table, types from sqlalchemy.engine import Dialect @@ -30,19 +31,56 @@ def process_result_value( return json.loads(value) +class UtcAwareDateTime(types.TypeDecorator): + """UTC timezone aware datetime type. + + This is needed because sqlalchemy.types.DateTime(timezone=True) does not + consistently store the timezone. + + """ + + impl = types.DateTime + + def process_bind_param( # type: ignore[override] + self, value: Optional[datetime], dialect: Dialect + ) -> Optional[datetime]: + if value is not None: + assert value.tzinfo == timezone.utc + + return value + + def process_result_value( + self, value: Optional[datetime], dialect: Dialect + ) -> Optional[datetime]: + if value is None: + return None + + return value.replace(tzinfo=timezone.utc) + + class Base(DeclarativeBase): pass -# FIXME: Do we actually need this table? If we are sure that a user is unique and has to -# be authenticated from the API layer, it seems having an extra mapping here is not -# needed? class User(Base): __tablename__ = "users" id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] name = Column(types.String, nullable=False, unique=True) - api_key = Column(types.String, nullable=False, unique=True) + api_keys = relationship("ApiKey", back_populates="user") + + +class ApiKey(Base): + __tablename__ = "api_keys" + + id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] + + user_id = Column(ForeignKey("users.id")) + user = relationship("User", back_populates="api_keys") + + name = Column(types.String, nullable=False) + value = Column(types.String, nullable=False, unique=True) + expires_at = Column(UtcAwareDateTime, nullable=False) document_chat_association_table = Table( @@ -135,4 +173,4 @@ class Message(Base): secondary=source_message_association_table, back_populates="messages", ) - timestamp = Column(types.DateTime, nullable=False) + timestamp = Column(types.DateTime(timezone=True), nullable=False) diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 7a13192c..f6fd2228 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -1,10 +1,16 @@ from __future__ import annotations -import datetime import uuid +from datetime import datetime, timezone from typing import Any -from pydantic import BaseModel, Field +from pydantic import ( + BaseModel, + Field, + ValidationInfo, + computed_field, + field_validator, +) import ragna.core @@ -14,6 +20,44 @@ class User(BaseModel): data: dict[str, Any] = Field(default_factory=dict) +class ApiKeyCreation(BaseModel): + name: str + expires_at: datetime + + +class ApiKey(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + name: str + expires_at: datetime + obfuscated: bool = True + value: str + + @field_validator("expires_at") + @classmethod + def _set_utc_timezone(cls, v: datetime) -> datetime: + if v.tzinfo is None: + return v.replace(tzinfo=timezone.utc) + else: + return v.astimezone(timezone.utc) + + @computed_field # type: ignore[prop-decorator] + @property + def expired(self) -> bool: + return datetime.now(timezone.utc) >= self.expires_at + + @field_validator("value") + @classmethod + def _maybe_obfuscate(cls, v: str, info: ValidationInfo) -> str: + if not info.data["obfuscated"]: + return v + + i = min(len(v) // 6, 3) + if i > 0: + return f"{v[:i]}***{v[-i:]}" + else: + return "***" + + class Components(BaseModel): documents: list[str] source_storages: list[dict[str, Any]] @@ -45,7 +89,7 @@ class Message(BaseModel): content: str role: ragna.core.MessageRole sources: list[Source] = Field(default_factory=list) - timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) class ChatCreation(BaseModel): From 4c8b8d4d93403375f64f31a3398af8fb65a21bc3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 4 Sep 2024 16:06:46 +0200 Subject: [PATCH 08/10] add JupyterhubServerProxyAuth --- ragna/deploy/__init__.py | 3 +- ragna/deploy/_auth.py | 71 +++++++++++++++++++++++++++++----------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/ragna/deploy/__init__.py b/ragna/deploy/__init__.py index 70bf42ce..cdb2ba44 100644 --- a/ragna/deploy/__init__.py +++ b/ragna/deploy/__init__.py @@ -4,12 +4,13 @@ "DummyBasicAuth", "GithubOAuth", "InMemoryKeyValueStore", + "JupyterhubServerProxyAuth", "KeyValueStore", "NoAuth", "RedisKeyValueStore", ] -from ._auth import Auth, DummyBasicAuth, GithubOAuth, NoAuth +from ._auth import Auth, DummyBasicAuth, GithubOAuth, JupyterhubServerProxyAuth, NoAuth from ._config import Config from ._key_value_store import InMemoryKeyValueStore, KeyValueStore, RedisKeyValueStore diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index b81f5bed..62e32827 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -1,7 +1,9 @@ import abc import base64 +import contextlib import json import os +import re import uuid from typing import Annotated, Awaitable, Callable, Optional, Union, cast @@ -191,18 +193,19 @@ def _delete_cookie(self, response: Response) -> None: ) -async def _get_session(request: Request) -> Session: - session = cast(Optional[Session], request.state.session) - if session is None: - raise RagnaException( - "Not authenticated", - http_detail=RagnaException.EVENT, - http_status_code=status.HTTP_401_UNAUTHORIZED, - ) - return session +class SessionAuth: # (OAuth2) + async def __call__(self, request: Request) -> Session: + session = cast(Optional[Session], request.state.session) + if session is None: + raise RagnaException( + "Not authenticated", + http_detail=RagnaException.EVENT, + http_status_code=status.HTTP_401_UNAUTHORIZED, + ) + return session -SessionDependency = Annotated[Session, Depends(_get_session)] +SessionDependency = Annotated[Session, Depends(SessionAuth())] async def _get_user(session: SessionDependency) -> schemas.User: @@ -263,22 +266,26 @@ def login_page(self, request: Request) -> Response: ... def login(self, request: Request) -> Union[schemas.User, Response]: ... -class NoAuth(Auth): - """ - ADDME - """ - +class _AutomaticLoginAuthBase(Auth): def login_page(self, request: Request) -> Response: - # To invoke the login() method below, the client either needs to + # To invoke the Auth.login() method, the client either needs to # - POST /login or # - GET /oauth-callback # Since we cannot instruct a browser to post when sending redirect response, we - # use the OAuth callback endpoint here, although this has nothing to do with - # OAuth. + # use the OAuth callback endpoint here, although this might have nothing to do + # with OAuth. return redirect("/oauth-callback") + +class NoAuth(_AutomaticLoginAuthBase): + """ + ADDME + """ + def login(self, request: Request) -> schemas.User: - return schemas.User(name=request.headers.get("X-User", default_user())) + return schemas.User( + name=request.headers.get("X-Forwarded-User", default_user()) + ) class DummyBasicAuth(Auth): @@ -374,3 +381,29 @@ async def login(self, request: Request) -> Union[schemas.User, Response]: return HTMLResponse("Unauthorized!") return schemas.User(name=user_data["login"]) + + +class JupyterhubServerProxyAuth(_AutomaticLoginAuthBase): + _JUPYTERHUB_ENV_VAR_PATTERN = re.compile(r"JUPYTERHUB_(?P.+)") + + def __init__(self) -> None: + data = {} + for env_var, value in os.environ.items(): + match = self._JUPYTERHUB_ENV_VAR_PATTERN.match(env_var) + if match is None: + continue + + key = match["key"].lower() + with contextlib.suppress(json.JSONDecodeError): + value = json.loads(value) + + data[key] = value + + name = data.pop("user") + if name is None: + raise RagnaException + + self._user = schemas.User(name=name, data=data) + + def login(self, request: Request) -> schemas.User: + return self._user From 4798540f2bc976fed7f0aefac7e5a5157e2cf301 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 11 Dec 2024 11:12:03 +0100 Subject: [PATCH 09/10] remove styling attempts --- ragna/deploy/_templates/basic_auth.css | 9 --------- 1 file changed, 9 deletions(-) diff --git a/ragna/deploy/_templates/basic_auth.css b/ragna/deploy/_templates/basic_auth.css index a515d1a0..80a7b918 100644 --- a/ragna/deploy/_templates/basic_auth.css +++ b/ragna/deploy/_templates/basic_auth.css @@ -4,12 +4,3 @@ flex-direction: column; justify-content: space-between; } - -.form-control { - background-color: gold; - margin: 10px; -} - -fieldset { - border: 1px solid #000; -} From 9dde033480f1fac455c956f26b4cc81ee4ebbfa1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 11 Dec 2024 13:11:50 +0100 Subject: [PATCH 10/10] cleanup --- environment-dev.yml | 2 +- ragna/deploy/_auth.py | 16 ++++++++++------ ragna/deploy/_templates/base.html | 3 +-- ragna/deploy/_ui/left_sidebar.py | 2 ++ ragna/source_storages/_vector_database.py | 2 +- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index b389606e..1684b030 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -1,4 +1,4 @@ -name: ragna-deploy-dev +name: ragna-dev channels: - conda-forge dependencies: diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index bd896674..06492bf4 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Annotated, Awaitable, Callable, Optional, Union, cast import httpx +import panel as pn import pydantic from fastapi import Depends, FastAPI, Request, status from fastapi.responses import HTMLResponse, RedirectResponse, Response @@ -38,6 +39,11 @@ class Session(pydantic.BaseModel): class SessionMiddleware(BaseHTTPMiddleware): + # panel uses cookies to transfer user information (see _cookie_dispatch() below) and + # signs them for security. However, since this happens after our authentication + # check, we can use an arbitrary, hardcoded value here. + _PANEL_COOKIE_SECRET = "ragna" + def __init__( self, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool ) -> None: @@ -48,6 +54,9 @@ def __init__( self._ui = ui self._sessions: KeyValueStore[Session] = config.key_value_store() + if ui: + pn.config.cookie_secret = self._PANEL_COOKIE_SECRET # type: ignore[misc] + _COOKIE_NAME = "ragna" async def dispatch(self, request: Request, call_next: CallNext) -> Response: @@ -97,11 +106,6 @@ async def _api_token_dispatch( request.state.session = session return await call_next(request) - # panel uses cookies to transfer user information (see _cookie_dispatch() below) and - # signs them for security. However, since this happens after our authentication - # check, we can use an arbitrary, hardcoded value here. - PANEL_COOKIE_SECRET = "ragna" - async def _cookie_dispatch( self, request: Request, call_next: CallNext, *, cookie: str ) -> Response: @@ -128,7 +132,7 @@ async def _cookie_dispatch( ( f"{key}=".encode() + create_signed_value( - self.PANEL_COOKIE_SECRET, key, value, version=1 + self._PANEL_COOKIE_SECRET, key, value, version=1 ) ) for key, value in extra_cookies.items() diff --git a/ragna/deploy/_templates/base.html b/ragna/deploy/_templates/base.html index 255044ee..e4f61d54 100644 --- a/ragna/deploy/_templates/base.html +++ b/ragna/deploy/_templates/base.html @@ -37,13 +37,12 @@ background-color: white; padding: 20px; box-shadow: 20px; - background-color: red; } {{ __template_css__ }} -
+
{% block content %}{% endblock %}
diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index 267a3b77..3d45849f 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -14,6 +14,7 @@ class LeftSidebar(pn.viewable.Viewer): def __init__(self, api_wrapper, **params): super().__init__(**params) + self.api_wrapper = api_wrapper self.on_click_chat = None self.on_click_new_chat = None @@ -104,6 +105,7 @@ def __panel__(self): + self.chat_buttons + [ pn.layout.VSpacer(), + pn.pane.HTML(f"user: {self.api_wrapper._user}"), pn.pane.HTML(f"version: {ragna_version}"), # self.footer() ] diff --git a/ragna/source_storages/_vector_database.py b/ragna/source_storages/_vector_database.py index 81ec2df5..3ed5ca35 100644 --- a/ragna/source_storages/_vector_database.py +++ b/ragna/source_storages/_vector_database.py @@ -89,7 +89,7 @@ def _chunk_pages( ): tokens, page_numbers = zip(*window) yield Chunk( - text=self._tokenizer.decode(tokens), # type: ignore[arg-type] + text=self._tokenizer.decode(tokens), page_numbers=list(filter(lambda n: n is not None, page_numbers)) or None, num_tokens=len(tokens),