Skip to content

Commit

Permalink
feat: add settings for config
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Nov 21, 2024
1 parent 19f3263 commit 13f2973
Show file tree
Hide file tree
Showing 26 changed files with 180 additions and 155 deletions.
4 changes: 2 additions & 2 deletions app/endpoints/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from app.schemas.audio import AudioTranscription, AudioTranscriptionVerbose
from app.schemas.config import AUDIO_MODEL_TYPE
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.security import check_api_key, check_rate_limit, User
from app.utils.lifespan import clients, limiter
from app.utils.exceptions import ModelNotFoundException
Expand All @@ -16,7 +16,7 @@


@router.post("/audio/transcriptions")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def audio_transcriptions(
request: Request,
file: UploadFile = File(...),
Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest
from app.schemas.security import User
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit

router = APIRouter()


@router.post("/chat/completions")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def chat_completions(
request: Request, body: ChatCompletionRequest, user: User = Security(check_api_key)
) -> Union[ChatCompletion, ChatCompletionChunk]:
Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key, check_rate_limit
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.lifespan import limiter

router = APIRouter()


@router.get("/chunks/{collection}/{document}")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def get_chunks(
request: Request,
collection: UUID,
Expand Down
8 changes: 4 additions & 4 deletions app/endpoints/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from app.schemas.security import User
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.variables import INTERNET_COLLECTION_ID, PUBLIC_COLLECTION_TYPE

router = APIRouter()


@router.post("/collections")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def create_collection(request: Request, body: CollectionRequest, user: User = Security(check_api_key)) -> Response:
"""
Create a new collection.
Expand All @@ -35,7 +35,7 @@ async def create_collection(request: Request, body: CollectionRequest, user: Use


@router.get("/collections")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def get_collections(request: Request, user: User = Security(check_api_key)) -> Union[Collection, Collections]:
"""
Get list of collections.
Expand All @@ -54,7 +54,7 @@ async def get_collections(request: Request, user: User = Security(check_api_key)


@router.delete("/collections/{collection}")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def delete_collections(request: Request, collection: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a collection.
Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

from app.schemas.completions import CompletionRequest, Completions
from app.schemas.security import User
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit

router = APIRouter()


@router.post("/completions")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def completions(request: Request, body: CompletionRequest, user: User = Security(check_api_key)) -> Completions:
"""
Completion API similar to OpenAI's API.
Expand Down
6 changes: 3 additions & 3 deletions app/endpoints/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from app.schemas.documents import Documents
from app.schemas.security import User
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit

router = APIRouter()


@router.get("/documents/{collection}")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def get_documents(
request: Request,
collection: UUID,
Expand All @@ -32,7 +32,7 @@ async def get_documents(


@router.delete("/documents/{collection}/{document}")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def delete_document(request: Request, collection: UUID, document: UUID, user: User = Security(check_api_key)) -> Response:
"""
Delete a document and relative collections.
Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from app.schemas.embeddings import Embeddings, EmbeddingsRequest
from app.schemas.security import User
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.exceptions import ContextLengthExceededException, WrongModelTypeException
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
Expand All @@ -13,7 +13,7 @@


@router.post("/embeddings")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(check_api_key)) -> Embeddings:
"""
Embedding API similar to OpenAI's API.
Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import APIRouter, Request, Security

from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.schemas.models import Model, Models
from app.schemas.security import User
from app.utils.lifespan import clients, limiter
Expand All @@ -13,7 +13,7 @@

@router.get("/models/{model:path}")
@router.get("/models")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def models(request: Request, model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]:
"""
Model API similar to OpenAI's API.
Expand Down
4 changes: 2 additions & 2 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from app.helpers import SearchOnInternet
from app.schemas.search import Searches, SearchRequest
from app.schemas.security import User
from app.utils.config import DEFAULT_RATE_LIMIT
from app.utils.config import settings
from app.utils.lifespan import clients, limiter
from app.utils.security import check_api_key, check_rate_limit
from app.utils.variables import INTERNET_COLLECTION_ID
Expand All @@ -12,7 +12,7 @@


@router.post("/search")
@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request))
@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request))
async def search(request: Request, body: SearchRequest, user: User = Security(check_api_key)) -> Searches:
"""
Similarity search for chunks in the vector store or on the internet.
Expand Down
15 changes: 8 additions & 7 deletions app/helpers/_clientsmanager.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
from redis import Redis as CacheManager
from redis.connection import ConnectionPool

from app.schemas.config import Config

from app.schemas.config import Settings

from ._modelclients import ModelClients
from ._authenticationclient import AuthenticationClient
from ._vectorstore import VectorStore


class ClientsManager:
def __init__(self, config: Config):
self.config = config
def __init__(self, settings: Settings) -> None:
self.settings = settings

def set(self):
# set models
self.models = ModelClients(config=self.config)
self.models = ModelClients(settings=self.settings)

# set cache
self.cache = CacheManager(connection_pool=ConnectionPool(**self.config.databases.cache.args))
self.cache = CacheManager(connection_pool=ConnectionPool(**self.settings.config.databases.cache.args))

# set vectors
self.vectors = VectorStore(models=self.models, **self.config.databases.vectors.args)
self.vectors = VectorStore(models=self.models, **self.settings.config.databases.vectors.args)

# set auth
self.auth = AuthenticationClient(cache=self.cache, **self.config.auth.args) if self.config.auth else None
self.auth = AuthenticationClient(cache=self.cache, **self.settings.config.auth.args) if self.settings.config.auth else None

def clear(self):
self.vectors.close()
44 changes: 22 additions & 22 deletions app/helpers/_modelclients.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from functools import partial
import time
from typing import Dict, List, Literal
from typing import Dict, List, Literal, Any

from openai import OpenAI, AsyncOpenAI
import requests

from app.schemas.config import Config
from app.schemas.config import Settings
from app.schemas.embeddings import Embeddings
from app.schemas.models import Model, Models
from app.utils.config import logger, DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL, DEFAULT_INTERNET_LANGUAGE_MODEL_URL
from app.utils.logging import logger
from app.utils.exceptions import ContextLengthExceededException, ModelNotAvailableException, ModelNotFoundException
from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE


def get_models_list(self, *args, **kwargs):
def get_models_list(self, *args, **kwargs) -> Models:
"""
Custom method to overwrite OpenAI's list method (client.models.list()). This method support
embeddings API models deployed with HuggingFace Text Embeddings Inference (see: https://github.com/huggingface/text-embeddings-inference).
Expand Down Expand Up @@ -73,8 +73,8 @@ def get_models_list(self, *args, **kwargs):
return Models(data=[data])


def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True):
# TODO: remove this methode and use better context length handling
def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True) -> bool:
# TODO: remove this methode and use better context length handling (by catch context length error model)
headers = {"Authorization": f"Bearer {self.api_key}"}
prompt = "\n".join([message["role"] + ": " + message["content"] for message in messages])

