Skip to content

Commit

Permalink
feat: add rate limiting (#40)
Browse files Browse the repository at this point in the history
* feat: add rate limiting

* feat: add redoc

* feat: extends rate limiting

---------

Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
leoguillaume and leoguillaumegouv authored Oct 21, 2024
1 parent 420ca6a commit 4d1211e
Show file tree
Hide file tree
Showing 32 changed files with 365 additions and 223 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ Tous les changements notables de l'application sont documentés dans ce fichier.
- 🔄 Refactoring
- ❌ Deprecated

## [Alpha] - 2024-10-21

- 🎉 Ajout de la limitation de débit (*rate limiting*) lorsque l'authentification est activée.
- 📚 Ajout d'une documentation (./docs/security.md) sur l'authentification et la limitation de débit.
- 🧪 Ajout de tests pour la limitation de débit.
- 📚 Amélioration de la documentation [README.md](./README.md).
- 📚 La documentation est maintenant accessible à l'URL `/documentation` et le swagger à l'URL `/swagger`.
- 🔄 Optimisation du comptage des documents dans Qdrant.

## [Alpha] - 2024-10-09

- 🎉 Ajout d'un status du modèle dans le retour du endpoint GET `/v1/models`. Ce status permet de vérifier si le modèle est disponible ou non.
Expand Down
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
# Albert API
![](https://img.shields.io/badge/python-3.12-green) ![](https://img.shields.io/badge/vLLM-v0.5.5-blue) ![](https://img.shields.io/badge/HuggingFace%20Text%20Embeddings%20Inference-1.5-red)
<div id="toc"><ul align="center" style="list-style: none">
<summary><h1>Albert API</h1></summary>

Albert API est une API open source d'IA générative développée par Etalab. Elle permet d'être un proxy entre des modèles de langage et vos données. Elle agrège les services suivants :
![](https://img.shields.io/badge/version-alpha-yellow) ![](https://img.shields.io/badge/Python-3.12-green) ![](https://img.shields.io/badge/vLLM-v0.5.5-blue) ![](https://img.shields.io/badge/HuggingFace%20Text%20Embeddings%20Inference-1.5-red)<br>
<a href="https://albert.api.etalab.gouv.fr/documentation"><b>Documentation</b></a> | <a href="https://github.com/etalab-ia/albert-api/blob/main/CHANGELOG.md"><b>Changelog</b></a> | <a href="https://huggingface.co/AgentPublic"><b>HuggingFace</b></a>
<br><br>
</ul></div>

Albert API est une initiative d'[Etalab](https://www.etalab.gouv.fr/). Il s'agit d'une API open source d'IA générative développée par Etalab. Elle permet d'être un proxy entre des modèles de langage et vos données. Elle agrège les services suivants :
- servir des modèles de langage avec [vLLM](https://github.com/vllm-project/vllm)
- servir des modèles d'embeddings avec [HuggingFace Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
- accès un *vector store* avec [Qdrant](https://qdrant.tech/) pour la recherche de similarité

En se basant sur les conventions définies par OpenAI, l'API Albert expose des endpoints qui peuvent être appelés avec le [client officiel python d'OpenAI](https://github.com/openai/openai-python/tree/main). Ce formalisme permet d'intégrer facilement l'API Albert avec des bibliothèques tierces comme [Langchain](https://www.langchain.com/) ou [LlamaIndex](https://www.llamaindex.ai/).

## 🚀 Nouveautés
## ⚙️ Fonctionnalités

Vous trouverez les changelogs des différentes versions d'Albert API dans le fichier [CHANGELOG.md](./CHANGELOG.md).
### Interface utilisateur (playground)

## ⚙️ Fonctionnalités
L'API Albert expose une interface utilisateur permettant de tester les différentes fonctionnalités, consultable ici [ici](https://albert.api.etalab.gouv.fr).

### Converser avec un modèle de langage (chat memory)

Expand Down Expand Up @@ -50,6 +55,7 @@ L'API Albert permet d'importer sa base de connaissances dans une base vectoriell

Albert API est un projet open source, vous pouvez contribuer au projet en lisant notre [guide de contribution](./CONTRIBUTING.md).

## Installation
## 🚀 Installation

Pour déployer l'API Albert sur votre propre infrastructure, suivez la [documentation](./docs/deployment.md).

Pour déployer l'API Albert sur votre propre infrastructure, suivez la [documentation](./docs/deployment.md).
23 changes: 12 additions & 11 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
from typing import Union

from fastapi import APIRouter, Security
from fastapi import APIRouter, Request, Security
from fastapi.responses import StreamingResponse
import httpx

from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key

from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit

router = APIRouter()


@router.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest, user: User = Security(check_api_key)) -> Union[ChatCompletion, ChatCompletionChunk]:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def chat_completions(
request: Request, body: ChatCompletionRequest, user: User = Security(check_api_key)
) -> Union[ChatCompletion, ChatCompletionChunk]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create for the API specification.
"""

request = dict(request)
client = clients.models[request["model"]]
client = clients.models[body.model]
url = f"{client.base_url}chat/completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

# non stream case
if not request["stream"]:
if not body.stream:
async with httpx.AsyncClient(timeout=20) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=request)
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()

data = response.json()
Expand All @@ -40,4 +41,4 @@ async def forward_stream(url: str, headers: dict, request: dict):
async for chunk in response.aiter_raw():
yield chunk

return StreamingResponse(forward_stream(url, headers, request), media_type="text/event-stream")
return StreamingResponse(forward_stream(url, headers, body.model_dump()), media_type="text/event-stream")
8 changes: 6 additions & 2 deletions app/endpoints/chunks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import Optional
from uuid import UUID

from fastapi import APIRouter, Security, Query
from fastapi import APIRouter, Request, Security, Query

from app.schemas.chunks import Chunks
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.security import check_api_key, check_rate_limit
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.lifespan import limiter

router = APIRouter()


@router.get("/chunks/{collection}/{document}")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def get_chunks(
request: Request,
collection: UUID,
document: UUID,
limit: Optional[int] = Query(default=10, ge=1, le=100),
Expand Down
19 changes: 11 additions & 8 deletions app/endpoints/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,36 @@
import uuid
from uuid import UUID

from fastapi import APIRouter, Response, Security
from fastapi import APIRouter, Request, Response, Security
from fastapi.responses import JSONResponse


from app.schemas.collections import Collection, CollectionRequest, Collections
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.variables import INTERNET_COLLECTION_ID, PUBLIC_COLLECTION_TYPE

router = APIRouter()


@router.post("/collections")
async def create_collection(request: CollectionRequest, user: User = Security(check_api_key)) -> Response:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def create_collection(request: Request, body: CollectionRequest, user: User = Security(check_api_key)) -> Response:
"""
Create a new collection.
"""
collection_id = str(uuid.uuid4())
clients.vectors.create_collection(
collection_id=collection_id, collection_name=request.name, collection_model=request.model, collection_type=request.type, user=user
collection_id=collection_id, collection_name=body.name, collection_model=body.model, collection_type=body.type, user=user
)

return JSONResponse(status_code=201, content={"id": collection_id})


@router.get("/collections")
async def get_collections(user: User = Security(check_api_key)) -> Union[Collection, Collections]:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def get_collections(request: Request, user: User = Security(check_api_key)) -> Union[Collection, Collections]:
"""
Get list of collections.
"""
Expand All @@ -47,7 +49,8 @@ async def get_collections(user: User = Security(check_api_key)) -> Union[Collect


@router.delete("/collections/{collection}")
async def delete_collections(collection: UUID, user: User = Security(check_api_key)) -> Response:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def delete_collections(request: Request, collection: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a collection.
"""
Expand Down
16 changes: 8 additions & 8 deletions app/endpoints/completions.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from fastapi import APIRouter, Security
from fastapi import APIRouter, Request, Security
import httpx

from app.schemas.completions import CompletionRequest, Completions
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit

router = APIRouter()


@router.post("/completions")
async def completions(request: CompletionRequest, user: User = Security(check_api_key)) -> Completions:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def completions(request: Request, body: CompletionRequest, user: User = Security(check_api_key)) -> Completions:
"""
Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create for the API specification.
"""

request = dict(request)
client = clients.models[request["model"]]
client = clients.models[body.model]
url = f"{client.base_url}completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

async with httpx.AsyncClient(timeout=20) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=request)
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
response.raise_for_status()

data = response.json()
Expand Down
22 changes: 12 additions & 10 deletions app/endpoints/documents.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from typing import Optional
from uuid import UUID

from fastapi import APIRouter, Response, Security, Query

from fastapi import APIRouter, Query, Request, Response, Security

from app.schemas.documents import Documents
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit

router = APIRouter()


@router.get("/documents/{collection}")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def get_documents(
collection: UUID, limit: Optional[int] = Query(default=10, ge=1, le=100), offset: Optional[UUID] = None, user: User = Security(check_api_key)
request: Request,
collection: UUID,
limit: Optional[int] = Query(default=10, ge=1, le=100),
offset: Optional[UUID] = None,
user: User = Security(check_api_key),
) -> Documents:
"""
Get all documents ID from a collection.
Expand All @@ -27,11 +32,8 @@ async def get_documents(


@router.delete("/documents/{collection}/{document}")
async def delete_document(
collection: UUID,
document: UUID,
user: User = Security(check_api_key),
) -> Response:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def delete_document(request: Request, collection: UUID, document: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a document and relative collections.
"""
Expand Down
17 changes: 9 additions & 8 deletions app/endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
from fastapi import APIRouter, Security
from fastapi import APIRouter, Request, Security
import httpx

from app.schemas.embeddings import Embeddings, EmbeddingsRequest
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.variables import EMBEDDINGS_MODEL_TYPE
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.exceptions import ContextLengthExceededException, WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import EMBEDDINGS_MODEL_TYPE

router = APIRouter()


@router.post("/embeddings")
async def embeddings(request: EmbeddingsRequest, user: User = Security(check_api_key)) -> Embeddings:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(check_api_key)) -> Embeddings:
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification.
"""

request = dict(request)
client = clients.models[request["model"]]
client = clients.models[body.model]
if client.type != EMBEDDINGS_MODEL_TYPE:
raise WrongModelTypeException()

url = f"{client.base_url}embeddings"
headers = {"Authorization": f"Bearer {client.api_key}"}

async with httpx.AsyncClient(timeout=20) as async_client:
response = await async_client.request(method="POST", url=url, headers=headers, json=request)
response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump())
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
Expand Down
2 changes: 1 addition & 1 deletion app/endpoints/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def upload_file(file: UploadFile = File(...), request: FilesRequest = Body
chunker_args = ChunkerArgs().model_dump()
chunker_name = None

chunker_args["length_function"] = len if chunker_args["length_function"] == "len" else None
chunker_args["length_function"] = len if chunker_args["length_function"] == "len" else chunker_args["length_function"]

uploader = FileUploader(vectors=clients.vectors, user=user, collection_id=request.collection)
output = uploader.parse(file=file)
Expand Down
10 changes: 6 additions & 4 deletions app/endpoints/models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import Optional, Union

from fastapi import APIRouter, Security
from fastapi import APIRouter, Request, Security

from app.utils.config import DEFAULT_RATE_LIMIT
from app.schemas.models import Model, Models
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit

router = APIRouter()


@router.get("/models/{model:path}")
@router.get("/models")
async def models(model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def models(request: Request, model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]:
"""
Model API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/models/list for the API specification.
Expand Down
28 changes: 14 additions & 14 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
from fastapi import APIRouter, Security
from fastapi import APIRouter, Request, Security

from app.helpers import SearchOnInternet
from app.schemas.search import Searches, SearchRequest
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import INTERNET_COLLECTION_ID

router = APIRouter()


@router.post("/search")
async def search(request: SearchRequest, user: User = Security(check_api_key)) -> Searches:
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
async def search(request: Request, body: SearchRequest, user: User = Security(check_api_key)) -> Searches:
"""
Similarity search for chunks in the vector store or on the internet.
"""

data = []
if INTERNET_COLLECTION_ID in request.collections:
request.collections.remove(INTERNET_COLLECTION_ID)
if INTERNET_COLLECTION_ID in body.collections:
body.collections.remove(INTERNET_COLLECTION_ID)
internet = SearchOnInternet(models=clients.models)
if len(request.collections) > 0:
collection_model = clients.vectors.get_collections(collection_ids=request.collections, user=user)[0].model
if len(body.collections) > 0:
collection_model = clients.vectors.get_collections(collection_ids=body.collections, user=user)[0].model
else:
collection_model = None
data.extend(internet.search(prompt=request.prompt, n=4, model_id=collection_model, score_threshold=request.score_threshold))
data.extend(internet.search(prompt=body.prompt, n=4, model_id=collection_model, score_threshold=body.score_threshold))

if len(request.collections) > 0:
if len(body.collections) > 0:
data.extend(
clients.vectors.search(
prompt=request.prompt, collection_ids=request.collections, k=request.k, score_threshold=request.score_threshold, user=user
)
clients.vectors.search(prompt=body.prompt, collection_ids=body.collections, k=body.k, score_threshold=body.score_threshold, user=user)
)

data = sorted(data, key=lambda x: x.score, reverse=False)[: request.k]
data = sorted(data, key=lambda x: x.score, reverse=False)[: body.k]

return Searches(data=data)
Loading

0 comments on commit 4d1211e

Please sign in to comment.