diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 09a39abd..846897e6 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -107,29 +107,30 @@ def answer(self, messages): config = Config(assistants=[DemoStreamingAssistant]) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # Start and prepare the chat chat = ( client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": source_storages.RagnaDemoSourceStorage.display_name(), "assistant": DemoStreamingAssistant.display_name(), - "params": {}, }, ) .raise_for_status() .json() ) -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() # %% # Streaming the response is performed with [JSONL](https://jsonlines.org/). Each line @@ -140,7 +141,7 @@ def answer(self, messages): with client.stream( "POST", - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?", "stream": True}, ) as response: chunks = [json.loads(data) for data in response.iter_lines()] @@ -163,7 +164,8 @@ def answer(self, messages): print("".join(chunk["content"] for chunk in chunks)) # %% -# Before we close the example, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it with the `ragna api` command. +# Before we close the example, let's terminate the REST API and have a look at what +# would have printed in the terminal if we had started it with the `ragna deploy` +# command. -rest_api.stop() +ragna_deploy.terminate() diff --git a/docs/references/config.md b/docs/references/config.md index c6ed18b9..ba83263a 100644 --- a/docs/references/config.md +++ b/docs/references/config.md @@ -69,9 +69,9 @@ is equivalent to `RAGNA_API_ORIGINS='["http://localhost:31477"]'`. Local root directory Ragna uses for storing files. See [ragna.local_root][]. -### `authentication` +### `auth` -[ragna.deploy.Authentication][] class to use for authenticating users. +[ragna.deploy.Auth][] class to use for authenticating users. ### `document` @@ -85,48 +85,26 @@ Local root directory Ragna uses for storing files. See [ragna.local_root][]. [ragna.core.Assistant][]s to be available for the user to use. -### `api` - -#### `hostname` +### `hostname` Hostname the REST API will be bound to. -#### `port` +### `port` Port the REST API will be bound to. -#### `root_path` +### `root_path` A path prefix handled by a proxy that is not seen by the REST API, but is seen by external clients. -#### `url` - -URL of the REST API to be accessed by the web UI. Make sure to include the -[`root_path`](#root_path) if set. - -#### `origins` +### `origins` [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) origins that are allowed -to connect to the REST API. The URL of the web UI is required for it to function. +to connect to the REST API. -#### `database_url` +### `database_url` URL of a SQL database that will be used to store the Ragna state. See [SQLAlchemy documentation](https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls) on how to format the URL. - -### `ui` - -#### `hostname` - -Hostname the web UI will be bound to. - -#### `port` - -Port the web UI will be bound to. - -#### `origins` - -[CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) origins that are allowed -to connect to the web UI. diff --git a/docs/references/rest-api.md b/docs/references/deploy.md similarity index 91% rename from docs/references/rest-api.md rename to docs/references/deploy.md index bfd77544..6f39d899 100644 --- a/docs/references/rest-api.md +++ b/docs/references/deploy.md @@ -1,7 +1,7 @@ # REST API reference diff --git a/docs/references/release-notes.md b/docs/references/release-notes.md index 17457078..40ab7ff6 100644 --- a/docs/references/release-notes.md +++ b/docs/references/release-notes.md @@ -137,9 +137,9 @@ -- The classes [ragna.deploy.Authentication][], [ragna.deploy.RagnaDemoAuthentication][], - and [ragna.deploy.Config][] moved from the [ragna.core][] module to a new - [ragna.deploy][] module. +- The classes `ragna.deploy.Authentication`, `ragna.deploy.RagnaDemoAuthentication`, and + [ragna.deploy.Config][] moved from the [ragna.core][] module to a new [ragna.deploy][] + module. - [ragna.core.Component][], which is the superclass for [ragna.core.Assistant][] and [ragna.core.SourceStorage][], no longer takes a [ragna.deploy.Config][] to instantiate. For example diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 8a411f81..6c556043 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -186,9 +186,11 @@ def answer(self, messages: list[Message]) -> Iterator[str]: assistants=[TutorialAssistant], ) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # To select our custom components, we pass their display names to the chat creation. @@ -201,10 +203,10 @@ def answer(self, messages: list[Message]) -> Iterator[str]: import json response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": TutorialAssistant.display_name(), "params": {}, @@ -212,10 +214,10 @@ def answer(self, messages: list[Message]) -> Iterator[str]: ).raise_for_status() chat = response.json() -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -225,7 +227,7 @@ def answer(self, messages: list[Message]) -> Iterator[str]: # Let's stop the REST API and have a look at what would have printed in the terminal if # we had started it with the `ragna api` command. -rest_api.stop() +ragna_deploy.terminate() # %% # ### Web UI @@ -263,9 +265,7 @@ def answer( my_optional_parameter: str = "foo", ) -> Iterator[str]: print(f"Running {type(self).__name__}().answer()") - yield ( - f"I was given {my_required_parameter=} and {my_optional_parameter=}." - ) + yield f"I was given {my_required_parameter=} and {my_optional_parameter=}." # %% @@ -319,19 +319,21 @@ def answer( assistants=[ElaborateTutorialAssistant], ) -rest_api = ragna_docs.RestApi() +ragna_deploy = ragna_docs.RagnaDeploy(config) -client, document = rest_api.start(config, authenticate=True, upload_document=True) +client, document = ragna_deploy.get_http_client( + authenticate=True, upload_document=True +) # %% # To pass custom parameters, define them in the `params` mapping when creating a new # chat. response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"]], "source_storage": TutorialSourceStorage.display_name(), "assistant": ElaborateTutorialAssistant.display_name(), "params": { @@ -344,10 +346,10 @@ def answer( # %% -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -357,7 +359,7 @@ def answer( # Let's stop the REST API and have a look at what would have printed in the terminal if # we had started it with the `ragna api` command. -rest_api.stop() +ragna_deploy.terminate() # %% # ### Web UI diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index befcbfb3..ede8833d 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -3,9 +3,9 @@ Ragna was designed to help you quickly build custom RAG powered web applications. For this you can leverage the built-in -[REST API](../../references/rest-api.md). +[REST API](../../references/deploy.md). -This tutorial walks you through basic steps of using Ragnas REST API. +This tutorial walks you through basic steps of using Ragna's REST API. """ # %% @@ -39,42 +39,39 @@ config = Config() -rest_api = ragna_docs.RestApi() -_ = rest_api.start(config) +ragna_deploy = ragna_docs.RagnaDeploy(config=config) # %% # Let's make sure the REST API is started correctly and can be reached. import httpx -client = httpx.Client(base_url=config.api.url) -client.get("/").raise_for_status() +client = httpx.Client(base_url=f"http://{config.hostname}:{config.port}") +client.get("/health").raise_for_status() # %% # ## Step 2: Authentication # -# In order to use Ragnas REST API, we need to authenticate first. To forge an API token -# we send a request to the `/token` endpoint. This is processed by the -# [`Authentication`][ragna.deploy.Authentication], which can be overridden through the -# config. For this tutorial, we use the default -# [ragna.deploy.RagnaDemoAuthentication][], which requires a matching username and -# password. +# In order to use Ragna's REST API, we need to authenticate first. This is handled by +# the [ragna.deploy.Auth][] class, which can be overridden through the config. By +# default, [ragna.deploy.NoAuth][] is used. By hitting the `/login` endpoint, we get a +# session cookie, which is later used to authorize our requests. -username = password = "Ragna" - -response = client.post( - "/token", - data={"username": username, "password": password}, -).raise_for_status() -token = response.json() +client.get("/login", follow_redirects=True) +dict(client.cookies) # %% -# We set the API token on our HTTP client so we don't have to manually supply it for -# each request below. - -client.headers["Authorization"] = f"Bearer {token}" - +# !!! note +# +# In a regular deployment, you'll have login through your browser and create an API +# key in your profile page. The API key is used as +# [bearer token](https://swagger.io/docs/specification/authentication/bearer-authentication/) +# and can be set with +# +# ```python +# httpx.Client(..., headers={"Authorization": f"Bearer {RAGNA_API_KEY}"}) +# ``` # %% # ## Step 3: Uploading documents @@ -84,7 +81,7 @@ import json -response = client.get("/components").raise_for_status() +response = client.get("/api/components").raise_for_status() print(json.dumps(response.json(), indent=2)) # %% @@ -102,38 +99,28 @@ # %% # The upload process in Ragna consists of two parts: # -# 1. Announce the file to be uploaded. Under the hood this pre-registers the document -# in Ragnas database and returns information about how the upload is to be performed. -# This is handled by the [ragna.core.Document][] class. By default, -# [ragna.core.LocalDocument][] is used, which uploads the files to the local file -# system. -# 2. Perform the actual upload with the information from step 1. +# 1. Announce the file to be uploaded. Under the hood this registers the document +# in Ragna's database and returns the document ID, which is needed for the upload. response = client.post( - "/document", json={"name": document_path.name} + "/api/documents", json=[{"name": document_path.name}] ).raise_for_status() -document_upload = response.json() -print(json.dumps(response.json(), indent=2)) +documents = response.json() +print(json.dumps(documents, indent=2)) # %% -# The returned JSON contains two parts: the document object that we are later going to -# use to create a chat as well as the upload parameters. -# !!! note +# 2. Perform the actual upload with the information from step 1. through a +# [multipart request](https://swagger.io/docs/specification/describing-request-body/multipart-requests/) +# with the following parameters: # -# The `"token"` in the response is *not* the Ragna REST API token, but rather a -# separate one to perform the document upload. -# -# We perform the actual upload with the latter now. - -document = document_upload["document"] +# - The field is `documents` for all entries +# - The field name is the ID of the document returned by step 1. +# - The field value is the binary content of the document. -parameters = document_upload["parameters"] -client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": open(document_path, "rb")}, -).raise_for_status() +client.put( + "/api/documents", + files=[("documents", (documents[0]["id"], open(document_path, "rb")))], +) # %% # ## Step 4: Select a source storage and assistant @@ -155,13 +142,12 @@ # be used, we can create a new chat. response = client.post( - "/chats", + "/api/chats", json={ "name": "Tutorial REST API", - "documents": [document], + "document_ids": [document["id"] for document in documents], "source_storage": source_storage, "assistant": assistant, - "params": {}, }, ).raise_for_status() chat = response.json() @@ -171,13 +157,13 @@ # As can be seen by the `"prepared"` field in the `chat` JSON object we still need to # prepare it. -client.post(f"/chats/{chat['id']}/prepare").raise_for_status() +client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() # %% # Finally, we can get answers to our questions. response = client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna?"}, ).raise_for_status() answer = response.json() @@ -188,7 +174,8 @@ print(answer["content"]) # %% -# Before we close the tutorial, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it with the `ragna api` command. +# Before we close the tutorial, let's terminate the REST API and have a look at what +# would have printed in the terminal if we had started it with the `ragna deploy` +# command. -rest_api.stop() +ragna_deploy.terminate() diff --git a/environment-dev.yml b/environment-dev.yml index ae7ffe0c..1684b030 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -1,4 +1,4 @@ -name: ragna-deploy-dev +name: ragna-dev channels: - conda-forge dependencies: @@ -15,6 +15,7 @@ dependencies: - pytest-asyncio - pytest-playwright - mypy ==1.10.0 + - types-redis - pre-commit - types-aiofiles - sqlalchemy-stubs diff --git a/mkdocs.yml b/mkdocs.yml index a31d6ffe..787bae60 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -91,7 +91,7 @@ nav: - community/contribute.md - References: - references/python-api.md - - references/rest-api.md + - references/deploy.md - references/cli.md - references/config.md - references/faq.md diff --git a/pyproject.toml b/pyproject.toml index ad841dcf..83594126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,9 +172,12 @@ disable_error_code = [ ] [[tool.mypy.overrides]] -# It is a fundamental feature of the components to request more parameters than the base -# class. Thus, we just silence mypy here. +# 1. We automatically handle user-defined sync and async methods +# 2. It is a fundamental feature of the RAG components to request more parameters than +# the base class. +# Thus, we just silence mypy where it would complain about the points above. module = [ + "ragna.deploy._auth", "ragna.source_storages.*", "ragna.assistants.*" ] diff --git a/ragna/_docs.py b/ragna/_docs.py index 0d6191d4..d06fd215 100644 --- a/ragna/_docs.py +++ b/ragna/_docs.py @@ -11,11 +11,12 @@ import httpx -from ragna._utils import timeout_after from ragna.core import RagnaException from ragna.deploy import Config -__all__ = ["SAMPLE_CONTENT", "RestApi"] +from ._utils import BackgroundSubprocess + +__all__ = ["SAMPLE_CONTENT", "RagnaDeploy"] SAMPLE_CONTENT = """\ Ragna is an open source project built by Quansight. It is designed to allow @@ -29,51 +30,25 @@ """ -class RestApi: - def __init__(self) -> None: - self._process: Optional[subprocess.Popen] = None - # In case the documentation errors before we call RestApi.stop, we still need to - # stop the server to avoid zombie processes - atexit.register(self.stop, quiet=True) - - def start( - self, - config: Config, - *, - authenticate: bool = False, - upload_document: bool = False, - ) -> tuple[httpx.Client, Optional[dict]]: - if upload_document and not authenticate: - raise RagnaException( - "Cannot upload a document without authenticating first. " - "Set authenticate=True when using upload_document=True." - ) - python_path, config_path = self._prepare_config(config) - - client = httpx.Client(base_url=config.api.url) - - self._process = self._start_api(config_path, python_path, client) +class RagnaDeploy: + def __init__(self, config: Config) -> None: + self.config = config + python_path, config_path = self._prepare_config() + self._process = self._deploy(config, config_path, python_path) + # In case the documentation errors before we call RagnaDeploy.terminate, + # we still need to stop the server to avoid zombie processes + atexit.register(self.terminate, quiet=True) - if authenticate: - self._authenticate(client) - - if upload_document: - document = self._upload_document(client) - else: - document = None - - return client, document - - def _prepare_config(self, config: Config) -> tuple[str, str]: + def _prepare_config(self) -> tuple[str, str]: deploy_directory = Path(tempfile.mkdtemp()) - python_path = ( - f"{deploy_directory}{os.pathsep}{os.environ.get('PYTHONPATH', '')}" + python_path = os.pathsep.join( + [str(deploy_directory), os.environ.get("PYTHONPATH", "")] ) config_path = str(deploy_directory / "ragna.toml") - config.local_root = deploy_directory - config.api.database_url = f"sqlite:///{deploy_directory / 'ragna.db'}" + self.config.local_root = deploy_directory + self.config.database_url = f"sqlite:///{deploy_directory / 'ragna.db'}" sys.modules["__main__"].__file__ = inspect.getouterframes( inspect.currentframe() @@ -88,98 +63,92 @@ def _prepare_config(self, config: Config) -> tuple[str, str]: file.write("from ragna import *\n") file.write("from ragna.core import *\n") - for component in itertools.chain(config.source_storages, config.assistants): + for component in itertools.chain( + self.config.source_storages, self.config.assistants + ): if component.__module__ == "__main__": custom_components.add(component) file.write(f"{textwrap.dedent(inspect.getsource(component))}\n\n") component.__module__ = custom_module - config.to_file(config_path) + self.config.to_file(config_path) for component in custom_components: component.__module__ = "__main__" return python_path, config_path - def _start_api( - self, config_path: str, python_path: str, client: httpx.Client - ) -> subprocess.Popen: + def _deploy( + self, config: Config, config_path: str, python_path: str + ) -> BackgroundSubprocess: env = os.environ.copy() env["PYTHONPATH"] = python_path - process = subprocess.Popen( - [sys.executable, "-m", "ragna", "api", "--config", config_path], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) - - def check_api_available() -> bool: + def startup_fn() -> bool: try: - return client.get("/").is_success + return httpx.get(f"{config._url}/health").is_success except httpx.ConnectError: return False - failure_message = "Failed to the start the Ragna REST API." - - @timeout_after(60, message=failure_message) - def wait_for_api() -> None: - print("Starting Ragna REST API") - while not check_api_available(): - try: - stdout, stderr = process.communicate(timeout=1) - except subprocess.TimeoutExpired: - print(".", end="") - continue - else: - parts = [failure_message] - if stdout: - parts.append(f"\n\nSTDOUT:\n\n{stdout.decode()}") - if stderr: - parts.append(f"\n\nSTDERR:\n\n{stderr.decode()}") - - raise RuntimeError("".join(parts)) - - print() - - wait_for_api() - return process - - def _authenticate(self, client: httpx.Client) -> None: - username = password = "Ragna" + if startup_fn(): + raise RagnaException("ragna server is already running") + + return BackgroundSubprocess( + sys.executable, + "-m", + "ragna", + "deploy", + "--api", + "--no-ui", + "--config", + config_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + startup_fn=startup_fn, + startup_timeout=60, + ) - response = client.post( - "/token", - data={"username": username, "password": password}, - ).raise_for_status() - token = response.json() + def get_http_client( + self, + *, + authenticate: bool = False, + upload_document: bool = False, + ) -> tuple[httpx.Client, Optional[dict[str, Any]]]: + if upload_document and not authenticate: + raise RagnaException( + "Cannot upload a document without authenticating first. " + "Set authenticate=True when using upload_document=True." + ) - client.headers["Authorization"] = f"Bearer {token}" + client = httpx.Client(base_url=self.config._url) - def _upload_document(self, client: httpx.Client) -> dict[str, Any]: - name, content = "ragna.txt", SAMPLE_CONTENT + if authenticate: + client.get("/login", follow_redirects=True) - response = client.post("/document", json={"name": name}).raise_for_status() - document_upload = response.json() + if upload_document: + name, content = "ragna.txt", SAMPLE_CONTENT - document = cast(dict[str, Any], document_upload["document"]) + response = client.post( + "/api/documents", json=[{"name": name}] + ).raise_for_status() + document = cast(dict[str, Any], response.json()[0]) - parameters = document_upload["parameters"] - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": content}, - ).raise_for_status() + client.put( + "/api/documents", + files=[("documents", (document["id"], content.encode()))], + ) + else: + document = None - return document + return client, document - def stop(self, *, quiet: bool = False) -> None: + def terminate(self, quiet: bool = False) -> None: if self._process is None: return - self._process.terminate() - stdout, _ = self._process.communicate() + output = self._process.terminate() - if not quiet: - print(stdout.decode()) + if output and not quiet: + stdout, _ = output + print(stdout) diff --git a/ragna/_utils.py b/ragna/_utils.py index 6ef5eb5c..32bb24c5 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -1,10 +1,31 @@ +from __future__ import annotations + +import contextlib import functools +import getpass import inspect import os +import shlex +import subprocess import sys import threading +import time from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + TypeVar, + Union, + cast, +) + +from starlette.concurrency import iterate_in_threadpool, run_in_threadpool + +T = TypeVar("T") _LOCAL_ROOT = ( Path(os.environ.get("RAGNA_LOCAL_ROOT", "~/.cache/ragna")).expanduser().resolve() @@ -110,3 +131,88 @@ def is_debugging() -> bool: if any(part.startswith(name) for part in parts): return True return False + + +def as_awaitable( + fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any, **kwargs: Any +) -> Awaitable[T]: + if inspect.iscoroutinefunction(fn): + fn = cast(Callable[..., Awaitable[T]], fn) + awaitable = fn(*args, **kwargs) + else: + fn = cast(Callable[..., T], fn) + awaitable = run_in_threadpool(fn, *args, **kwargs) + + return awaitable + + +def as_async_iterator( + fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], + *args: Any, + **kwargs: Any, +) -> AsyncIterator[T]: + if inspect.isasyncgenfunction(fn): + fn = cast(Callable[..., AsyncIterator[T]], fn) + async_iterator = fn(*args, **kwargs) + else: + fn = cast(Callable[..., Iterator[T]], fn) + async_iterator = iterate_in_threadpool(fn(*args, **kwargs)) + + return async_iterator + + +def default_user() -> str: + with contextlib.suppress(Exception): + return getpass.getuser() + with contextlib.suppress(Exception): + return os.getlogin() + return "Bodil" + + +class BackgroundSubprocess: + def __init__( + self, + *cmd: str, + stdout: Any = sys.stdout, + stderr: Any = sys.stdout, + text: bool = True, + startup_fn: Optional[Callable[[], bool]] = None, + startup_timeout: float = 10, + terminate_timeout: float = 10, + **subprocess_kwargs: Any, + ) -> None: + self._process = subprocess.Popen( + cmd, stdout=stdout, stderr=stderr, **subprocess_kwargs + ) + try: + if startup_fn: + + @timeout_after(startup_timeout, message=shlex.join(cmd)) + def wait() -> None: + while not startup_fn(): + time.sleep(0.2) + + wait() + except Exception: + self.terminate() + raise + + self._terminate_timeout = terminate_timeout + + def terminate(self) -> tuple[str, str]: + @timeout_after(self._terminate_timeout) + def terminate() -> tuple[str, str]: + self._process.terminate() + return self._process.communicate() + + try: + return terminate() # type: ignore[no-any-return] + except TimeoutError: + self._process.kill() + return self._process.communicate() + + def __enter__(self) -> BackgroundSubprocess: + return self + + def __exit__(self, *exc_info: Any) -> None: + self.terminate() diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index 44449775..1cdbc667 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -4,7 +4,6 @@ "Component", "Document", "DocumentHandler", - "DocumentUploadParameters", "DocxDocumentHandler", "PptxDocumentHandler", "EnvVarRequirement", diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index d963c15b..be5282aa 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -2,7 +2,6 @@ import contextlib import datetime -import inspect import itertools import uuid from collections import defaultdict @@ -24,11 +23,12 @@ import pydantic import pydantic_core from fastapi import status -from starlette.concurrency import iterate_in_threadpool, run_in_threadpool + +from ragna._utils import as_async_iterator, as_awaitable, default_user from ._components import Assistant, Component, Message, MessageRole, SourceStorage from ._document import Document, LocalDocument -from ._utils import RagnaException, default_user, merge_models +from ._utils import RagnaException, merge_models if TYPE_CHECKING: from ragna.deploy import Config @@ -145,7 +145,6 @@ def chat( Args: documents: Documents to use. If any item is not a [ragna.core.Document][], [ragna.core.LocalDocument.from_path][] is invoked on it. - FIXME source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. @@ -153,8 +152,8 @@ def chat( return Chat( self, documents=documents, - source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] - assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] + source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] + assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] **params, ) @@ -241,11 +240,11 @@ async def prepare(self) -> Message: raise RagnaException( "Chat is already prepared", chat=self, - http_status_code=400, + http_status_code=status.HTTP_400_BAD_REQUEST, detail=RagnaException.EVENT, ) - await self._run(self.source_storage.store, self.documents) + await self._as_awaitable(self.source_storage.store, self.documents) self._prepared = True welcome = Message( @@ -269,17 +268,21 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: raise RagnaException( "Chat is not prepared", chat=self, - http_status_code=400, + http_status_code=status.HTTP_400_BAD_REQUEST, detail=RagnaException.EVENT, ) - sources = await self._run(self.source_storage.retrieve, self.documents, prompt) + sources = await self._as_awaitable( + self.source_storage.retrieve, self.documents, prompt + ) question = Message(content=prompt, role=MessageRole.USER, sources=sources) self._messages.append(question) answer = Message( - content=self._run_gen(self.assistant.answer, self._messages.copy()), + content=self._as_async_iterator( + self.assistant.answer, self._messages.copy() + ), role=MessageRole.ASSISTANT, sources=sources, ) @@ -361,7 +364,7 @@ def format_error( formatted_error = f"- {param}" if annotation: annotation_ = cast( - type, model_cls.model_fields[param].annotation + type, model_cls.__pydantic_fields__[param].annotation ).__name__ formatted_error += f": {annotation_}" @@ -417,34 +420,17 @@ def format_error( raise RagnaException("\n".join(parts)) - async def _run( + def _as_awaitable( self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any - ) -> T: - kwargs = self._unpacked_params[fn] - if inspect.iscoroutinefunction(fn): - fn = cast(Callable[..., Awaitable[T]], fn) - coro = fn(*args, **kwargs) - else: - fn = cast(Callable[..., T], fn) - coro = run_in_threadpool(fn, *args, **kwargs) + ) -> Awaitable[T]: + return as_awaitable(fn, *args, **self._unpacked_params[fn]) - return await coro - - async def _run_gen( + def _as_async_iterator( self, fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], *args: Any, ) -> AsyncIterator[T]: - kwargs = self._unpacked_params[fn] - if inspect.isasyncgenfunction(fn): - fn = cast(Callable[..., AsyncIterator[T]], fn) - async_gen = fn(*args, **kwargs) - else: - fn = cast(Callable[..., Iterator[T]], fn) - async_gen = iterate_in_threadpool(fn(*args, **kwargs)) - - async for item in async_gen: - yield item + return as_async_iterator(fn, *args, **self._unpacked_params[fn]) async def __aenter__(self) -> Chat: await self.prepare() diff --git a/ragna/core/_utils.py b/ragna/core/_utils.py index 972b0926..34ac2e7d 100644 --- a/ragna/core/_utils.py +++ b/ragna/core/_utils.py @@ -1,15 +1,13 @@ from __future__ import annotations import abc -import contextlib import enum import functools -import getpass import importlib import importlib.metadata import os from collections import defaultdict -from typing import Any, Collection, Optional, Type, Union, cast +from typing import Any, Callable, Collection, Optional, Type, Union, cast import packaging.requirements import pydantic @@ -121,14 +119,6 @@ def __repr__(self) -> str: return self._name -def default_user() -> str: - with contextlib.suppress(Exception): - return getpass.getuser() - with contextlib.suppress(Exception): - return os.getlogin() - return "Ragna" - - def merge_models( model_name: str, *models: Type[pydantic.BaseModel], @@ -136,14 +126,14 @@ def merge_models( ) -> Type[pydantic.BaseModel]: raw_field_definitions = defaultdict(list) for model_cls in models: - for name, field in model_cls.model_fields.items(): + for name, field in model_cls.__pydantic_fields__.items(): type_ = field.annotation default: Any if field.is_required(): default = ... elif field.default is pydantic_core.PydanticUndefined: - default = field.default_factory() # type: ignore[misc] + default = cast(Callable[[], Any], field.default_factory)() else: default = field.default diff --git a/ragna/deploy/__init__.py b/ragna/deploy/__init__.py index f3a86255..cdb2ba44 100644 --- a/ragna/deploy/__init__.py +++ b/ragna/deploy/__init__.py @@ -1,11 +1,18 @@ __all__ = [ - "Authentication", + "Auth", "Config", - "RagnaDemoAuthentication", + "DummyBasicAuth", + "GithubOAuth", + "InMemoryKeyValueStore", + "JupyterhubServerProxyAuth", + "KeyValueStore", + "NoAuth", + "RedisKeyValueStore", ] -from ._authentication import Authentication, RagnaDemoAuthentication +from ._auth import Auth, DummyBasicAuth, GithubOAuth, JupyterhubServerProxyAuth, NoAuth from ._config import Config +from ._key_value_store import InMemoryKeyValueStore, KeyValueStore, RedisKeyValueStore # isort: split diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 242730f6..788bfc38 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -2,34 +2,23 @@ from typing import Annotated, AsyncIterator import pydantic -from fastapi import ( - APIRouter, - Body, - Depends, - UploadFile, -) +from fastapi import APIRouter, Body, UploadFile from fastapi.responses import StreamingResponse -from ragna.core._utils import default_user - from . import _schemas as schemas +from ._auth import UserDependency from ._engine import Engine def make_router(engine: Engine) -> APIRouter: router = APIRouter(tags=["API"]) - def get_user() -> str: - return default_user() - - UserDependency = Annotated[str, Depends(get_user)] - @router.post("/documents") def register_documents( user: UserDependency, document_registrations: list[schemas.DocumentRegistration] ) -> list[schemas.Document]: return engine.register_documents( - user=user, document_registrations=document_registrations + user=user.name, document_registrations=document_registrations ) @router.put("/documents") @@ -44,7 +33,7 @@ async def content_stream() -> AsyncIterator[bytes]: return content_stream() await engine.store_documents( - user=user, + user=user.name, ids_and_streams=[ (uuid.UUID(document.filename), make_content_stream(document)) for document in documents @@ -60,19 +49,19 @@ async def create_chat( user: UserDependency, chat_creation: schemas.ChatCreation, ) -> schemas.Chat: - return engine.create_chat(user=user, chat_creation=chat_creation) + return engine.create_chat(user=user.name, chat_creation=chat_creation) @router.get("/chats") async def get_chats(user: UserDependency) -> list[schemas.Chat]: - return engine.get_chats(user=user) + return engine.get_chats(user=user.name) @router.get("/chats/{id}") async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: - return engine.get_chat(user=user, id=id) + return engine.get_chat(user=user.name, id=id) @router.post("/chats/{id}/prepare") async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: - return await engine.prepare_chat(user=user, id=id) + return await engine.prepare_chat(user=user.name, id=id) @router.post("/chats/{id}/answer") async def answer( @@ -81,7 +70,7 @@ async def answer( prompt: Annotated[str, Body(..., embed=True)], stream: Annotated[bool, Body(..., embed=True)] = False, ) -> schemas.Message: - message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt) + message_stream = engine.answer_stream(user=user.name, chat_id=id, prompt=prompt) answer = await anext(message_stream) if not stream: @@ -106,6 +95,6 @@ async def to_jsonl( @router.delete("/chats/{id}") async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: - engine.delete_chat(user=user, id=id) + engine.delete_chat(user=user.name, id=id) return router diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py new file mode 100644 index 00000000..06492bf4 --- /dev/null +++ b/ragna/deploy/_auth.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import abc +import base64 +import contextlib +import json +import os +import re +import uuid +from typing import TYPE_CHECKING, Annotated, Awaitable, Callable, Optional, Union, cast + +import httpx +import panel as pn +import pydantic +from fastapi import Depends, FastAPI, Request, status +from fastapi.responses import HTMLResponse, RedirectResponse, Response +from fastapi.security.utils import get_authorization_scheme_param +from starlette.middleware.base import BaseHTTPMiddleware +from tornado.web import create_signed_value + +from ragna._utils import as_awaitable, default_user +from ragna.core import RagnaException + +from . import _schemas as schemas +from . import _templates as templates +from ._utils import redirect + +if TYPE_CHECKING: + from ._config import Config + from ._engine import Engine + from ._key_value_store import KeyValueStore + + +class Session(pydantic.BaseModel): + user: schemas.User + + +CallNext = Callable[[Request], Awaitable[Response]] + + +class SessionMiddleware(BaseHTTPMiddleware): + # panel uses cookies to transfer user information (see _cookie_dispatch() below) and + # signs them for security. However, since this happens after our authentication + # check, we can use an arbitrary, hardcoded value here. + _PANEL_COOKIE_SECRET = "ragna" + + def __init__( + self, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool + ) -> None: + super().__init__(app) + self._config = config + self._engine = engine + self._api = api + self._ui = ui + self._sessions: KeyValueStore[Session] = config.key_value_store() + + if ui: + pn.config.cookie_secret = self._PANEL_COOKIE_SECRET # type: ignore[misc] + + _COOKIE_NAME = "ragna" + + async def dispatch(self, request: Request, call_next: CallNext) -> Response: + if (authorization := request.headers.get("Authorization")) is not None: + return await self._api_token_dispatch( + request, call_next, authorization=authorization + ) + elif (cookie := request.cookies.get(self._COOKIE_NAME)) is not None: + return await self._cookie_dispatch(request, call_next, cookie=cookie) + elif request.url.path in {"/login", "/oauth-callback"}: + return await self._login_dispatch(request, call_next) + elif self._api and request.url.path.startswith("/api"): + return self._unauthorized("Missing authorization header") + elif self._ui and request.url.path.startswith("/ui"): + return redirect("/login") + else: + # Either an unknown route or something on the default router. In any case, + # this doesn't need a session and so we let it pass. + request.state.session = None + return await call_next(request) + + async def _api_token_dispatch( + self, request: Request, call_next: CallNext, authorization: str + ) -> Response: + scheme, api_key = get_authorization_scheme_param(authorization) + if scheme.lower() != "bearer": + return self._unauthorized("Bearer authentication scheme required") + + user, expired = self._engine.get_user_by_api_key(api_key) + if user is None or expired: + self._sessions.delete(api_key) + reason = "Invalid" if user is None else "Expired" + return self._unauthorized(f"{reason} API key") + + session = self._sessions.get(api_key) + if session is None: + # First time the API key is used + session = Session(user=user) + # We are using the API key value instead of its ID as session key for two + # reasons: + # 1. Similar to its ID, the value is unique and thus can be safely used as + # key. + # 2. If an API key was deleted, we lose its ID, but still need to be able to + # remove its corresponding session. + self._sessions.set(api_key, session, expires_after=3600) + + request.state.session = session + return await call_next(request) + + async def _cookie_dispatch( + self, request: Request, call_next: CallNext, *, cookie: str + ) -> Response: + session = self._sessions.get(cookie) + response: Response + if session is None: + # Invalid cookie + response = redirect("/login") + self._delete_cookie(response) + return response + + request.state.session = session + if self._ui and request.method == "GET" and request.url.path == "/ui": + # panel.state.user and panel.state.user_info are based on the two cookies + # below that the panel auth flow sets. Since we don't want extra cookies + # just for panel, we just inject them into the scope here, which will be + # parsed by panel down the line. After this initial request, the values are + # tied to the active session and don't have to be set again. + extra_cookies: dict[str, Union[str, bytes]] = { + "user": session.user.name, + "id_token": base64.b64encode(json.dumps(session.user.data).encode()), + } + extra_values = [ + ( + f"{key}=".encode() + + create_signed_value( + self._PANEL_COOKIE_SECRET, key, value, version=1 + ) + ) + for key, value in extra_cookies.items() + ] + + cookie_key = b"cookie" + idx, value = next( + (idx, value) + for idx, (key, value) in enumerate(request.scope["headers"]) + if key == cookie_key + ) + # We are not setting request.cookies or request.headers here, because any + # changes to them are not reflected back to the scope, which is the only + # safe way to transfer data between the middleware and an endpoint. + request.scope["headers"][idx] = ( + cookie_key, + b";".join([value, *extra_values]), + ) + + response = await call_next(request) + + if request.url.path == "/logout": + self._sessions.delete(cookie) + self._delete_cookie(response) + else: + self._sessions.refresh(cookie, expires_after=self._config.session_lifetime) + self._add_cookie(response, cookie) + + return response + + async def _login_dispatch(self, request: Request, call_next: CallNext) -> Response: + request.state.session = None + response = await call_next(request) + session = request.state.session + + if session is not None: + cookie = str(uuid.uuid4()) + self._sessions.set( + cookie, session, expires_after=self._config.session_lifetime + ) + self._add_cookie(response, cookie=cookie) + + return response + + def _unauthorized(self, message: str) -> Response: + return Response( + content=message, + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + ) + + def _add_cookie(self, response: Response, cookie: str) -> None: + response.set_cookie( + key=self._COOKIE_NAME, + value=cookie, + max_age=self._config.session_lifetime, + httponly=True, + samesite="lax", + ) + + def _delete_cookie(self, response: Response) -> None: + response.delete_cookie( + key=self._COOKIE_NAME, + httponly=True, + samesite="lax", + ) + + +async def _get_session(request: Request) -> Session: + session = cast(Optional[Session], request.state.session) + if session is None: + raise RagnaException( + "Not authenticated", + http_detail=RagnaException.EVENT, + http_status_code=status.HTTP_401_UNAUTHORIZED, + ) + return session + + +SessionDependency = Annotated[Session, Depends(_get_session)] + + +async def _get_user(session: SessionDependency) -> schemas.User: + return session.user + + +UserDependency = Annotated[schemas.User, Depends(_get_user)] + + +class Auth(abc.ABC): + """ + ADDME + """ + + @classmethod + def _add_to_app( + cls, app: FastAPI, *, config: Config, engine: Engine, api: bool, ui: bool + ) -> None: + self = cls() + + @app.get("/login", include_in_schema=False) + async def login_page(request: Request) -> Response: + return await as_awaitable(self.login_page, request) + + async def _login(request: Request) -> Response: + result = await as_awaitable(self.login, request) + if not isinstance(result, schemas.User): + return result + + engine.maybe_add_user(result) + request.state.session = Session(user=result) + return redirect("/") + + @app.post("/login", include_in_schema=False) + async def login(request: Request) -> Response: + return await _login(request) + + @app.get("/oauth-callback", include_in_schema=False) + async def oauth_callback(request: Request) -> Response: + return await _login(request) + + @app.get("/logout", include_in_schema=False) + async def logout() -> RedirectResponse: + return redirect("/") + + app.add_middleware( + SessionMiddleware, + config=config, + engine=engine, + api=api, + ui=ui, + ) + + @abc.abstractmethod + def login_page(self, request: Request) -> Response: ... + + @abc.abstractmethod + def login(self, request: Request) -> Union[schemas.User, Response]: ... + + +class _AutomaticLoginAuthBase(Auth): + def login_page(self, request: Request) -> Response: + # To invoke the Auth.login() method, the client either needs to + # - POST /login or + # - GET /oauth-callback + # Since we cannot instruct a browser to post when sending redirect response, we + # use the OAuth callback endpoint here, although this might have nothing to do + # with OAuth. + return redirect("/oauth-callback") + + +class NoAuth(_AutomaticLoginAuthBase): + """ + ADDME + """ + + def login(self, request: Request) -> schemas.User: + return schemas.User( + name=request.headers.get("X-Forwarded-User", default_user()) + ) + + +class DummyBasicAuth(Auth): + """Dummy OAuth2 password authentication without requirements. + + !!! danger + + As the name implies, this authentication is just testing or demo purposes and + should not be used in production. + """ + + def __init__(self) -> None: + self._password = os.environ.get("RAGNA_DUMMY_BASIC_AUTH_PASSWORD") + + def login_page( + self, + request: Request, + *, + username: Optional[str] = None, + fail_reason: Optional[str] = None, + ) -> HTMLResponse: + return HTMLResponse( + templates.render( + "basic_auth.html", username=username, fail_reason=fail_reason + ) + ) + + async def login(self, request: Request) -> Union[schemas.User, Response]: + async with request.form() as form: + username = cast(str, form.get("username")) + password = cast(str, form.get("password")) + + if username is None or password is None: + # This can only happen if the endpoint is not hit through the login page. + # Thus, instead of returning the failed login page like below, we just + # return an error. + raise RagnaException( + "Field 'username' or 'password' is missing from the form data.", + http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + http_detail=RagnaException.MESSAGE, + ) + + if not username: + return self.login_page(request, fail_reason="Username cannot be empty") + elif (self._password is not None and password != self._password) or ( + self._password is None and password != username + ): + return self.login_page( + request, username=username, fail_reason="Password incorrect" + ) + + return schemas.User(name=username) + + +class GithubOAuth(Auth): + def __init__(self) -> None: + # FIXME: requirements + self._client_id = os.environ["RAGNA_GITHUB_OAUTH_CLIENT_ID"] + self._client_secret = os.environ["RAGNA_GITHUB_OAUTH_CLIENT_SECRET"] + + def login_page(self, request: Request) -> HTMLResponse: + return HTMLResponse( + templates.render( + "oauth.html", + service="GitHub", + url=f"https://github.com/login/oauth/authorize?client_id={self._client_id}", + ) + ) + + async def login(self, request: Request) -> Union[schemas.User, Response]: + async with httpx.AsyncClient(headers={"Accept": "application/json"}) as client: + response = await client.post( + "https://github.com/login/oauth/access_token", + json={ + "client_id": self._client_id, + "client_secret": self._client_secret, + "code": request.query_params["code"], + }, + ) + access_token = response.json()["access_token"] + client.headers["Authorization"] = f"Bearer {access_token}" + + user_data = (await client.get("https://api.github.com/user")).json() + + organizations_data = ( + await client.get(user_data["organizations_url"]) + ).json() + organizations = { + organization_data["login"] for organization_data in organizations_data + } + if not (organizations & {"Quansight", "Quansight-Labs"}): + # FIXME: send the login page again with a failure message + return HTMLResponse("Unauthorized!") + + return schemas.User(name=user_data["login"]) + + +class JupyterhubServerProxyAuth(_AutomaticLoginAuthBase): + _JUPYTERHUB_ENV_VAR_PATTERN = re.compile(r"JUPYTERHUB_(?P.+)") + + def __init__(self) -> None: + data = {} + for env_var, value in os.environ.items(): + match = self._JUPYTERHUB_ENV_VAR_PATTERN.match(env_var) + if match is None: + continue + + key = match["key"].lower() + with contextlib.suppress(json.JSONDecodeError): + value = json.loads(value) + + data[key] = value + + name = data.pop("user") + if name is None: + raise RagnaException + + self._user = schemas.User(name=name, data=data) + + def login(self, request: Request) -> schemas.User: + return self._user diff --git a/ragna/deploy/_authentication.py b/ragna/deploy/_authentication.py deleted file mode 100644 index b8a4cbb8..00000000 --- a/ragna/deploy/_authentication.py +++ /dev/null @@ -1,132 +0,0 @@ -import abc -import os -import secrets -import time -from typing import cast - -import jwt -import rich -from fastapi import HTTPException, Request, status -from fastapi.security.utils import get_authorization_scheme_param - - -class Authentication(abc.ABC): - """Abstract base class for authentication used by the REST API.""" - - @abc.abstractmethod - async def create_token(self, request: Request) -> str: - """Authenticate user and create an authorization token. - - Args: - request: Request send to the `/token` endpoint of the REST API. - - Returns: - Authorization token. - """ - pass - - @abc.abstractmethod - async def get_user(self, request: Request) -> str: - """ - Args: - request: Request send to any endpoint of the REST API that requires - authorization. - - Returns: - Authorized user. - """ - pass - - -class RagnaDemoAuthentication(Authentication): - """Demo OAuth2 password authentication without requirements. - - !!! danger - - As the name implies, this authentication is just for demo purposes and should - not be used in production. - """ - - def __init__(self) -> None: - msg = f"INFO:\t{type(self).__name__}: You can log in with any username" - self._password = os.environ.get("RAGNA_DEMO_AUTHENTICATION_PASSWORD") - if self._password is None: - msg = f"{msg} and a matching password." - else: - msg = f"{msg} and the password {self._password}" - rich.print(msg) - - _JWT_SECRET = os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_SECRET", secrets.token_urlsafe(32)[:32] - ) - _JWT_ALGORITHM = "HS256" - _JWT_TTL = int(os.environ.get("RAGNA_DEMO_AUTHENTICATION_TTL", 60 * 60 * 24 * 7)) - - async def create_token(self, request: Request) -> str: - """Authenticate user and create an authorization token. - - User name is arbitrary. Authentication is possible in two ways: - - 1. If the `RAGNA_DEMO_AUTHENTICATION_PASSWORD` environment variable is set, the - password is checked against that. - 2. Otherwise, the password has to match the user name. - - Args: - request: Request send to the `/token` endpoint of the REST API. Must include - the `"username"` and `"password"` as form data. - - Returns: - Authorization [JWT](https://jwt.io/) that expires after one week. - """ - async with request.form() as form: - username = form.get("username") - password = form.get("password") - - if username is None or password is None: - raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY) - - if (self._password is not None and password != self._password) or ( - self._password is None and password != username - ): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - - return jwt.encode( - payload={"user": username, "exp": time.time() + self._JWT_TTL}, - key=self._JWT_SECRET, - algorithm=self._JWT_ALGORITHM, - ) - - async def get_user(self, request: Request) -> str: - """Get user from an authorization token. - - Token has to be supplied in the - [Bearer authentication scheme](https://swagger.io/docs/specification/authentication/bearer-authentication/), - i.e. including a `Authorization: Bearer {token}` header. - - Args: - request: Request send to any endpoint of the REST API that requires - authorization. - - Returns: - Authorized user. - """ - - unauthorized = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"}, - ) - - authorization = request.headers.get("Authorization") - scheme, token = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() != "bearer": - raise unauthorized - - try: - payload = jwt.decode( - token, key=self._JWT_SECRET, algorithms=[self._JWT_ALGORITHM] - ) - except (jwt.InvalidSignatureError, jwt.ExpiredSignatureError): - raise unauthorized - - return cast(str, payload["user"]) diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index e960f831..3c6f46c3 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -2,7 +2,15 @@ import itertools from pathlib import Path -from typing import Annotated, Any, Callable, Generic, Type, TypeVar, Union +from typing import ( + Annotated, + Any, + Callable, + Generic, + Type, + TypeVar, + Union, +) import tomlkit import tomlkit.container @@ -18,7 +26,8 @@ from ragna._utils import make_directory from ragna.core import Assistant, Document, RagnaException, SourceStorage -from ._authentication import Authentication +from ._auth import Auth +from ._key_value_store import KeyValueStore T = TypeVar("T") @@ -79,8 +88,9 @@ def settings_customise_sources( default_factory=ragna.local_root ) - authentication: ImportString[type[Authentication]] = ( - "ragna.deploy.RagnaDemoAuthentication" # type: ignore[assignment] + auth: ImportString[type[Auth]] = "ragna.deploy.NoAuth" # type: ignore[assignment] + key_value_store: ImportString[type[KeyValueStore]] = ( + "ragna.deploy.InMemoryKeyValueStore" # type: ignore[assignment] ) document: ImportString[type[Document]] = "ragna.core.LocalDocument" # type: ignore[assignment] @@ -97,6 +107,7 @@ def settings_customise_sources( origins: list[str] = AfterConfigValidateDefault.make( lambda config: [f"http://{config.hostname}:{config.port}"] ) + session_lifetime: int = 60 * 60 * 24 database_url: str = AfterConfigValidateDefault.make( lambda config: f"sqlite:///{config.local_root}/ragna.db", diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 65cca3cd..64de4a27 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -1,6 +1,7 @@ import contextlib import threading import time +import uuid import webbrowser from pathlib import Path from typing import AsyncContextManager, AsyncIterator, Callable, Optional, cast @@ -15,7 +16,9 @@ import ragna from ragna.core import RagnaException +from . import _schemas as schemas from ._api import make_router as make_api_router +from ._auth import UserDependency from ._config import Config from ._engine import Engine from ._ui import app as make_ui_app @@ -78,6 +81,8 @@ def server_available() -> bool: ignore_unavailable_components=ignore_unavailable_components, ) + config.auth._add_to_app(app, config=config, engine=engine, api=api, ui=ui) + if api: app.include_router(make_api_router(engine), prefix="/api") @@ -103,6 +108,24 @@ async def health() -> Response: async def version() -> str: return ragna.__version__ + @app.get("/user") + async def user(user: UserDependency) -> schemas.User: + return user + + @app.get("/api-keys") + def list_api_keys(user: UserDependency) -> list[schemas.ApiKey]: + return engine.list_api_keys(user=user.name) + + @app.post("/api-keys") + def create_api_key( + user: UserDependency, api_key_creation: schemas.ApiKeyCreation + ) -> schemas.ApiKey: + return engine.create_api_key(user=user.name, api_key_creation=api_key_creation) + + @app.delete("/api-keys/{id}") + def delete_api_key(user: UserDependency, id: uuid.UUID) -> None: + return engine.delete_api_key(user=user.name, id=id) + @app.exception_handler(RagnaException) async def ragna_exception_handler( request: Request, exc: RagnaException diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 529fa3b6..54f7b9f2 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from typing import Any, Collection, Optional +from typing import Any, Collection, Optional, cast from urllib.parse import urlsplit from sqlalchemy import create_engine, select @@ -13,6 +13,14 @@ from . import _schemas as schemas +class UnknownUser(Exception): + def __init__( + self, name: Optional[str] = None, api_key: Optional[str] = None + ) -> None: + self.name = name + self.api_key = api_key + + class Database: def __init__(self, url: str) -> None: components = urlsplit(url) @@ -28,20 +36,88 @@ def __init__(self, url: str) -> None: self._to_orm = SchemaToOrmConverter() self._to_schema = OrmToSchemaConverter() - def _get_user(self, session: Session, *, username: str) -> orm.User: - user: Optional[orm.User] = session.execute( - select(orm.User).where(orm.User.name == username) - ).scalar_one_or_none() + def _get_orm_user_by_name(self, session: Session, *, name: str) -> orm.User: + user = cast( + Optional[orm.User], + session.execute( + select(orm.User).where(orm.User.name == name) + ).scalar_one_or_none(), + ) if user is None: - # Add a new user if the current username is not registered yet. Since this - # is behind the authentication layer, we don't need any extra security here. - user = orm.User(id=uuid.uuid4(), name=username) - session.add(user) - session.commit() + raise UnknownUser(name) return user + def maybe_add_user(self, session: Session, *, user: schemas.User) -> None: + try: + self._get_orm_user_by_name(session, name=user.name) + except UnknownUser: + orm_user = orm.User(id=uuid.uuid4(), name=user.name) + session.add(orm_user) + session.commit() + + def add_api_key( + self, session: Session, *, user: str, api_key: schemas.ApiKey + ) -> None: + user_id = self._get_orm_user_by_name(session, name=user).id + orm_api_key = orm.ApiKey( + id=uuid.uuid4(), + user_id=user_id, + name=api_key.name, + value=api_key.value, + expires_at=api_key.expires_at, + ) + session.add(orm_api_key) + session.commit() + + def get_api_keys(self, session: Session, *, user: str) -> list[schemas.ApiKey]: + return [ + self._to_schema.api_key(api_key) + for api_key in session.execute( + select(orm.ApiKey).where( + orm.ApiKey.user_id + == self._get_orm_user_by_name(session, name=user).id + ) + ) + .scalars() + .all() + ] + + def delete_api_key(self, session: Session, *, user: str, id: uuid.UUID) -> None: + orm_api_key = session.execute( + select(orm.ApiKey).where( + (orm.ApiKey.id == id) + & ( + orm.ApiKey.user_id + == self._get_orm_user_by_name(session, name=user).id + ) + ) + ).scalar_one_or_none() + + if orm_api_key is None: + raise RagnaException + + session.delete(orm_api_key) # type: ignore[no-untyped-call] + session.commit() + + def get_user_by_api_key( + self, session: Session, api_key_value: str + ) -> Optional[tuple[schemas.User, schemas.ApiKey]]: + orm_api_key = session.execute( + select(orm.ApiKey) # type: ignore[attr-defined] + .options(joinedload(orm.ApiKey.user)) + .where(orm.ApiKey.value == api_key_value) + ).scalar_one_or_none() + + if orm_api_key is None: + return None + + return ( + self._to_schema.user(orm_api_key.user), + self._to_schema.api_key(orm_api_key), + ) + def add_documents( self, session: Session, @@ -49,7 +125,7 @@ def add_documents( user: str, documents: list[schemas.Document], ) -> None: - user_id = self._get_user(session, username=user).id + user_id = self._get_orm_user_by_name(session, name=user).id session.add_all( [self._to_orm.document(document, user_id=user_id) for document in documents] ) @@ -82,7 +158,7 @@ def get_documents( def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_user(session, username=user).id + chat, user_id=self._get_orm_user_by_name(session, name=user).id ) # We need to merge and not add here, because the documents are already in the DB session.merge(orm_chat) @@ -102,7 +178,8 @@ def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: self._to_schema.chat(chat) for chat in session.execute( self._select_chat(eager=True).where( - orm.Chat.user_id == self._get_user(session, username=user).id + orm.Chat.user_id + == self._get_orm_user_by_name(session, name=user).id ) ) .scalars() @@ -117,7 +194,10 @@ def _get_orm_chat( session.execute( self._select_chat(eager=eager).where( (orm.Chat.id == id) - & (orm.Chat.user_id == self._get_user(session, username=user).id) + & ( + orm.Chat.user_id + == self._get_orm_user_by_name(session, name=user).id + ) ) ) .unique() @@ -134,7 +214,7 @@ def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Cha def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, user_id=self._get_user(session, username=user).id + chat, user_id=self._get_orm_user_by_name(session, name=user).id ) session.merge(orm_chat) session.commit() @@ -199,6 +279,18 @@ def chat( class OrmToSchemaConverter: + def user(self, user: orm.User) -> schemas.User: + return schemas.User(name=user.name) + + def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey: + return schemas.ApiKey( + id=api_key.id, + name=api_key.name, + expires_at=api_key.expires_at, + obfuscated=True, + value=api_key.value, + ) + def document(self, document: orm.Document) -> schemas.Document: return schemas.Document( id=document.id, name=document.name, metadata=document.metadata_ diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 6df48460..7d668c78 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,16 +1,17 @@ +import secrets import uuid from typing import Any, AsyncIterator, Optional, Type, cast from fastapi import status as http_status_code import ragna -from ragna import Rag, core +from ragna import core from ragna._utils import make_directory -from ragna.core import RagnaException +from ragna.core import Rag, RagnaException from ragna.core._rag import SpecialChatParams -from ragna.deploy import Config from . import _schemas as schemas +from ._config import Config from ._database import Database @@ -33,6 +34,47 @@ def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> No self._to_core = SchemaToCoreConverter(config=self._config, rag=self._rag) self._to_schema = CoreToSchemaConverter() + def maybe_add_user(self, user: schemas.User) -> None: + with self._database.get_session() as session: + return self._database.maybe_add_user(session, user=user) + + def get_user_by_api_key( + self, api_key_value: str + ) -> tuple[Optional[schemas.User], bool]: + with self._database.get_session() as session: + data = self._database.get_user_by_api_key( + session, api_key_value=api_key_value + ) + + if data is None: + return None, False + + user, api_key = data + return user, api_key.expired + + def create_api_key( + self, user: str, api_key_creation: schemas.ApiKeyCreation + ) -> schemas.ApiKey: + api_key = schemas.ApiKey( + name=api_key_creation.name, + expires_at=api_key_creation.expires_at, + obfuscated=False, + value=secrets.token_urlsafe(32)[:32], + ) + + with self._database.get_session() as session: + self._database.add_api_key(session, user=user, api_key=api_key) + + return api_key + + def list_api_keys(self, user: str) -> list[schemas.ApiKey]: + with self._database.get_session() as session: + return self._database.get_api_keys(session, user=user) + + def delete_api_key(self, user: str, id: uuid.UUID) -> None: + with self._database.get_session() as session: + self._database.delete_api_key(session, user=user, id=id) + def _get_component_json_schema( self, component: Type[core.Component], diff --git a/ragna/deploy/_key_value_store.py b/ragna/deploy/_key_value_store.py new file mode 100644 index 00000000..8218aacc --- /dev/null +++ b/ragna/deploy/_key_value_store.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import abc +import os +import time +from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast + +import pydantic + +from ragna.core import PackageRequirement, Requirement +from ragna.core._utils import RequirementsMixin + +M = TypeVar("M", bound=pydantic.BaseModel) + + +class SerializableModel(pydantic.BaseModel, Generic[M]): + cls: pydantic.ImportString[type[M]] + obj: dict[str, Any] + + @classmethod + def from_model(cls, model: M) -> SerializableModel[M]: + return SerializableModel(cls=type(model), obj=model.model_dump(mode="json")) + + def to_model(self) -> M: + return self.cls.model_validate(self.obj) + + +class KeyValueStore(abc.ABC, RequirementsMixin, Generic[M]): + def serialize(self, model: M) -> str: + return SerializableModel.from_model(model).model_dump_json() + + def deserialize(self, data: Union[str, bytes]) -> M: + return SerializableModel.model_validate_json(data).to_model() + + @abc.abstractmethod + def set( + self, key: str, model: M, *, expires_after: Optional[int] = None + ) -> None: ... + + @abc.abstractmethod + def get(self, key: str) -> Optional[M]: ... + + @abc.abstractmethod + def delete(self, key: str) -> None: ... + + @abc.abstractmethod + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: ... + + +class InMemoryKeyValueStore(KeyValueStore[M]): + def __init__(self) -> None: + self._store: dict[str, tuple[M, Optional[float]]] = {} + self._timer: Callable[[], float] = time.monotonic + + def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + if expires_after is not None: + expires_at = self._timer() + expires_after + else: + expires_at = None + self._store[key] = (model, expires_at) + + def get(self, key: str) -> Optional[M]: + value = self._store.get(key) + if value is None: + return None + + model, expires_at = value + if expires_at is not None and self._timer() >= expires_at: + self.delete(key) + return None + + return model + + def delete(self, key: str) -> None: + if key not in self._store: + return + + del self._store[key] + + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + value = self._store.get(key) + if value is None: + return + + model, _ = value + self.set(key, model, expires_after=expires_after) + + +class RedisKeyValueStore(KeyValueStore[M]): + @classmethod + def requirements(cls) -> list[Requirement]: + return [PackageRequirement("redis")] + + def __init__(self) -> None: + import redis + + self._r = redis.Redis( + host=os.environ.get("RAGNA_REDIS_HOST", "localhost"), + port=int(os.environ.get("RAGNA_REDIS_PORT", 6379)), + ) + + def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + self._r.set(key, self.serialize(model), ex=expires_after) + + def get(self, key: str) -> Optional[M]: + value = cast(bytes, self._r.get(key)) + if value is None: + return None + return self.deserialize(value) + + def delete(self, key: str) -> None: + self._r.delete(key) + + def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + if expires_after is None: + self._r.persist(key) + else: + self._r.expire(key, expires_after) diff --git a/ragna/deploy/_orm.py b/ragna/deploy/_orm.py index 74f3560f..3017e660 100644 --- a/ragna/deploy/_orm.py +++ b/ragna/deploy/_orm.py @@ -31,7 +31,7 @@ def process_result_value( return json.loads(value) -class UtcDateTime(types.TypeDecorator): +class UtcAwareDateTime(types.TypeDecorator): """UTC timezone aware datetime type. This is needed because sqlalchemy.types.DateTime(timezone=True) does not @@ -63,14 +63,25 @@ class Base(DeclarativeBase): pass -# FIXME: Do we actually need this table? If we are sure that a user is unique and has to -# be authenticated from the API layer, it seems having an extra mapping here is not -# needed? class User(Base): __tablename__ = "users" id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] + name = Column(types.String, nullable=False, unique=True) + api_keys = relationship("ApiKey", back_populates="user") + + +class ApiKey(Base): + __tablename__ = "api_keys" + + id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] + + user_id = Column(ForeignKey("users.id")) + user = relationship("User", back_populates="api_keys") + name = Column(types.String, nullable=False) + value = Column(types.String, nullable=False, unique=True) + expires_at = Column(UtcAwareDateTime, nullable=False) document_chat_association_table = Table( @@ -163,4 +174,5 @@ class Message(Base): secondary=source_message_association_table, back_populates="messages", ) - timestamp = Column(UtcDateTime, nullable=False) + + timestamp = Column(UtcAwareDateTime, nullable=False) diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 3c080f43..0b73aa23 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -4,11 +4,61 @@ from datetime import datetime, timezone from typing import Annotated, Any -from pydantic import AfterValidator, BaseModel, Field +from pydantic import ( + AfterValidator, + BaseModel, + Field, + ValidationInfo, + computed_field, + field_validator, +) import ragna.core +class User(BaseModel): + name: str + data: dict[str, Any] = Field(default_factory=dict) + + +class ApiKeyCreation(BaseModel): + name: str + expires_at: datetime + + +class ApiKey(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + name: str + expires_at: datetime + obfuscated: bool = True + value: str + + @field_validator("expires_at") + @classmethod + def _set_utc_timezone(cls, v: datetime) -> datetime: + if v.tzinfo is None: + return v.replace(tzinfo=timezone.utc) + else: + return v.astimezone(timezone.utc) + + @computed_field # type: ignore[misc] + @property + def expired(self) -> bool: + return datetime.now(timezone.utc) >= self.expires_at + + @field_validator("value") + @classmethod + def _maybe_obfuscate(cls, v: str, info: ValidationInfo) -> str: + if not info.data["obfuscated"]: + return v + + i = min(len(v) // 6, 3) + if i > 0: + return f"{v[:i]}***{v[-i:]}" + else: + return "***" + + def _set_utc_timezone(v: datetime) -> datetime: if v.tzinfo is None: return v.replace(tzinfo=timezone.utc) diff --git a/ragna/deploy/_templates/__init__.py b/ragna/deploy/_templates/__init__.py new file mode 100644 index 00000000..7c977794 --- /dev/null +++ b/ragna/deploy/_templates/__init__.py @@ -0,0 +1,16 @@ +import contextlib +from pathlib import Path +from typing import Any + +from jinja2 import Environment, FileSystemLoader, TemplateNotFound + +ENVIRONMENT = Environment(loader=FileSystemLoader(Path(__file__).parent)) + + +def render(template: str, **context: Any) -> str: + with contextlib.suppress(TemplateNotFound): + css_template = ENVIRONMENT.get_template(str(Path(template).with_suffix(".css"))) + context["__template_css__"] = css_template.render(**context) + + template = ENVIRONMENT.get_template(template) + return template.render(**context) diff --git a/ragna/deploy/_templates/base.html b/ragna/deploy/_templates/base.html new file mode 100644 index 00000000..e4f61d54 --- /dev/null +++ b/ragna/deploy/_templates/base.html @@ -0,0 +1,49 @@ + + + + + Ragna + + + + + + + + +
+ {% block content %}{% endblock %} +
+ + diff --git a/ragna/deploy/_templates/basic_auth.css b/ragna/deploy/_templates/basic_auth.css new file mode 100644 index 00000000..80a7b918 --- /dev/null +++ b/ragna/deploy/_templates/basic_auth.css @@ -0,0 +1,6 @@ +.basic-auth { + height: 100%; + display: flex; + flex-direction: column; + justify-content: space-between; +} diff --git a/ragna/deploy/_templates/basic_auth.html b/ragna/deploy/_templates/basic_auth.html new file mode 100644 index 00000000..5269334f --- /dev/null +++ b/ragna/deploy/_templates/basic_auth.html @@ -0,0 +1,33 @@ +{% extends "base.html" %} {% block content %} +
+