Expand All @@ -83,7 +83,7 @@ def check_context_length(self, messages: List[Dict[str, str]], add_special_token
elif self.type == EMBEDDINGS_MODEL_TYPE:
data = {"inputs": prompt, "add_special_tokens": add_special_tokens}

response = requests.post(str(self.base_url).replace("/v1/", "/tokenize"), json=data, headers=headers)
response = requests.post(url=str(self.base_url).replace("/v1/", "/tokenize"), json=data, headers=headers)
response.raise_for_status()
response = response.json()

Expand All @@ -110,7 +110,7 @@ def create_embeddings(self, *args, **kwargs):
class ModelClient(OpenAI):
DEFAULT_TIMEOUT = 120

def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *args, **kwargs):
def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *args, **kwargs) -> None:
"""
ModelClient class extends OpenAI class to support custom methods.
"""
Expand All @@ -120,7 +120,7 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *ar
# set attributes for unavailable models
self.id = ""
self.owned_by = ""
self.created = round(time.time())
self.created = round(number=time.time())
self.max_context_length = None

# set real attributes if model is available
Expand All @@ -139,18 +139,18 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *ar
class AsyncModelClient(AsyncOpenAI):
DEFAULT_TIMEOUT = 120

def __init__(self, type=Literal[AUDIO_MODEL_TYPE], *args, **kwargs):
def __init__(self, type=Literal[AUDIO_MODEL_TYPE], *args, **kwargs) -> None:
"""
AsyncModelClient class extends AsyncOpenAI class to support custom methods.
"""
timeout = 60 if type == AUDIO_MODEL_TYPE else self.DEFAULT_TIMEOUT
super().__init__(timeout=timeout, *args, **kwargs)

super().__init__(timeout=self.DEFAULT_TIMEOUT, *args, **kwargs)
self.type = type

# set attributes for unavailable models
self.id = ""
self.owned_by = ""
self.created = round(time.time())
self.created = round(number=time.time())
self.max_context_length = None

# set real attributes if model is available
Expand All @@ -170,29 +170,29 @@ class ModelClients(dict):
Overwrite __getitem__ method to raise a 404 error if model is not found.
"""

def __init__(self, config: Config):
for model_config in config.models:
def __init__(self, settings: Settings) -> None:
for model_config in settings.config.models:
model_client_class = ModelClient if model_config.type != AUDIO_MODEL_TYPE else AsyncModelClient
model = model_client_class(base_url=model_config.url, api_key=model_config.key, type=model_config.type)
if model.status == "unavailable":
logger.error(f"unavailable model API on {model_config.url}, skipping.")
logger.error(msg=f"unavailable model API on {model_config.url}, skipping.")
continue
self.__setitem__(model.id, model)
self.__setitem__(key=model.id, value=model)

if model_config.url == DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL:
if model_config.url == settings.default_internet_embeddings_model_url:
self.DEFAULT_INTERNET_EMBEDDINGS_MODEL_ID = model.id
elif model_config.url == DEFAULT_INTERNET_LANGUAGE_MODEL_URL:
elif model_config.url == settings.default_internet_language_model_url:
self.DEFAULT_INTERNET_LANGUAGE_MODEL_ID = model.id

assert "DEFAULT_INTERNET_EMBEDDINGS_MODEL_ID" in self.__dict__, "Default internet embeddings model is unavailable."
assert "DEFAULT_INTERNET_LANGUAGE_MODEL_ID" in self.__dict__, "Default internet language model is unavailable."

def __setitem__(self, key: str, value):
def __setitem__(self, key: str, value) -> None:
if any(key == k for k in self.keys()):
raise KeyError(f"Model id {key} is duplicated, not allowed.")
raise KeyError(msg=f"Model id {key} is duplicated, not allowed.")
super().__setitem__(key, value)

def __getitem__(self, key: str):
def __getitem__(self, key: str) -> Any:
try:
item = super().__getitem__(key)
assert item.status == "available", "Model not available."
Expand Down
2 changes: 1 addition & 1 deletion app/helpers/_searchoninternet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from app.helpers.chunkers import LangchainRecursiveCharacterTextSplitter
from app.helpers.parsers import HTMLParser
from app.schemas.search import Search
from app.utils.config import logger
from app.utils.logging import logger
from app.utils.variables import INTERNET_COLLECTION_ID


Expand Down
Loading

0 comments on commit 13f2973

Please sign in to comment.