diff --git a/CHANGELOG.md b/CHANGELOG.md index 68b159e..be99497 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,15 @@ Tous les changements notables de l'application sont documentés dans ce fichier. - 🔄 Refactoring - ❌ Deprecated +## [Alpha] - 2024-10-21 + +- 🎉 Ajout de la limitation de débit (*rate limiting*) lorsque l'authentification est activée. +- 📚 Ajout d'une documentation (./docs/security.md) sur l'authentification et la limitation de débit. +- 🧪 Ajout de tests pour la limitation de débit. +- 📚 Amélioration de la documentation [README.md](./README.md). +- 📚 La documentation est maintenant accessible à l'URL `/documentation` et le swagger à l'URL `/swagger`. +- 🔄 Optimisation du comptage des documents dans Qdrant. + ## [Alpha] - 2024-10-09 - 🎉 Ajout d'un status du modèle dans le retour du endpoint GET `/v1/models`. Ce status permet de vérifier si le modèle est disponible ou non. diff --git a/README.md b/README.md index ff99796..39cbffc 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,23 @@ -# Albert API -![](https://img.shields.io/badge/python-3.12-green) ![](https://img.shields.io/badge/vLLM-v0.5.5-blue) ![](https://img.shields.io/badge/HuggingFace%20Text%20Embeddings%20Inference-1.5-red) +
+ +Albert API est une initiative d'[Etalab](https://www.etalab.gouv.fr/). Il s'agit d'une API open source d'IA générative développée par Etalab. Elle permet d'être un proxy entre des modèles de langage et vos données. Elle agrège les services suivants : - servir des modèles de langage avec [vLLM](https://github.com/vllm-project/vllm) - servir des modèles d'embeddings avec [HuggingFace Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) - accès un *vector store* avec [Qdrant](https://qdrant.tech/) pour la recherche de similarité En se basant sur les conventions définies par OpenAI, l'API Albert expose des endpoints qui peuvent être appelés avec le [client officiel python d'OpenAI](https://github.com/openai/openai-python/tree/main). Ce formalisme permet d'intégrer facilement l'API Albert avec des bibliothèques tierces comme [Langchain](https://www.langchain.com/) ou [LlamaIndex](https://www.llamaindex.ai/). -## 🚀 Nouveautés +## ⚙️ Fonctionnalités -Vous trouverez les changelogs des différentes versions d'Albert API dans le fichier [CHANGELOG.md](./CHANGELOG.md). +### Interface utilisateur (playground) -## ⚙️ Fonctionnalités +L'API Albert expose une interface utilisateur permettant de tester les différentes fonctionnalités, consultable ici [ici](https://albert.api.etalab.gouv.fr). ### Converser avec un modèle de langage (chat memory) @@ -50,6 +55,7 @@ L'API Albert permet d'importer sa base de connaissances dans une base vectoriell Albert API est un projet open source, vous pouvez contribuer au projet en lisant notre [guide de contribution](./CONTRIBUTING.md). -## Installation +## 🚀 Installation + +Pour déployer l'API Albert sur votre propre infrastructure, suivez la [documentation](./docs/deployment.md). -Pour déployer l'API Albert sur votre propre infrastructure, suivez la [documentation](./docs/deployment.md). \ No newline at end of file diff --git a/app/endpoints/chat.py b/app/endpoints/chat.py index e7816b4..6201d88 100644 --- a/app/endpoints/chat.py +++ b/app/endpoints/chat.py @@ -1,33 +1,34 @@ from typing import Union -from fastapi import APIRouter, Security +from fastapi import APIRouter, Request, Security from fastapi.responses import StreamingResponse import httpx from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest from app.schemas.security import User -from app.utils.lifespan import clients -from app.utils.security import check_api_key - +from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit router = APIRouter() @router.post("/chat/completions") -async def chat_completions(request: ChatCompletionRequest, user: User = Security(check_api_key)) -> Union[ChatCompletion, ChatCompletionChunk]: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def chat_completions( + request: Request, body: ChatCompletionRequest, user: User = Security(check_api_key) +) -> Union[ChatCompletion, ChatCompletionChunk]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create for the API specification. """ - - request = dict(request) - client = clients.models[request["model"]] + client = clients.models[body.model] url = f"{client.base_url}chat/completions" headers = {"Authorization": f"Bearer {client.api_key}"} # non stream case - if not request["stream"]: + if not body.stream: async with httpx.AsyncClient(timeout=20) as async_client: - response = await async_client.request(method="POST", url=url, headers=headers, json=request) + response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump()) response.raise_for_status() data = response.json() @@ -40,4 +41,4 @@ async def forward_stream(url: str, headers: dict, request: dict): async for chunk in response.aiter_raw(): yield chunk - return StreamingResponse(forward_stream(url, headers, request), media_type="text/event-stream") + return StreamingResponse(forward_stream(url, headers, body.model_dump()), media_type="text/event-stream") diff --git a/app/endpoints/chunks.py b/app/endpoints/chunks.py index 64409f2..8c52983 100644 --- a/app/endpoints/chunks.py +++ b/app/endpoints/chunks.py @@ -1,18 +1,22 @@ from typing import Optional from uuid import UUID -from fastapi import APIRouter, Security, Query +from fastapi import APIRouter, Request, Security, Query from app.schemas.chunks import Chunks from app.schemas.security import User from app.utils.lifespan import clients -from app.utils.security import check_api_key +from app.utils.security import check_api_key, check_rate_limit +from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.lifespan import limiter router = APIRouter() @router.get("/chunks/{collection}/{document}") +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) async def get_chunks( + request: Request, collection: UUID, document: UUID, limit: Optional[int] = Query(default=10, ge=1, le=100), diff --git a/app/endpoints/collections.py b/app/endpoints/collections.py index db7391f..96299e4 100644 --- a/app/endpoints/collections.py +++ b/app/endpoints/collections.py @@ -2,34 +2,36 @@ import uuid from uuid import UUID -from fastapi import APIRouter, Response, Security +from fastapi import APIRouter, Request, Response, Security from fastapi.responses import JSONResponse - from app.schemas.collections import Collection, CollectionRequest, Collections from app.schemas.security import User -from app.utils.lifespan import clients -from app.utils.security import check_api_key +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit +from app.utils.config import DEFAULT_RATE_LIMIT from app.utils.variables import INTERNET_COLLECTION_ID, PUBLIC_COLLECTION_TYPE router = APIRouter() @router.post("/collections") -async def create_collection(request: CollectionRequest, user: User = Security(check_api_key)) -> Response: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def create_collection(request: Request, body: CollectionRequest, user: User = Security(check_api_key)) -> Response: """ Create a new collection. """ collection_id = str(uuid.uuid4()) clients.vectors.create_collection( - collection_id=collection_id, collection_name=request.name, collection_model=request.model, collection_type=request.type, user=user + collection_id=collection_id, collection_name=body.name, collection_model=body.model, collection_type=body.type, user=user ) return JSONResponse(status_code=201, content={"id": collection_id}) @router.get("/collections") -async def get_collections(user: User = Security(check_api_key)) -> Union[Collection, Collections]: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def get_collections(request: Request, user: User = Security(check_api_key)) -> Union[Collection, Collections]: """ Get list of collections. """ @@ -47,7 +49,8 @@ async def get_collections(user: User = Security(check_api_key)) -> Union[Collect @router.delete("/collections/{collection}") -async def delete_collections(collection: UUID, user: User = Security(check_api_key)) -> Response: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def delete_collections(request: Request, collection: UUID, user: User = Security(check_api_key)) -> Response: """ Delete a collection. """ diff --git a/app/endpoints/completions.py b/app/endpoints/completions.py index ef61075..119e39a 100644 --- a/app/endpoints/completions.py +++ b/app/endpoints/completions.py @@ -1,28 +1,28 @@ -from fastapi import APIRouter, Security +from fastapi import APIRouter, Request, Security import httpx from app.schemas.completions import CompletionRequest, Completions from app.schemas.security import User -from app.utils.lifespan import clients -from app.utils.security import check_api_key +from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit router = APIRouter() @router.post("/completions") -async def completions(request: CompletionRequest, user: User = Security(check_api_key)) -> Completions: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def completions(request: Request, body: CompletionRequest, user: User = Security(check_api_key)) -> Completions: """ Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create for the API specification. """ - - request = dict(request) - client = clients.models[request["model"]] + client = clients.models[body.model] url = f"{client.base_url}completions" headers = {"Authorization": f"Bearer {client.api_key}"} async with httpx.AsyncClient(timeout=20) as async_client: - response = await async_client.request(method="POST", url=url, headers=headers, json=request) + response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump()) response.raise_for_status() data = response.json() diff --git a/app/endpoints/documents.py b/app/endpoints/documents.py index 1815762..a84cc0d 100644 --- a/app/endpoints/documents.py +++ b/app/endpoints/documents.py @@ -1,20 +1,25 @@ from typing import Optional from uuid import UUID -from fastapi import APIRouter, Response, Security, Query - +from fastapi import APIRouter, Query, Request, Response, Security from app.schemas.documents import Documents from app.schemas.security import User -from app.utils.lifespan import clients -from app.utils.security import check_api_key +from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit router = APIRouter() @router.get("/documents/{collection}") +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) async def get_documents( - collection: UUID, limit: Optional[int] = Query(default=10, ge=1, le=100), offset: Optional[UUID] = None, user: User = Security(check_api_key) + request: Request, + collection: UUID, + limit: Optional[int] = Query(default=10, ge=1, le=100), + offset: Optional[UUID] = None, + user: User = Security(check_api_key), ) -> Documents: """ Get all documents ID from a collection. @@ -27,11 +32,8 @@ async def get_documents( @router.delete("/documents/{collection}/{document}") -async def delete_document( - collection: UUID, - document: UUID, - user: User = Security(check_api_key), -) -> Response: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def delete_document(request: Request, collection: UUID, document: UUID, user: User = Security(check_api_key)) -> Response: """ Delete a document and relative collections. """ diff --git a/app/endpoints/embeddings.py b/app/endpoints/embeddings.py index 3de42cc..fae6a90 100644 --- a/app/endpoints/embeddings.py +++ b/app/endpoints/embeddings.py @@ -1,25 +1,26 @@ -from fastapi import APIRouter, Security +from fastapi import APIRouter, Request, Security import httpx from app.schemas.embeddings import Embeddings, EmbeddingsRequest from app.schemas.security import User -from app.utils.lifespan import clients -from app.utils.security import check_api_key -from app.utils.variables import EMBEDDINGS_MODEL_TYPE +from app.utils.config import DEFAULT_RATE_LIMIT from app.utils.exceptions import ContextLengthExceededException, WrongModelTypeException +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit +from app.utils.variables import EMBEDDINGS_MODEL_TYPE router = APIRouter() @router.post("/embeddings") -async def embeddings(request: EmbeddingsRequest, user: User = Security(check_api_key)) -> Embeddings: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(check_api_key)) -> Embeddings: """ Embedding API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. """ - request = dict(request) - client = clients.models[request["model"]] + client = clients.models[body.model] if client.type != EMBEDDINGS_MODEL_TYPE: raise WrongModelTypeException() @@ -27,7 +28,7 @@ async def embeddings(request: EmbeddingsRequest, user: User = Security(check_api headers = {"Authorization": f"Bearer {client.api_key}"} async with httpx.AsyncClient(timeout=20) as async_client: - response = await async_client.request(method="POST", url=url, headers=headers, json=request) + response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump()) try: response.raise_for_status() except httpx.HTTPStatusError as e: diff --git a/app/endpoints/files.py b/app/endpoints/files.py index dce70b4..db600fd 100644 --- a/app/endpoints/files.py +++ b/app/endpoints/files.py @@ -31,7 +31,7 @@ async def upload_file(file: UploadFile = File(...), request: FilesRequest = Body chunker_args = ChunkerArgs().model_dump() chunker_name = None - chunker_args["length_function"] = len if chunker_args["length_function"] == "len" else None + chunker_args["length_function"] = len if chunker_args["length_function"] == "len" else chunker_args["length_function"] uploader = FileUploader(vectors=clients.vectors, user=user, collection_id=request.collection) output = uploader.parse(file=file) diff --git a/app/endpoints/models.py b/app/endpoints/models.py index 423df5d..9b3a9eb 100644 --- a/app/endpoints/models.py +++ b/app/endpoints/models.py @@ -1,18 +1,20 @@ from typing import Optional, Union -from fastapi import APIRouter, Security +from fastapi import APIRouter, Request, Security +from app.utils.config import DEFAULT_RATE_LIMIT from app.schemas.models import Model, Models from app.schemas.security import User -from app.utils.lifespan import clients -from app.utils.security import check_api_key +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit router = APIRouter() @router.get("/models/{model:path}") @router.get("/models") -async def models(model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def models(request: Request, model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]: """ Model API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/models/list for the API specification. diff --git a/app/endpoints/search.py b/app/endpoints/search.py index ce5ad49..77dd6d0 100644 --- a/app/endpoints/search.py +++ b/app/endpoints/search.py @@ -1,38 +1,38 @@ -from fastapi import APIRouter, Security +from fastapi import APIRouter, Request, Security from app.helpers import SearchOnInternet from app.schemas.search import Searches, SearchRequest from app.schemas.security import User -from app.utils.lifespan import clients -from app.utils.security import check_api_key +from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.lifespan import clients, limiter +from app.utils.security import check_api_key, check_rate_limit from app.utils.variables import INTERNET_COLLECTION_ID router = APIRouter() @router.post("/search") -async def search(request: SearchRequest, user: User = Security(check_api_key)) -> Searches: +@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +async def search(request: Request, body: SearchRequest, user: User = Security(check_api_key)) -> Searches: """ Similarity search for chunks in the vector store or on the internet. """ data = [] - if INTERNET_COLLECTION_ID in request.collections: - request.collections.remove(INTERNET_COLLECTION_ID) + if INTERNET_COLLECTION_ID in body.collections: + body.collections.remove(INTERNET_COLLECTION_ID) internet = SearchOnInternet(models=clients.models) - if len(request.collections) > 0: - collection_model = clients.vectors.get_collections(collection_ids=request.collections, user=user)[0].model + if len(body.collections) > 0: + collection_model = clients.vectors.get_collections(collection_ids=body.collections, user=user)[0].model else: collection_model = None - data.extend(internet.search(prompt=request.prompt, n=4, model_id=collection_model, score_threshold=request.score_threshold)) + data.extend(internet.search(prompt=body.prompt, n=4, model_id=collection_model, score_threshold=body.score_threshold)) - if len(request.collections) > 0: + if len(body.collections) > 0: data.extend( - clients.vectors.search( - prompt=request.prompt, collection_ids=request.collections, k=request.k, score_threshold=request.score_threshold, user=user - ) + clients.vectors.search(prompt=body.prompt, collection_ids=body.collections, k=body.k, score_threshold=body.score_threshold, user=user) ) - data = sorted(data, key=lambda x: x.score, reverse=False)[: request.k] + data = sorted(data, key=lambda x: x.score, reverse=False)[: body.k] return Searches(data=data) diff --git a/app/helpers/__init__.py b/app/helpers/__init__.py index d64f89e..b735a42 100644 --- a/app/helpers/__init__.py +++ b/app/helpers/__init__.py @@ -1,9 +1,9 @@ -from ._authmanager import AuthManager +from ._authenticationclient import AuthenticationClient from ._clientsmanager import ClientsManager from ._contentsizelimitmiddleware import ContentSizeLimitMiddleware +from ._fileuploader import FileUploader from ._modelclients import ModelClients from ._searchoninternet import SearchOnInternet from ._vectorstore import VectorStore -from ._fileuploader import FileUploader -__all__ = ["AuthManager", "ClientsManager", "ContentSizeLimitMiddleware", "FileUploader", "ModelClients", "SearchOnInternet", "VectorStore"] +__all__ = ["AuthenticationClient", "ClientsManager", "ContentSizeLimitMiddleware", "FileUploader", "ModelClients", "SearchOnInternet", "VectorStore"] diff --git a/app/helpers/_authmanager.py b/app/helpers/_authenticationclient.py similarity index 74% rename from app/helpers/_authmanager.py rename to app/helpers/_authenticationclient.py index 18d4d4a..fe44f71 100644 --- a/app/helpers/_authmanager.py +++ b/app/helpers/_authenticationclient.py @@ -1,19 +1,25 @@ import datetime as dt +import json from typing import Optional +import uuid from grist_api import GristDocAPI from redis import Redis -import json -from app.utils.variables import USER_ROLE + +from app.utils.variables import ROLE_LEVEL_0, ROLE_LEVEL_1, ROLE_LEVEL_2 -class AuthManager(GristDocAPI): +class AuthenticationClient(GristDocAPI): CACHE_EXPIRATION = 3600 # 1h + ROLE_DICT = { + "user": ROLE_LEVEL_0, + "client": ROLE_LEVEL_1, + "admin": ROLE_LEVEL_2, + } def __init__(self, cache: Redis, table_id: str, *args, **kwargs): super().__init__(*args, **kwargs) - self.user = kwargs.get("user") - self.doc_id = kwargs.get("doc_id") + self.session_id = str(uuid.uuid4()) self.table_id = table_id self.redis = cache @@ -37,7 +43,7 @@ def cache(func): """ def wrapper(self): - key = f"auth-{self.doc_id}-{self.table_id}" + key = f"auth-{self.session_id}" result = self.redis.get(key) if result: result = json.loads(result) @@ -53,12 +59,15 @@ def wrapper(self): def _get_api_keys(self): """ Get all keys from a table in the Grist document. + + Returns: + dict: dictionary of keys and their corresponding access level """ records = self.fetch_table(self.table_id) keys = dict() for record in records: if record.EXPIRATION > dt.datetime.now().timestamp(): - keys[record.KEY] = record.ROLE or USER_ROLE + keys[record.KEY] = self.ROLE_DICT.get(record.ROLE, ROLE_LEVEL_0) return keys diff --git a/app/helpers/_clientsmanager.py b/app/helpers/_clientsmanager.py index a339dd4..9f7a904 100644 --- a/app/helpers/_clientsmanager.py +++ b/app/helpers/_clientsmanager.py @@ -1,9 +1,10 @@ from redis import Redis as CacheManager +from redis.connection import ConnectionPool from app.schemas.config import Config from ._modelclients import ModelClients -from ._authmanager import AuthManager +from ._authenticationclient import AuthenticationClient from ._vectorstore import VectorStore @@ -16,10 +17,13 @@ def set(self): self.models = ModelClients(config=self.config) # set cache - self.cache = CacheManager(**self.config.databases.cache.args) + self.cache = CacheManager(connection_pool=ConnectionPool(**self.config.databases.cache.args)) # set vectors self.vectors = VectorStore(models=self.models, **self.config.databases.vectors.args) # set auth - self.auth = AuthManager(cache=self.cache, **self.config.auth.args) if self.config.auth else None + self.auth = AuthenticationClient(cache=self.cache, **self.config.auth.args) if self.config.auth else None + + def clear(self): + self.vectors.close() diff --git a/app/helpers/_contentsizelimitmiddleware.py b/app/helpers/_contentsizelimitmiddleware.py index 8e0ef85..9ccf4ac 100644 --- a/app/helpers/_contentsizelimitmiddleware.py +++ b/app/helpers/_contentsizelimitmiddleware.py @@ -12,13 +12,9 @@ class ContentSizeLimitMiddleware: max_content_size (optional): the maximum content size allowed in bytes, default is MAX_CONTENT_SIZE """ - MAX_CONTENT_SIZE = 10 * 1024 * 1024 # 10MB + MAX_CONTENT_SIZE = 20 * 1024 * 1024 # 20MB - def __init__( - self, - app, - max_content_size: Optional[int] = None, - ): + def __init__(self, app, max_content_size: Optional[int] = None): self.app = app self.max_content_size = max_content_size or self.MAX_CONTENT_SIZE diff --git a/app/helpers/_vectorstore.py b/app/helpers/_vectorstore.py index 9c8612b..8251c7c 100644 --- a/app/helpers/_vectorstore.py +++ b/app/helpers/_vectorstore.py @@ -19,29 +19,29 @@ from app.schemas.documents import Document from app.schemas.search import Search from app.schemas.security import User -from app.utils.variables import EMBEDDINGS_MODEL_TYPE, PUBLIC_COLLECTION_TYPE, USER_ROLE from app.utils.exceptions import ( + CollectionNotFoundException, DifferentCollectionsModelsException, - WrongModelTypeException, WrongCollectionTypeException, - CollectionNotFoundException, + WrongModelTypeException, ) +from app.utils.variables import EMBEDDINGS_MODEL_TYPE, PUBLIC_COLLECTION_TYPE, ROLE_LEVEL_2 -class VectorStore: +class VectorStore(QdrantClient): BATCH_SIZE = 48 METADATA_COLLECTION_ID = "collections" DOCUMENT_COLLECTION_ID = "documents" def __init__(self, models: dict, *args, **kwargs): - self.qdrant = QdrantClient(*args, **kwargs) + super().__init__(*args, **kwargs) self.models = models - if not self.qdrant.collection_exists(collection_name=self.METADATA_COLLECTION_ID): - self.qdrant.create_collection(collection_name=self.METADATA_COLLECTION_ID, vectors_config={}, on_disk_payload=False) + if not super().collection_exists(collection_name=self.METADATA_COLLECTION_ID): + super().create_collection(collection_name=self.METADATA_COLLECTION_ID, vectors_config={}, on_disk_payload=False) - if not self.qdrant.collection_exists(collection_name=self.DOCUMENT_COLLECTION_ID): - self.qdrant.create_collection(collection_name=self.DOCUMENT_COLLECTION_ID, vectors_config={}, on_disk_payload=False) + if not super().collection_exists(collection_name=self.DOCUMENT_COLLECTION_ID): + super().create_collection(collection_name=self.DOCUMENT_COLLECTION_ID, vectors_config={}, on_disk_payload=False) def upsert(self, chunks: List[Chunk], collection_id: str, user: User) -> None: """ @@ -54,14 +54,14 @@ def upsert(self, chunks: List[Chunk], collection_id: str, user: User) -> None: """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role == USER_ROLE and collection.type == PUBLIC_COLLECTION_TYPE: + if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: raise WrongCollectionTypeException() for i in range(0, len(chunks), self.BATCH_SIZE): batch = chunks[i : i + self.BATCH_SIZE] # insert documents - self.qdrant.upsert( + super().upsert( collection_name=self.DOCUMENT_COLLECTION_ID, points=[ PointStruct( @@ -83,7 +83,7 @@ def upsert(self, chunks: List[Chunk], collection_id: str, user: User) -> None: vectors = [vector.embedding for vector in response.data] # insert chunks and vectors - self.qdrant.upsert( + super().upsert( collection_name=collection_id, points=[ PointStruct(id=chunk.id, vector=vector, payload={"content": chunk.content, "metadata": chunk.metadata.model_dump()}) @@ -91,6 +91,19 @@ def upsert(self, chunks: List[Chunk], collection_id: str, user: User) -> None: ], ) + # update collection documents count + payload = collection.model_dump() + payload["documents"] = ( + super() + .count( + collection_name=self.DOCUMENT_COLLECTION_ID, + count_filter=Filter(must=[FieldCondition(key="collection_id", match=MatchAny(any=[collection.id]))]), + ) + .count + ) + payload.pop("id") + super().upsert(collection_name=self.METADATA_COLLECTION_ID, points=[PointStruct(id=collection.id, payload=payload, vector={})]) + def search( self, prompt: str, @@ -124,7 +137,7 @@ def search( chunks = [] for collection in collections: - results = self.qdrant.search( + results = super().search( collection_name=collection.id, query_vector=vector, limit=k, score_threshold=score_threshold, with_payload=True, query_filter=filter ) for result in results: @@ -148,7 +161,6 @@ def get_collections(self, user: User, collection_ids: List[str] = []) -> List[Co user (User): The user retrieving the collections. collection_ids (List[str]): List of collection ids to retrieve metadata for. If is an empty list, all collections will be considered. - Returns: List[Collection]: A list of Collection objects containing the metadata for the specified collections. """ @@ -160,15 +172,15 @@ def get_collections(self, user: User, collection_ids: List[str] = []) -> List[Co ] filter = Filter(must=must, should=should) - records = self.qdrant.scroll(collection_name=self.METADATA_COLLECTION_ID, scroll_filter=filter, limit=1000, offset=None) + records = super().scroll(collection_name=self.METADATA_COLLECTION_ID, scroll_filter=filter, limit=1000, offset=None) data, offset = records[0], records[1] while offset is not None: - records = self.qdrant.scroll(collection_name=self.METADATA_COLLECTION_ID, scroll_filter=filter, limit=1000, offset=offset) + records = super().scroll(collection_name=self.METADATA_COLLECTION_ID, scroll_filter=filter, limit=1000, offset=offset) data.extend(records[0]) offset = records[1] # sanity check: remove collection that does not exist - existing_collection_ids = [collection.name for collection in self.qdrant.get_collections().collections] + existing_collection_ids = [collection.name for collection in super().get_collections().collections] data = [collection for collection in data if collection.id in existing_collection_ids] # check if collection ids are valid @@ -179,10 +191,6 @@ def get_collections(self, user: User, collection_ids: List[str] = []) -> List[Co collections = list() for collection in data: - document_count = self.qdrant.count( - collection_name=self.DOCUMENT_COLLECTION_ID, - count_filter=Filter(must=[FieldCondition(key="collection_id", match=MatchAny(any=[collection.id]))]), - ).count collections.append( Collection( id=collection.id, @@ -192,7 +200,7 @@ def get_collections(self, user: User, collection_ids: List[str] = []) -> List[Co user=collection.payload.get("user"), description=collection.payload.get("description"), created_at=collection.payload.get("created_at"), - documents=document_count, + documents=collection.payload.get("documents"), ) ) @@ -212,7 +220,7 @@ def create_collection(self, collection_id: str, collection_name: str, collection if self.models[collection_model].type != EMBEDDINGS_MODEL_TYPE: raise WrongModelTypeException() - if user.role == USER_ROLE and collection_type == PUBLIC_COLLECTION_TYPE: + if user.role != ROLE_LEVEL_2 and collection_type == PUBLIC_COLLECTION_TYPE: raise WrongCollectionTypeException() # create metadata @@ -223,11 +231,12 @@ def create_collection(self, collection_id: str, collection_name: str, collection "user": user.id, "description": None, "created_at": round(time.time()), + "documents": 0, } - self.qdrant.upsert(collection_name=self.METADATA_COLLECTION_ID, points=[PointStruct(id=collection_id, payload=dict(metadata), vector={})]) + super().upsert(collection_name=self.METADATA_COLLECTION_ID, points=[PointStruct(id=collection_id, payload=dict(metadata), vector={})]) # create collection - self.qdrant.create_collection( + super().create_collection( collection_name=collection_id, vectors_config=VectorParams(size=self.models[collection_model].vector_size, distance=Distance.COSINE) ) @@ -241,11 +250,11 @@ def delete_collection(self, collection_id: str, user: User) -> None: """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role == USER_ROLE and collection.type == PUBLIC_COLLECTION_TYPE: + if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: raise WrongCollectionTypeException() - self.qdrant.delete_collection(collection_name=collection.id) - self.qdrant.delete(collection_name=self.METADATA_COLLECTION_ID, points_selector=PointIdsList(points=[collection.id])) + super().delete_collection(collection_name=collection.id) + super().delete(collection_name=self.METADATA_COLLECTION_ID, points_selector=PointIdsList(points=[collection.id])) def get_chunks(self, collection_id: str, document_id: str, user: User, limit: Optional[int] = 10, offset: Optional[int] = None) -> List[Chunk]: """ @@ -264,7 +273,7 @@ def get_chunks(self, collection_id: str, document_id: str, user: User, limit: Op collection = self.get_collections(collection_ids=[collection_id], user=user)[0] filter = Filter(must=[FieldCondition(key="metadata.document_id", match=MatchAny(any=[document_id]))]) - data = self.qdrant.scroll(collection_name=collection.id, scroll_filter=filter, limit=limit, offset=offset)[0] + data = super().scroll(collection_name=collection.id, scroll_filter=filter, limit=limit, offset=offset)[0] chunks = [Chunk(id=chunk.id, content=chunk.payload["content"], metadata=ChunkMetadata(**chunk.payload["metadata"])) for chunk in data] return chunks @@ -285,13 +294,17 @@ def get_documents(self, collection_id: str, user: User, limit: Optional[int] = 1 collection = self.get_collections(collection_ids=[collection_id], user=user)[0] filter = Filter(must=[FieldCondition(key="collection_id", match=MatchAny(any=[collection_id]))]) - data = self.qdrant.scroll(collection_name=self.DOCUMENT_COLLECTION_ID, scroll_filter=filter, limit=limit, offset=offset)[0] + data = super().scroll(collection_name=self.DOCUMENT_COLLECTION_ID, scroll_filter=filter, limit=limit, offset=offset)[0] documents = list() for document in data: - chunks_count = self.qdrant.count( - collection_name=collection.id, - count_filter=Filter(must=[FieldCondition(key="metadata.document_id", match=MatchAny(any=[document.id]))]), - ).count + chunks_count = ( + super() + .count( + collection_name=collection.id, + count_filter=Filter(must=[FieldCondition(key="metadata.document_id", match=MatchAny(any=[document.id]))]), + ) + .count + ) documents.append(Document(id=document.id, name=document.payload["name"], created_at=document.payload["created_at"], chunks=chunks_count)) return documents @@ -307,12 +320,12 @@ def delete_document(self, collection_id: str, document_id: str, user: User): """ collection = self.get_collections(collection_ids=[collection_id], user=user)[0] - if user.role == USER_ROLE and collection.type == PUBLIC_COLLECTION_TYPE: + if user.role != ROLE_LEVEL_2 and collection.type == PUBLIC_COLLECTION_TYPE: raise WrongCollectionTypeException() # delete chunks filter = Filter(must=[FieldCondition(key="metadata.document_id", match=MatchAny(any=[document_id]))]) - self.qdrant.delete(collection_name=collection.id, points_selector=FilterSelector(filter=filter)) + super().delete(collection_name=collection.id, points_selector=FilterSelector(filter=filter)) # delete document - self.qdrant.delete(collection_name=self.DOCUMENT_COLLECTION_ID, points_selector=PointIdsList(points=[document_id])) + super().delete(collection_name=self.DOCUMENT_COLLECTION_ID, points_selector=PointIdsList(points=[document_id])) diff --git a/app/main.py b/app/main.py index bd66142..900dd3a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,9 @@ from fastapi import FastAPI, Response, Security -from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search, documents + +from slowapi.middleware import SlowAPIASGIMiddleware + +from app.endpoints import chat, chunks, collections, completions, documents, embeddings, files, models, search from app.helpers import ContentSizeLimitMiddleware from app.schemas.security import User from app.utils.config import APP_CONTACT_EMAIL, APP_CONTACT_URL, APP_DESCRIPTION, APP_VERSION @@ -14,10 +17,17 @@ contact={"url": APP_CONTACT_URL, "email": APP_CONTACT_EMAIL}, licence_info={"name": "MIT License", "identifier": "MIT"}, lifespan=lifespan, + docs_url="/swagger", + redoc_url="/documentation", ) +# Middlewares +app.add_middleware(ContentSizeLimitMiddleware) +app.add_middleware(SlowAPIASGIMiddleware) + -@app.get("/health") +# Monitoring +@app.get("/health", tags=["Monitoring"]) def health(user: User = Security(check_api_key)): """ Health check. @@ -26,8 +36,6 @@ def health(user: User = Security(check_api_key)): return Response(status_code=200) -app.add_middleware(ContentSizeLimitMiddleware) - # Core app.include_router(models.router, tags=["Core"], prefix="/v1") app.include_router(chat.router, tags=["Core"], prefix="/v1") diff --git a/app/schemas/security.py b/app/schemas/security.py index 1679786..cbf3068 100644 --- a/app/schemas/security.py +++ b/app/schemas/security.py @@ -3,4 +3,4 @@ class User(BaseModel): id: str - role: str + role: int diff --git a/app/tests/assets/pdf_too_large.pdf b/app/tests/assets/pdf_too_large.pdf index fddec4c..623a009 100644 Binary files a/app/tests/assets/pdf_too_large.pdf and b/app/tests/assets/pdf_too_large.pdf differ diff --git a/app/tests/conftest.py b/app/tests/conftest.py index e6251e4..7d1fb06 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,4 +1,5 @@ import logging +import time import pytest import requests @@ -67,3 +68,10 @@ def cleanup_collections(args, session_user, session_admin): for collection_id in collection_ids: session_admin.delete(f"{args["base_url"]}/collections/{collection_id}") + + +@pytest.fixture(scope="module", autouse=True) +def sleep_between_tests(): + # Sleep between tests to avoid rate limit errors + yield + time.sleep(30) diff --git a/app/tests/test_chat.py b/app/tests/test_chat.py index 69c6728..4436116 100644 --- a/app/tests/test_chat.py +++ b/app/tests/test_chat.py @@ -86,7 +86,7 @@ def test_chat_completions_max_tokens_too_large(self, args, session_user, setup): "messages": [{"role": "user", "content": prompt}], "stream": True, "n": 1, - "max_tokens": MAX_CONTEXT_LENGTH + 10, + "max_tokens": MAX_CONTEXT_LENGTH + 100, } response = session_user.post(f"{args['base_url']}/chat/completions", json=params) assert response.status_code == 422, f"error: retrieve chat completions ({response.status_code})" @@ -94,7 +94,7 @@ def test_chat_completions_max_tokens_too_large(self, args, session_user, setup): def test_chat_completions_context_too_large(self, args, session_user, setup): MODEL_ID, MAX_CONTEXT_LENGTH = setup - prompt = "test" * (MAX_CONTEXT_LENGTH + 10) + prompt = "test" * (MAX_CONTEXT_LENGTH + 100) params = { "model": MODEL_ID, "messages": [{"role": "user", "content": prompt}], diff --git a/app/tests/test_models.py b/app/tests/test_models.py index d8bb2c8..ff275e6 100644 --- a/app/tests/test_models.py +++ b/app/tests/test_models.py @@ -1,32 +1,47 @@ +import time + import pytest from app.schemas.models import Model, Models +from app.utils.config import DEFAULT_RATE_LIMIT -@pytest.mark.usefixtures("args", "session_user") +@pytest.mark.usefixtures("args", "session_user", "session_admin") class TestModels: - def test_get_models_response_status_code(self, args, session_user): + def test_get_models_response_status_code(self, args, session_admin): """Test the GET /models response status code.""" - response = session_user.get(f"{args['base_url']}/models") + response = session_admin.get(f"{args["base_url"]}/models") assert response.status_code == 200, f"error: retrieve models ({response.status_code})" models = Models(data=[Model(**model) for model in response.json()["data"]]) assert isinstance(models, Models) assert all(isinstance(model, Model) for model in models.data) - def test_get_model_retrieve_model(self, args, session_user): - """Test the GET /models/{model_id} response status code.""" - response = session_user.get(f"{args['base_url']}/models") - assert response.status_code == 200, f"error: retrieve models ({response.status_code})" - - model = response.json()["data"][0]["id"] - response = session_user.get(f"{args['base_url']}/models/{model}") + model = models.data[0].id + response = session_admin.get(f"{args["base_url"]}/models/{model}") assert response.status_code == 200, f"error: retrieve model ({response.status_code})" model = Model(**response.json()) assert isinstance(model, Model) - def test_get_models_non_existing_model(self, args, session_user): + def test_get_models_non_existing_model(self, args, session_admin): """Test the GET /models response status code for a non-existing model.""" - response = session_user.get(f"{args['base_url']}/models/non-existing-model") + response = session_admin.get(f"{args["base_url"]}/models/non-existing-model") assert response.status_code == 404, f"error: retrieve non-existing model ({response.status_code})" + + def test_get_models_rate_limit(self, args, session_user): + """Test the GET /models rate limiting.""" + start = time.time() + limit = int(DEFAULT_RATE_LIMIT.replace("/minute", "")) + i = 0 + while time.time() - start < 60: + i += 1 + response = session_user.get(f"{args["base_url"]}/models") + if i == limit: + assert response.status_code == 429 + break + else: + assert response.status_code == 200 + + # sanity check to make sure the rate limiting is tested + assert i == limit diff --git a/app/utils/config.py b/app/utils/config.py index 928e32d..1fcc212 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -23,3 +23,7 @@ "APP_DESCRIPTION", "[See documentation](https://github.com/etalab-ia/albert-api/blob/main/README.md)", ) + +# Rate limit +GLOBAL_RATE_LIMIT = os.getenv("GLOBAL_RATE_LIMIT", "100/minute") +DEFAULT_RATE_LIMIT = os.getenv("DEFAULT_RATE_LIMIT", "10/minute") diff --git a/app/utils/lifespan.py b/app/utils/lifespan.py index 674a65a..ad6a0fb 100644 --- a/app/utils/lifespan.py +++ b/app/utils/lifespan.py @@ -1,18 +1,27 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +from slowapi import Limiter +from slowapi.util import get_ipaddr from app.helpers import ClientsManager -from app.utils.config import CONFIG +from app.utils.config import CONFIG, GLOBAL_RATE_LIMIT clients = ClientsManager(config=CONFIG) +limiter = Limiter( + key_func=get_ipaddr, + storage_uri=f"redis://{CONFIG.databases.cache.args.get("username", "")}:{CONFIG.databases.cache.args.get("password", "")}@{CONFIG.databases.cache.args["host"]}:{CONFIG.databases.cache.args["port"]}", + default_limits=[GLOBAL_RATE_LIMIT], +) -# @TODO: test to move into main.py @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan event to initialize clients (models API and databases).""" + app.state.limiter = limiter clients.set() yield + + clients.clear() diff --git a/app/utils/security.py b/app/utils/security.py index 80ad8b1..5d4480a 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -1,13 +1,15 @@ import base64 import hashlib -from typing import Annotated +from typing import Annotated, Optional -from fastapi import Depends +from fastapi import Depends, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from app.utils.lifespan import clients from app.schemas.security import User -from app.utils.exceptions import InvalidAuthenticationSchemeException, InvalidAPIKeyException +from app.utils.config import CONFIG +from app.utils.exceptions import InvalidAPIKeyException, InvalidAuthenticationSchemeException +from app.utils.lifespan import clients +from app.utils.variables import ROLE_LEVEL_0, ROLE_LEVEL_2 def encode_string(input: str) -> str: @@ -28,30 +30,53 @@ def encode_string(input: str) -> str: return hash -def check_api_key( - api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))], -) -> str: - """ - Check if the API key is valid. +if CONFIG.auth: - Args: - api_key (Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key")]): The API key to check. + def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> str: + """ + Check if the API key is valid. + + Args: + api_key (Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key")]): The API key to check. + + Returns: + str: User ID, corresponding to the encoded API key or "no-auth" if no authentication is set in the configuration file. + """ - Returns: - str: User ID, corresponding to the encoded API key or "no-auth" if no authentication is set in the configuration file. - """ - if clients.auth: if api_key.scheme != "Bearer": raise InvalidAuthenticationSchemeException() role = clients.auth.check_api_key(api_key.credentials) - if role is None: raise InvalidAPIKeyException() user_id = encode_string(input=api_key.credentials) - else: - user_id = "no-auth" + return User(id=user_id, role=role) + +else: - return User(id=user_id, role=role) + def check_api_key(api_key: Optional[str] = None) -> str: + return User(id="no-auth", role=ROLE_LEVEL_2) + + +def check_rate_limit(request: Request) -> Optional[str]: + """ + Check the rate limit for the user. + + Args: + request (Request): The request object. + + Returns: + Optional[str]: user_id if the access level is 0, None otherwise (no rate limit applied). + """ + + authorization = request.headers.get("Authorization") + scheme, credentials = authorization.split(" ") if authorization else ("", "") + api_key = HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) + user = check_api_key(api_key=api_key) + + if user.role > ROLE_LEVEL_0: + return None + else: + return user.id diff --git a/app/utils/variables.py b/app/utils/variables.py index d269e65..c76d05e 100644 --- a/app/utils/variables.py +++ b/app/utils/variables.py @@ -9,5 +9,7 @@ JSON_TYPE = "application/json" TXT_TYPE = "text/plain" HTML_TYPE = "text/html" -USER_ROLE = "user" +ROLE_LEVEL_0 = 0 +ROLE_LEVEL_1 = 1 +ROLE_LEVEL_2 = 2 # @TODO : add DOCX_TYPE (application/vnd.openxmlformats-officedocument.wordprocessingml.document) diff --git a/docs/deploiement.md b/docs/deploiement.md index f43eac1..469d78c 100644 --- a/docs/deploiement.md +++ b/docs/deploiement.md @@ -1,31 +1,45 @@ -## Déployer l'API Albert +# Déployer l'API Albert -### Quickstart +## Quickstart -1. Installez [libmagic](https://man7.org/linux/man-pages/man3/libmagic.3.html) +1. Créez un fichier *config.yml* à la racine du dépot sur la base du fichier d'exemple *[config.example.yml](./config.example.yml)* Voir la section [Configuration](#configuration) pour plus d'informations. -2. Installez les packages Python dans un environnement virtuel dédié +2. Déployez l'API avec Docker à l'aide du fichier [compose.yml](../compose.yml) à la racine du dépot. - ```bash - pip install ".[app]" + ```bash + docker compose up -d ``` -3. Créez un fichier *config.yml* à la racine du repository sur la base du fichier d'exemple *[config.example.yml](./config.example.yml)* +## Configuration - Si vous souhaitez configurer les accès aux modèles et aux bases de données, consultez la [Configuration](#configuration). +### Variables d'environnements - Pour lancer l'API : - ```bash - uvicorn app.main:app --reload --port 8080 --log-level debug - ``` +| Variable | Description | +| --- | --- | +| APP_CONTACT_URL | URL for app contact information (default: None) | +| APP_CONTACT_EMAIL | Email for app contact (default: None) | +| APP_VERSION | Version of the application (default: "0.0.0") | +| APP_DESCRIPTION | Description of the application (default: None) | +| GLOBAL_RATE_LIMIT | Global rate limit for API requests per IP address (default: "100/minute") | +| DEFAULT_RATE_LIMIT | Default rate limit for API requests per user (default: "10/minute") | +| CONFIG_FILE | Path to the configuration file (default: "config.yml") | +| LOG_LEVEL | Logging level (default: DEBUG) | + +### Clients tiers -### Configuration +Pour fonctionner, l'API Albert nécessite des clients tiers : -Toute la configuration de l'API Albert se fait dans fichier de configuration qui doit respecter les spécifications suivantes (voir *[config.example.yml](./config.example.yml)* pour un exemple) : +* [Optionnel] Auth : [Grist](https://www.getgrist.com/)* +* Cache : [Redis](https://redis.io/) +* Vectors : [Qdrant](https://qdrant.tech/) + +\* *Pour plus d'information sur l'authentification Grist, voir la [documentation](./security.md).* + +Ces clients sont déclarés dans un fichier de configuration qui doit respecter les spécifications suivantes (voir *[config.example.yml](./config.example.yml)* pour un exemple) : ```yaml auth: [optional] - type: [optional] + type: grist args: [optional] [arg_name]: [value] ... @@ -34,45 +48,28 @@ models: - url: [required] key: [optional] search_internet: [optional] + type: [required] # at least one of embedding model (text-embeddings-inference) + + - url: [required] + key: [optional] + search_internet: [optional] + type: [required] # at least one of language model (text-generation) ... databases: cache: [required] - type: [required] # see following Database section for the list of supported db type + type: redis args: [required] [arg_name]: [value] ... vectors: [required] - type: [required] # see following Database section for the list of supported db type + type: qdrant args: [required] [arg_name]: [value] ... ``` -**Par défaut, l'API va chercher un fichier nommé *config.yml* la racine du dépot.** Néanmoins, vous pouvez spécifier un autre fichier de config comme ceci : - -```bash -CONFIG_FILE= uvicorn main:app --reload --port 8080 --log-level debug -``` - -La configuration permet de spéficier le token d'accès à l'API, les API de modèles auquel à accès l'API d'Albert ainsi que les bases de données nécessaires à sont fonctionnement. - -#### Auth - -Les IAM supportés, de nouveaux seront disponibles prochainement : - -* [Grist](https://www.getgrist.com/) - -#### Databases - -Voici les types de base de données supportées, à configurer dans le fichier de configuration (*[config.example.yml](./config.example.yml)*) : - -| Database | Type | -| --- | --- | -| vectors | [qdrant](https://qdrant.tech/) | -| cache | [redis](https://redis.io/) | - ## Déploiement de l'interface Streamlit 1. Installez les packages Python dans un environnement virtuel diff --git a/docs/security.md b/docs/security.md new file mode 100644 index 0000000..882d209 --- /dev/null +++ b/docs/security.md @@ -0,0 +1,23 @@ +# Sécurité + +## Authentification + +L'authentification est réalisée par le biais d'un client [Grist](https://www.getgrist.com/). Vous devez créer une table dans Grist pou stocker les clefs d'API keys que vous créez. Cette table doit avoir la structure suivante + +| KEY | ROLE | EXPIRATION | +| --- | --- | --- | +| my_key | admin \| client \| user | 2099-01-01 | + +Si vous souhaitez déployer l'API sans authentication par Grist, ne déclarez pas de section *auth* dans le fichier de configuration. L'authentification sera alors désactivée et l'utilisateur est le rôle de niveau 2 (admin). + +## Droits d'accès + +L'API implémente un système de rôle à 3 niveaux : + +| Niveau du rôle | Description | +| --- | --- | +| 0 (user) | Accès limité à l'API (rate limiting) et aucun droits d'édition sur les collections publiques | +| 1 (client) | Accès illimité à l'API et aucun droits d'édition sur les collections publiques | +| 2 (admin) | Accès illimité à l'API et droits d'édition sur toutes les collections | + +Par défaut, le rate limiting est de 100 requêtes par minute pour tous les niveaux. Il est de 10 requêtes par minute pour le niveau 0 (user) pour les endpoints tagués *Core*. diff --git a/docs/tutorials/import_knowledge_database.ipynb b/docs/tutorials/import_knowledge_database.ipynb index 0934b2e..35ec336 100644 --- a/docs/tutorials/import_knowledge_database.ipynb +++ b/docs/tutorials/import_knowledge_database.ipynb @@ -86,7 +86,7 @@ "source": [ "Le format du fichier doit être JSON, ici nous n'avons pas besoin de le convertir. En revanche, il est possible que le fichier ne respecte pas la structure définit dans la documentation de l'API.\n", "\n", - "De plus le fichier ne doit pas dépasser 10MB, il est donc nécessaire de le découper en plusieurs fichiers." + "De plus le fichier ne doit pas dépasser 20MB, il est donc nécessaire de le découper en plusieurs fichiers." ] }, { @@ -252,7 +252,7 @@ "id": "092c9bb7", "metadata": {}, "source": [ - "La taille du fichier dépasse 10MB, il est donc nécessaire de le découper en plusieurs fichiers." + "La taille du fichier dépasse 20MB, il est donc nécessaire de le découper en plusieurs fichiers." ] }, { @@ -268,7 +268,7 @@ "\n", " batch_file_path = f\"tmp_{i}.json\"\n", " json.dump(batch_file, open(batch_file_path, \"w\"))\n", - " assert os.path.getsize(batch_file_path) < 10 * 1024 * 1024, \"Le fichier ne doit pas dépasser 5MB\"\n", + " assert os.path.getsize(batch_file_path) < 20 * 1024 * 1024\n", "\n", " files = {\"file\": (os.path.basename(batch_file_path), open(batch_file_path, \"rb\"), \"application/json\")}\n", " data = {\"request\": '{\"collection\": \"%s\"}' % collection_id}\n", @@ -352,7 +352,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 99fafe9..fb555a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ app = [ "beautifulsoup4==4.12.3", "duckduckgo-search==6.2.13", "numpy==1.26.4", + "slowapi==0.1.9", ] dev = [ "ruff==0.6.5", diff --git a/ui/chat.py b/ui/chat.py index 10adf05..a0e2d58 100644 --- a/ui/chat.py +++ b/ui/chat.py @@ -35,11 +35,11 @@ st.title("RAG parameters") params["rag"]["embeddings_model"] = st.selectbox("Embeddings model", embeddings_models) model_collections = [ - f"{collection["name"]} - {collection["id"]}" for collection in collections if collection["model"] == params["rag"]["embeddings_model"] + f"{collection["id"]} - {collection["name"]}" for collection in collections if collection["model"] == params["rag"]["embeddings_model"] ] + ["internet"] if model_collections: selected_collections = st.multiselect(label="Collections", options=model_collections, default=[model_collections[0]]) - params["rag"]["collections"] = [collection.split(" - ")[-1] for collection in selected_collections] + params["rag"]["collections"] = [collection.split(" - ")[0] for collection in selected_collections] params["rag"]["k"] = st.number_input("Top K", value=3) # Main diff --git a/ui/pages/documents.py b/ui/pages/documents.py index 2450120..bd4fde3 100644 --- a/ui/pages/documents.py +++ b/ui/pages/documents.py @@ -57,10 +57,10 @@ with st.expander("Delete a collection", icon="📦"): collection = st.selectbox( "Select collection to delete", - [f"{collection["name"]} - {collection["id"]}" for collection in collections], + [f"{collection["id"]} - {collection["name"]}" for collection in collections if collection["type"] == PRIVATE_COLLECTION_TYPE], key="delete_collection_selectbox", ) - collection_id = collection.split(" - ")[-1] if collection else None + collection_id = collection.split(" - ")[0] if collection else None submit_delete = st.button("Delete", disabled=not collection_id, key="delete_collection_button") if submit_delete: delete_collection(api_key=API_KEY, collection_id=collection_id) @@ -93,10 +93,10 @@ with st.expander("Upload a file", icon="📑"): collection = st.selectbox( "Select a collection", - [f"{collection["name"]} - {collection["id"]}" for collection in collections if collection["type"] == PRIVATE_COLLECTION_TYPE], + [f"{collection["id"]} - {collection["name"]}" for collection in collections if collection["type"] == PRIVATE_COLLECTION_TYPE], key="upload_file_selectbox", ) - collection_id = collection.split(" - ")[-1] + collection_id = collection.split(" - ")[0] file_to_upload = st.file_uploader("File", type=["pdf", "html", "json"]) submit_upload = st.button("Upload", disabled=not collection_id or not file_to_upload) if file_to_upload and submit_upload and collection_id: @@ -108,8 +108,8 @@ ## Delete files with col2: with st.expander("Delete a document", icon="🗑️"): - document = st.selectbox("Select document to delete", [f"{document["name"]} - {document["id"]}" for document in documents]) - document_id = document.split(" - ")[-1] if document else None + document = st.selectbox("Select document to delete", [f"{document["id"]} - {document["name"]}" for document in documents]) + document_id = document.split(" - ")[0] if document else None submit_delete = st.button("Delete", disabled=not document_id, key="delete_document_button") if submit_delete: document_collection = [document["collection_id"] for document in documents if document["id"] == document_id][0]