Log in

+ {% if fail_reason %} +
{{ fail_reason }}
+ {% endif %} +
+
+ Username + +
+
+ Password + +
+
+
+ +
+
+{% endblock %} diff --git a/ragna/deploy/_templates/oauth.html b/ragna/deploy/_templates/oauth.html new file mode 100644 index 00000000..7a21f0cd --- /dev/null +++ b/ragna/deploy/_templates/oauth.html @@ -0,0 +1,7 @@ +{% extends "base.html" %} {% block content %} +
+ + + +
+{% endblock %} diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index 9dcc6b16..b03f2232 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -2,9 +2,9 @@ from datetime import datetime import emoji +import panel as pn import param -from ragna.core._utils import default_user from ragna.deploy import _schemas as schemas from ragna.deploy._engine import Engine @@ -12,7 +12,7 @@ class ApiWrapper(param.Parameterized): def __init__(self, engine: Engine): super().__init__() - self._user = default_user() + self._user = pn.state.user self._engine = engine async def get_chats(self): diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index 267a3b77..3d45849f 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -14,6 +14,7 @@ class LeftSidebar(pn.viewable.Viewer): def __init__(self, api_wrapper, **params): super().__init__(**params) + self.api_wrapper = api_wrapper self.on_click_chat = None self.on_click_new_chat = None @@ -104,6 +105,7 @@ def __panel__(self): + self.chat_buttons + [ pn.layout.VSpacer(), + pn.pane.HTML(f"user: {self.api_wrapper._user}"), pn.pane.HTML(f"version: {ragna_version}"), # self.footer() ] diff --git a/ragna/source_storages/_vector_database.py b/ragna/source_storages/_vector_database.py index 81ec2df5..3ed5ca35 100644 --- a/ragna/source_storages/_vector_database.py +++ b/ragna/source_storages/_vector_database.py @@ -89,7 +89,7 @@ def _chunk_pages( ): tokens, page_numbers = zip(*window) yield Chunk( - text=self._tokenizer.decode(tokens), # type: ignore[arg-type] + text=self._tokenizer.decode(tokens), page_numbers=list(filter(lambda n: n is not None, page_numbers)) or None, num_tokens=len(tokens), diff --git a/scripts/add_chats.py b/scripts/add_chats.py index 5f550289..14d8827a 100644 --- a/scripts/add_chats.py +++ b/scripts/add_chats.py @@ -8,25 +8,14 @@ def main(): client = httpx.Client(base_url="http://127.0.0.1:31476") client.get("/health").raise_for_status() - # ## authentication - # - # username = default_user() - # token = ( - # client.post( - # "/token", - # data={ - # "username": username, - # "password": os.environ.get( - # "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - # ), - # }, - # ) - # .raise_for_status() - # .json() - # ) - # client.headers["Authorization"] = f"Bearer {token}" - - print() + ## authentication + + # This only works if Ragna was deployed with ragna.core.NoAuth + # If that is not the case, login in whatever way is required, grab the API token and + # use the following instead + # client.headers["Authorization"] = f"Bearer {api_token}" + + client.get("/login", follow_redirects=True).raise_for_status() ## documents diff --git a/scripts/docs/gen_files.py b/scripts/docs/gen_files.py index a350f035..42bd6107 100644 --- a/scripts/docs/gen_files.py +++ b/scripts/docs/gen_files.py @@ -7,14 +7,14 @@ import mkdocs_gen_files import typer.rich_utils +from ragna._cli import app as cli_app # noqa: E402 from ragna.deploy import Config -from ragna.deploy._api import app as api_app -from ragna.deploy._cli import app as cli_app +from ragna.deploy._core import make_app as make_deploy_app def main(): cli_reference() - api_reference() + deploy_reference() config_reference() @@ -43,8 +43,14 @@ def get_doc(command): file.write(get_doc(command.name or command.callback.__name__)) -def api_reference(): - app = api_app(config=Config(), ignore_unavailable_components=False) +def deploy_reference(): + app = make_deploy_app( + config=Config(), + api=True, + ui=True, + ignore_unavailable_components=False, + open_browser=False, + ) openapi_json = fastapi.openapi.utils.get_openapi( title=app.title, version=app.version, diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index f7c9c594..3f10089b 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -2,17 +2,16 @@ import itertools import json import os -import time from pathlib import Path import httpx import pytest from ragna import assistants -from ragna._utils import timeout_after +from ragna._utils import BackgroundSubprocess from ragna.assistants._http_api import HttpApiAssistant, HttpStreamingProtocol from ragna.core import Message, RagnaException -from tests.utils import background_subprocess, get_available_port, skip_on_windows +from tests.utils import get_available_port, skip_on_windows HTTP_API_ASSISTANTS = [ assistant @@ -43,26 +42,19 @@ def streaming_server(): port = get_available_port() base_url = f"http://localhost:{port}" - with background_subprocess( + def check_fn(): + try: + return httpx.get(f"{base_url}/health").is_success + except httpx.ConnectError: + return False + + with BackgroundSubprocess( "uvicorn", f"--app-dir={Path(__file__).parent}", f"--port={port}", "streaming_server:app", + startup_fn=check_fn, ): - - def up(): - try: - return httpx.get(f"{base_url}/health").is_success - except httpx.ConnectError: - return False - - @timeout_after(10, message="Failed to start streaming server") - def wait(): - while not up(): - time.sleep(0.2) - - wait() - yield base_url diff --git a/tests/conftest.py b/tests/conftest.py index 6f380be2..4f44218e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import ragna -@pytest.fixture +@pytest.fixture(autouse=True) def tmp_local_root(tmp_path): old = ragna.local_root() try: diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 8e23cdeb..9ccdb2ba 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -1,11 +1,10 @@ import pytest from fastapi import status -from fastapi.testclient import TestClient from ragna import assistants from ragna.core import RagnaException from ragna.deploy import Config -from tests.deploy.utils import authenticate_with_api, make_api_app +from tests.deploy.utils import make_api_app, make_api_client @pytest.mark.parametrize("ignore_unavailable_components", [True, False]) @@ -19,14 +18,10 @@ def test_ignore_unavailable_components(ignore_unavailable_components): config = Config(assistants=[available_assistant, unavailable_assistant]) if ignore_unavailable_components: - with TestClient( - make_api_app( - config=config, - ignore_unavailable_components=ignore_unavailable_components, - ) + with make_api_client( + config=config, + ignore_unavailable_components=ignore_unavailable_components, ) as client: - authenticate_with_api(client) - components = client.get("/api/components").raise_for_status().json() assert [assistant["title"] for assistant in components["assistants"]] == [ available_assistant.display_name() @@ -61,11 +56,9 @@ def test_unknown_component(tmp_local_root): with open(document_path, "w") as file: file.write("!\n") - with TestClient( - make_api_app(config=Config(), ignore_unavailable_components=False) + with make_api_client( + config=Config(), ignore_unavailable_components=False ) as client: - authenticate_with_api(client) - document = ( client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index e023c0ee..fa342d91 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -1,10 +1,9 @@ import json import pytest -from fastapi.testclient import TestClient from ragna.deploy import Config -from tests.deploy.utils import TestAssistant, authenticate_with_api, make_api_app +from tests.deploy.utils import TestAssistant, make_api_client from tests.utils import skip_on_windows @@ -20,11 +19,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): with open(document_path, "w") as file: file.write("!\n") - with TestClient( - make_api_app(config=config, ignore_unavailable_components=False) - ) as client: - authenticate_with_api(client) - + with make_api_client(config=config, ignore_unavailable_components=False) as client: assert client.get("/api/chats").raise_for_status().json() == [] documents = ( diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py new file mode 100644 index 00000000..6dd3fb78 --- /dev/null +++ b/tests/deploy/api/utils.py @@ -0,0 +1,35 @@ +import os + +from fastapi.testclient import TestClient + +from ragna._utils import default_user +from ragna.deploy._core import make_app + + +def make_api_app(*, config, ignore_unavailable_components): + return make_app( + config, + api=True, + ui=False, + ignore_unavailable_components=ignore_unavailable_components, + open_browser=False, + ) + + +def authenticate(client: TestClient) -> None: + return + username = default_user() + token = ( + client.post( + "/token", + data={ + "username": username, + "password": os.environ.get( + "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username + ), + }, + ) + .raise_for_status() + .json() + ) + client.headers["Authorization"] = f"Bearer {token}" diff --git a/tests/deploy/utils.py b/tests/deploy/utils.py index f8d1277a..e30fa758 100644 --- a/tests/deploy/utils.py +++ b/tests/deploy/utils.py @@ -1,10 +1,10 @@ -import os +import contextlib import time from fastapi.testclient import TestClient from ragna.assistants import RagnaDemoAssistant -from ragna.core._utils import default_user +from ragna.deploy._auth import SessionMiddleware from ragna.deploy._core import make_app @@ -38,19 +38,17 @@ def make_api_app(*, config, ignore_unavailable_components): def authenticate_with_api(client: TestClient) -> None: - return - username = default_user() - token = ( - client.post( - "/token", - data={ - "username": username, - "password": os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - ), - }, + client.get("/login", follow_redirects=True).raise_for_status() + assert SessionMiddleware._COOKIE_NAME in client.cookies + + +@contextlib.contextmanager +def make_api_client(*, config, ignore_unavailable_components): + with TestClient( + make_api_app( + config=config, + ignore_unavailable_components=ignore_unavailable_components, ) - .raise_for_status() - .json() - ) - client.headers["Authorization"] = f"Bearer {token}" + ) as client: + authenticate_with_api(client) + yield client diff --git a/tests/utils.py b/tests/utils.py index 3fc31732..87081f65 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,5 @@ -import contextlib import platform import socket -import subprocess -import sys import pytest @@ -11,16 +8,6 @@ ) -@contextlib.contextmanager -def background_subprocess(*args, stdout=sys.stdout, stderr=sys.stdout, **kwargs): - process = subprocess.Popen(args, stdout=stdout, stderr=stderr, **kwargs) - try: - yield process - finally: - process.kill() - process.communicate() - - def get_available_port(): with socket.socket() as s: s.bind(("", 0))