diff --git a/app/endpoints/audio.py b/app/endpoints/audio.py index 67170df..66a0a21 100644 --- a/app/endpoints/audio.py +++ b/app/endpoints/audio.py @@ -4,7 +4,7 @@ from app.schemas.audio import AudioTranscription, AudioTranscriptionVerbose from app.schemas.config import AUDIO_MODEL_TYPE -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.security import check_api_key, check_rate_limit, User from app.utils.lifespan import clients, limiter from app.utils.exceptions import ModelNotFoundException @@ -16,7 +16,7 @@ @router.post("/audio/transcriptions") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def audio_transcriptions( request: Request, file: UploadFile = File(...), diff --git a/app/endpoints/chat.py b/app/endpoints/chat.py index 6201d88..e76611e 100644 --- a/app/endpoints/chat.py +++ b/app/endpoints/chat.py @@ -6,7 +6,7 @@ from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest from app.schemas.security import User -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit @@ -14,7 +14,7 @@ @router.post("/chat/completions") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def chat_completions( request: Request, body: ChatCompletionRequest, user: User = Security(check_api_key) ) -> Union[ChatCompletion, ChatCompletionChunk]: diff --git a/app/endpoints/chunks.py b/app/endpoints/chunks.py index 8c52983..aabccd8 100644 --- a/app/endpoints/chunks.py +++ b/app/endpoints/chunks.py @@ -7,14 +7,14 @@ from app.schemas.security import User from app.utils.lifespan import clients from app.utils.security import check_api_key, check_rate_limit -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.lifespan import limiter router = APIRouter() @router.get("/chunks/{collection}/{document}") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def get_chunks( request: Request, collection: UUID, diff --git a/app/endpoints/collections.py b/app/endpoints/collections.py index d8370be..0bac6ac 100644 --- a/app/endpoints/collections.py +++ b/app/endpoints/collections.py @@ -9,14 +9,14 @@ from app.schemas.security import User from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.variables import INTERNET_COLLECTION_ID, PUBLIC_COLLECTION_TYPE router = APIRouter() @router.post("/collections") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def create_collection(request: Request, body: CollectionRequest, user: User = Security(check_api_key)) -> Response: """ Create a new collection. @@ -35,7 +35,7 @@ async def create_collection(request: Request, body: CollectionRequest, user: Use @router.get("/collections") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def get_collections(request: Request, user: User = Security(check_api_key)) -> Union[Collection, Collections]: """ Get list of collections. @@ -54,7 +54,7 @@ async def get_collections(request: Request, user: User = Security(check_api_key) @router.delete("/collections/{collection}") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def delete_collections(request: Request, collection: UUID, user: User = Security(check_api_key)) -> Response: """ Delete a collection. diff --git a/app/endpoints/completions.py b/app/endpoints/completions.py index 119e39a..051a8d4 100644 --- a/app/endpoints/completions.py +++ b/app/endpoints/completions.py @@ -3,7 +3,7 @@ from app.schemas.completions import CompletionRequest, Completions from app.schemas.security import User -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit @@ -11,7 +11,7 @@ @router.post("/completions") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def completions(request: Request, body: CompletionRequest, user: User = Security(check_api_key)) -> Completions: """ Completion API similar to OpenAI's API. diff --git a/app/endpoints/documents.py b/app/endpoints/documents.py index a84cc0d..5fbd18a 100644 --- a/app/endpoints/documents.py +++ b/app/endpoints/documents.py @@ -5,7 +5,7 @@ from app.schemas.documents import Documents from app.schemas.security import User -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit @@ -13,7 +13,7 @@ @router.get("/documents/{collection}") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def get_documents( request: Request, collection: UUID, @@ -32,7 +32,7 @@ async def get_documents( @router.delete("/documents/{collection}/{document}") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def delete_document(request: Request, collection: UUID, document: UUID, user: User = Security(check_api_key)) -> Response: """ Delete a document and relative collections. diff --git a/app/endpoints/embeddings.py b/app/endpoints/embeddings.py index fae6a90..e97a206 100644 --- a/app/endpoints/embeddings.py +++ b/app/endpoints/embeddings.py @@ -3,7 +3,7 @@ from app.schemas.embeddings import Embeddings, EmbeddingsRequest from app.schemas.security import User -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.exceptions import ContextLengthExceededException, WrongModelTypeException from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit @@ -13,7 +13,7 @@ @router.post("/embeddings") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(check_api_key)) -> Embeddings: """ Embedding API similar to OpenAI's API. diff --git a/app/endpoints/models.py b/app/endpoints/models.py index 9b3a9eb..981bb18 100644 --- a/app/endpoints/models.py +++ b/app/endpoints/models.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Request, Security -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.schemas.models import Model, Models from app.schemas.security import User from app.utils.lifespan import clients, limiter @@ -13,7 +13,7 @@ @router.get("/models/{model:path}") @router.get("/models") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def models(request: Request, model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]: """ Model API similar to OpenAI's API. diff --git a/app/endpoints/search.py b/app/endpoints/search.py index 77dd6d0..c249a9c 100644 --- a/app/endpoints/search.py +++ b/app/endpoints/search.py @@ -3,7 +3,7 @@ from app.helpers import SearchOnInternet from app.schemas.search import Searches, SearchRequest from app.schemas.security import User -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit from app.utils.variables import INTERNET_COLLECTION_ID @@ -12,7 +12,7 @@ @router.post("/search") -@limiter.limit(DEFAULT_RATE_LIMIT, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) async def search(request: Request, body: SearchRequest, user: User = Security(check_api_key)) -> Searches: """ Similarity search for chunks in the vector store or on the internet. diff --git a/app/helpers/_clientsmanager.py b/app/helpers/_clientsmanager.py index 9f7a904..429a5f5 100644 --- a/app/helpers/_clientsmanager.py +++ b/app/helpers/_clientsmanager.py @@ -1,7 +1,8 @@ from redis import Redis as CacheManager from redis.connection import ConnectionPool -from app.schemas.config import Config + +from app.schemas.config import Settings from ._modelclients import ModelClients from ._authenticationclient import AuthenticationClient @@ -9,21 +10,21 @@ class ClientsManager: - def __init__(self, config: Config): - self.config = config + def __init__(self, settings: Settings) -> None: + self.settings = settings def set(self): # set models - self.models = ModelClients(config=self.config) + self.models = ModelClients(settings=self.settings) # set cache - self.cache = CacheManager(connection_pool=ConnectionPool(**self.config.databases.cache.args)) + self.cache = CacheManager(connection_pool=ConnectionPool(**self.settings.config.databases.cache.args)) # set vectors - self.vectors = VectorStore(models=self.models, **self.config.databases.vectors.args) + self.vectors = VectorStore(models=self.models, **self.settings.config.databases.vectors.args) # set auth - self.auth = AuthenticationClient(cache=self.cache, **self.config.auth.args) if self.config.auth else None + self.auth = AuthenticationClient(cache=self.cache, **self.settings.config.auth.args) if self.settings.config.auth else None def clear(self): self.vectors.close() diff --git a/app/helpers/_modelclients.py b/app/helpers/_modelclients.py index 655520e..ec5d52b 100644 --- a/app/helpers/_modelclients.py +++ b/app/helpers/_modelclients.py @@ -1,19 +1,19 @@ from functools import partial import time -from typing import Dict, List, Literal +from typing import Dict, List, Literal, Any from openai import OpenAI, AsyncOpenAI import requests -from app.schemas.config import Config +from app.schemas.config import Settings from app.schemas.embeddings import Embeddings from app.schemas.models import Model, Models -from app.utils.config import logger, DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL, DEFAULT_INTERNET_LANGUAGE_MODEL_URL +from app.utils.logging import logger from app.utils.exceptions import ContextLengthExceededException, ModelNotAvailableException, ModelNotFoundException from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE -def get_models_list(self, *args, **kwargs): +def get_models_list(self, *args, **kwargs) -> Models: """ Custom method to overwrite OpenAI's list method (client.models.list()). This method support embeddings API models deployed with HuggingFace Text Embeddings Inference (see: https://github.com/huggingface/text-embeddings-inference). @@ -73,8 +73,8 @@ def get_models_list(self, *args, **kwargs): return Models(data=[data]) -def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True): - # TODO: remove this methode and use better context length handling +def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True) -> bool: + # TODO: remove this methode and use better context length handling (by catch context length error model) headers = {"Authorization": f"Bearer {self.api_key}"} prompt = "\n".join([message["role"] + ": " + message["content"] for message in messages]) @@ -83,7 +83,7 @@ def check_context_length(self, messages: List[Dict[str, str]], add_special_token elif self.type == EMBEDDINGS_MODEL_TYPE: data = {"inputs": prompt, "add_special_tokens": add_special_tokens} - response = requests.post(str(self.base_url).replace("/v1/", "/tokenize"), json=data, headers=headers) + response = requests.post(url=str(self.base_url).replace("/v1/", "/tokenize"), json=data, headers=headers) response.raise_for_status() response = response.json() @@ -110,7 +110,7 @@ def create_embeddings(self, *args, **kwargs): class ModelClient(OpenAI): DEFAULT_TIMEOUT = 120 - def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *args, **kwargs): + def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *args, **kwargs) -> None: """ ModelClient class extends OpenAI class to support custom methods. """ @@ -120,7 +120,7 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *ar # set attributes for unavailable models self.id = "" self.owned_by = "" - self.created = round(time.time()) + self.created = round(number=time.time()) self.max_context_length = None # set real attributes if model is available @@ -139,18 +139,18 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *ar class AsyncModelClient(AsyncOpenAI): DEFAULT_TIMEOUT = 120 - def __init__(self, type=Literal[AUDIO_MODEL_TYPE], *args, **kwargs): + def __init__(self, type=Literal[AUDIO_MODEL_TYPE], *args, **kwargs) -> None: """ AsyncModelClient class extends AsyncOpenAI class to support custom methods. """ - timeout = 60 if type == AUDIO_MODEL_TYPE else self.DEFAULT_TIMEOUT - super().__init__(timeout=timeout, *args, **kwargs) + + super().__init__(timeout=self.DEFAULT_TIMEOUT, *args, **kwargs) self.type = type # set attributes for unavailable models self.id = "" self.owned_by = "" - self.created = round(time.time()) + self.created = round(number=time.time()) self.max_context_length = None # set real attributes if model is available @@ -170,29 +170,29 @@ class ModelClients(dict): Overwrite __getitem__ method to raise a 404 error if model is not found. """ - def __init__(self, config: Config): - for model_config in config.models: + def __init__(self, settings: Settings) -> None: + for model_config in settings.config.models: model_client_class = ModelClient if model_config.type != AUDIO_MODEL_TYPE else AsyncModelClient model = model_client_class(base_url=model_config.url, api_key=model_config.key, type=model_config.type) if model.status == "unavailable": - logger.error(f"unavailable model API on {model_config.url}, skipping.") + logger.error(msg=f"unavailable model API on {model_config.url}, skipping.") continue - self.__setitem__(model.id, model) + self.__setitem__(key=model.id, value=model) - if model_config.url == DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL: + if model_config.url == settings.default_internet_embeddings_model_url: self.DEFAULT_INTERNET_EMBEDDINGS_MODEL_ID = model.id - elif model_config.url == DEFAULT_INTERNET_LANGUAGE_MODEL_URL: + elif model_config.url == settings.default_internet_language_model_url: self.DEFAULT_INTERNET_LANGUAGE_MODEL_ID = model.id assert "DEFAULT_INTERNET_EMBEDDINGS_MODEL_ID" in self.__dict__, "Default internet embeddings model is unavailable." assert "DEFAULT_INTERNET_LANGUAGE_MODEL_ID" in self.__dict__, "Default internet language model is unavailable." - def __setitem__(self, key: str, value): + def __setitem__(self, key: str, value) -> None: if any(key == k for k in self.keys()): - raise KeyError(f"Model id {key} is duplicated, not allowed.") + raise KeyError(msg=f"Model id {key} is duplicated, not allowed.") super().__setitem__(key, value) - def __getitem__(self, key: str): + def __getitem__(self, key: str) -> Any: try: item = super().__getitem__(key) assert item.status == "available", "Model not available." diff --git a/app/helpers/_searchoninternet.py b/app/helpers/_searchoninternet.py index fd50056..c471f92 100644 --- a/app/helpers/_searchoninternet.py +++ b/app/helpers/_searchoninternet.py @@ -10,7 +10,7 @@ from app.helpers.chunkers import LangchainRecursiveCharacterTextSplitter from app.helpers.parsers import HTMLParser from app.schemas.search import Search -from app.utils.config import logger +from app.utils.logging import logger from app.utils.variables import INTERNET_COLLECTION_ID diff --git a/app/main.py b/app/main.py index 478e591..1cb5760 100644 --- a/app/main.py +++ b/app/main.py @@ -6,16 +6,16 @@ from app.endpoints import audio, chat, chunks, collections, completions, documents, embeddings, files, models, search from app.helpers import ContentSizeLimitMiddleware from app.schemas.security import User -from app.utils.config import APP_CONTACT_EMAIL, APP_CONTACT_URL, APP_DESCRIPTION, APP_VERSION +from app.utils.config import settings from app.utils.lifespan import lifespan from app.utils.security import check_api_key app = FastAPI( - title="Albert API", - version=APP_VERSION, - description=APP_DESCRIPTION, - contact={"url": APP_CONTACT_URL, "email": APP_CONTACT_EMAIL}, + title=settings.app_name, + version=settings.app_version, + description=settings.app_description, + contact={"url": settings.app_contact_url, "email": settings.app_contact_email}, licence_info={"name": "MIT License", "identifier": "MIT"}, lifespan=lifespan, docs_url="/swagger", @@ -23,13 +23,13 @@ ) # Middlewares -app.add_middleware(ContentSizeLimitMiddleware) -app.add_middleware(SlowAPIASGIMiddleware) +app.add_middleware(middleware_class=ContentSizeLimitMiddleware) +app.add_middleware(middleware_class=SlowAPIASGIMiddleware) # Monitoring -@app.get("/health", tags=["Monitoring"]) -def health(user: User = Security(check_api_key)): +@app.get(path="/health", tags=["Monitoring"]) +def health(user: User = Security(dependency=check_api_key)) -> Response: """ Health check. """ @@ -38,15 +38,15 @@ def health(user: User = Security(check_api_key)): # Core -app.include_router(models.router, tags=["Core"], prefix="/v1") -app.include_router(chat.router, tags=["Core"], prefix="/v1") -app.include_router(completions.router, tags=["Core"], prefix="/v1") -app.include_router(embeddings.router, tags=["Core"], prefix="/v1") -app.include_router(audio.router, tags=["Core"], prefix="/v1") +app.include_router(router=models.router, tags=["Core"], prefix="/v1") +app.include_router(router=chat.router, tags=["Core"], prefix="/v1") +app.include_router(router=completions.router, tags=["Core"], prefix="/v1") +app.include_router(router=embeddings.router, tags=["Core"], prefix="/v1") +app.include_router(router=audio.router, tags=["Core"], prefix="/v1") # RAG -app.include_router(search.router, tags=["Retrieval Augmented Generation"], prefix="/v1") -app.include_router(collections.router, tags=["Retrieval Augmented Generation"], prefix="/v1") -app.include_router(files.router, tags=["Retrieval Augmented Generation"], prefix="/v1") -app.include_router(documents.router, tags=["Retrieval Augmented Generation"], prefix="/v1") -app.include_router(chunks.router, tags=["Retrieval Augmented Generation"], prefix="/v1") +app.include_router(router=search.router, tags=["Retrieval Augmented Generation"], prefix="/v1") +app.include_router(router=collections.router, tags=["Retrieval Augmented Generation"], prefix="/v1") +app.include_router(router=files.router, tags=["Retrieval Augmented Generation"], prefix="/v1") +app.include_router(router=documents.router, tags=["Retrieval Augmented Generation"], prefix="/v1") +app.include_router(router=chunks.router, tags=["Retrieval Augmented Generation"], prefix="/v1") diff --git a/app/schemas/chat.py b/app/schemas/chat.py index 5f0e76b..33ca542 100644 --- a/app/schemas/chat.py +++ b/app/schemas/chat.py @@ -34,17 +34,17 @@ class ChatCompletionRequest(BaseModel): class ConfigDict: extra = "allow" - @model_validator(mode="before") - def validate_model(cls, value): - if clients.models[value["model"]].type != LANGUAGE_MODEL_TYPE: + @model_validator(mode="after") + def validate_model(cls, values): + if clients.models[values["model"]].type != LANGUAGE_MODEL_TYPE: raise WrongModelTypeException() - if not clients.models[value["model"]].check_context_length(messages=value["messages"]): + if not clients.models[values["model"]].check_context_length(messages=values["messages"]): raise ContextLengthExceededException() - if "max_tokens" in value and value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length: + if "max_tokens" in values and values["max_tokens"] is not None and values["max_tokens"] > clients.models[values["model"]].max_context_length: raise MaxTokensExceededException() - return value + return values class ChatCompletion(ChatCompletion): diff --git a/app/schemas/collections.py b/app/schemas/collections.py index 3330ebe..34b9088 100644 --- a/app/schemas/collections.py +++ b/app/schemas/collections.py @@ -28,7 +28,7 @@ class CollectionRequest(BaseModel): description: Optional[str] = Field(None) @field_validator("name", mode="before") - def strip(cls, v): - if isinstance(v, str): - v = v.strip() - return v + def strip(cls, name): + if isinstance(name, str): + name = name.strip() + return name diff --git a/app/schemas/completions.py b/app/schemas/completions.py index 35fd5a9..0459e7d 100644 --- a/app/schemas/completions.py +++ b/app/schemas/completions.py @@ -27,15 +27,15 @@ class CompletionRequest(BaseModel): top_p: Optional[float] = 1.0 user: Optional[str] = None - @model_validator(mode="before") - def validate_model(cls, value): - if clients.models[value["model"]].type != LANGUAGE_MODEL_TYPE: + @model_validator(mode="after") + def validate_model(cls, values): + if clients.models[values["model"]].type != LANGUAGE_MODEL_TYPE: raise WrongModelTypeException() - if not clients.models[value["model"]].check_context_length(messages=value["messages"]): + if not clients.models[values["model"]].check_context_length(messages=values["messages"]): raise ContextLengthExceededException() - if value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length: + if values["max_tokens"] is not None and values["max_tokens"] > clients.models[values["model"]].max_context_length: raise MaxTokensExceededException() diff --git a/app/schemas/config.py b/app/schemas/config.py index 6298c0e..649d702 100644 --- a/app/schemas/config.py +++ b/app/schemas/config.py @@ -1,8 +1,11 @@ +import os from typing import List, Literal, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic_settings import BaseSettings +import yaml -from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDIO_MODEL_TYPE +from app.utils.variables import AUDIO_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE class ConfigBaseModel(BaseModel): @@ -61,3 +64,57 @@ def validate_models(cls, values): raise ValueError("At least one embeddings model is required") return values + + +class Settings(BaseSettings): + # logging + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + + # config + config_file: str = "config.yml" + + # app + app_name: str = "Albert API" + app_contact_url: Optional[str] = None + app_contact_email: Optional[str] = None + app_version: str = "0.0.0" + app_description: str = "[See documentation](https://github.com/etalab-ia/albert-api/blob/main/README.md)" + + # models + default_internet_language_model_url: Optional[str] = None + default_internet_embeddings_model_url: Optional[str] = None + + # rate_limit + global_rate_limit: str = "100/minute" + default_rate_limit: str = "10/minute" + + class Config: + extra = "allow" + + @field_validator("config_file", mode="before") + def config_file_exists(cls, config_file): + assert os.path.exists(config_file), "Config file not found" + return config_file + + @model_validator(mode="after") + def validate_models(cls, values): + config = Config(**yaml.safe_load(stream=open(file=values.config_file, mode="r"))) + if not values.default_internet_language_model_url: + values.default_internet_language_model_url = [model.url for model in config.models if model.type == LANGUAGE_MODEL_TYPE][0] + + else: + assert values.default_internet_language_model_url in [ + model.url for model in config.models if model.type == LANGUAGE_MODEL_TYPE + ], "Wrong default internet language model url" + + if not values.default_internet_embeddings_model_url: + values.default_internet_embeddings_model_url = [model.url for model in config.models if model.type == EMBEDDINGS_MODEL_TYPE][0] + + else: + assert values.default_internet_embeddings_model_url in [ + model.url for model in config.models if model.type == EMBEDDINGS_MODEL_TYPE + ], "Wrong default internet embeddings model url" + + values.config = config + + return values diff --git a/app/schemas/files.py b/app/schemas/files.py index 24b2647..622ad7d 100644 --- a/app/schemas/files.py +++ b/app/schemas/files.py @@ -29,15 +29,15 @@ class FilesRequest(BaseModel): @model_validator(mode="before") @classmethod - def validate_to_json(cls, value): - if isinstance(value, str): - return cls(**json.loads(value)) - return value + def validate_to_json(cls, values): + if isinstance(values, str): + return cls(**json.loads(values)) + return values @field_validator("collection", mode="after") @classmethod - def convert_to_string(cls, value): - return str(value) + def convert_to_string(cls, collection): + return str(collection) class Json(BaseModel): diff --git a/app/schemas/search.py b/app/schemas/search.py index ee6e534..d5f1dd6 100644 --- a/app/schemas/search.py +++ b/app/schemas/search.py @@ -13,16 +13,16 @@ class SearchRequest(BaseModel): score_threshold: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="Score of cosine similarity threshold for filtering results") @field_validator("prompt") - def blank_string(value): - if value.strip() == "": + def blank_string(prompt): + if prompt.strip() == "": raise ValueError("Prompt cannot be empty") - return value + return prompt @field_validator("collections") - def convert_to_string(cls, v): - if v is None: + def convert_to_string(cls, collections): + if collections is None: return [] - return list(set(str(collection) for collection in v)) + return list(set(str(collection) for collection in collections)) class Search(BaseModel): diff --git a/app/tests/test_models.py b/app/tests/test_models.py index ff275e6..4e8f80f 100644 --- a/app/tests/test_models.py +++ b/app/tests/test_models.py @@ -3,7 +3,7 @@ import pytest from app.schemas.models import Model, Models -from app.utils.config import DEFAULT_RATE_LIMIT +from app.utils.config import settings @pytest.mark.usefixtures("args", "session_user", "session_admin") @@ -32,7 +32,7 @@ def test_get_models_non_existing_model(self, args, session_admin): def test_get_models_rate_limit(self, args, session_user): """Test the GET /models rate limiting.""" start = time.time() - limit = int(DEFAULT_RATE_LIMIT.replace("/minute", "")) + limit = int(settings.default_rate_limit.replace("/minute", "")) i = 0 while time.time() - start < 60: i += 1 diff --git a/app/tests/test_search.py b/app/tests/test_search.py index 49d6243..9e41969 100644 --- a/app/tests/test_search.py +++ b/app/tests/test_search.py @@ -7,7 +7,7 @@ from app.schemas.search import Search, Searches from app.utils.variables import EMBEDDINGS_MODEL_TYPE, INTERNET_COLLECTION_ID -from app.utils.config import logger +from app.utils.logging import logger @pytest.fixture(scope="module") diff --git a/app/utils/config.py b/app/utils/config.py index cc7355a..b35046b 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -1,49 +1,10 @@ -import logging -import os +from functools import lru_cache +from app.schemas.config import Settings -import yaml -from app.schemas.config import Config -from app.schemas.models import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE +@lru_cache +def get_settings() -> Settings: + return Settings() -logging.basicConfig(format="%(levelname)s:%(asctime)s:%(name)s: %(message)s", level=logging.INFO) -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("LOG_LEVEL", logging.DEBUG)) -# Configuration -CONFIG_FILE = os.getenv("CONFIG_FILE", "config.yml") -assert os.path.exists(CONFIG_FILE), f"error: configuration file {CONFIG_FILE} not found" -logger.info(f"loading configuration file: {CONFIG_FILE}") -CONFIG = Config(**yaml.safe_load(open(CONFIG_FILE, "r"))) - -# Metadata -APP_CONTACT_URL = os.getenv("APP_CONTACT_URL") -APP_CONTACT_EMAIL = os.getenv("APP_CONTACT_EMAIL") -APP_VERSION = os.getenv("APP_VERSION", "0.0.0") -APP_DESCRIPTION = os.getenv( - "APP_DESCRIPTION", - "[See documentation](https://github.com/etalab-ia/albert-api/blob/main/README.md)", -) - -# Models -DEFAULT_INTERNET_LANGUAGE_MODEL_URL = os.getenv( - "DEFAULT_INTERNET_LANGUAGE_MODEL_URL", [model.url for model in CONFIG.models if model.type == LANGUAGE_MODEL_TYPE][0] -) -DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL = os.getenv( - "DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL", [model.url for model in CONFIG.models if model.type == EMBEDDINGS_MODEL_TYPE][0] -) -assert DEFAULT_INTERNET_LANGUAGE_MODEL_URL in [model.url for model in CONFIG.models], "Default internet language model not found." -assert DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL in [model.url for model in CONFIG.models], "Default internet embeddings model not found." -assert DEFAULT_INTERNET_LANGUAGE_MODEL_URL in [ - model.url for model in CONFIG.models if model.type == LANGUAGE_MODEL_TYPE -], "Default internet language model wrong type." -assert DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL in [ - model.url for model in CONFIG.models if model.type == EMBEDDINGS_MODEL_TYPE -], "Default internet embeddings model wrong type." - -logger.info(f"default internet language model url: {DEFAULT_INTERNET_LANGUAGE_MODEL_URL}") -logger.info(f"default internet embeddings model url: {DEFAULT_INTERNET_EMBEDDINGS_MODEL_URL}") - -# Rate limit -GLOBAL_RATE_LIMIT = os.getenv("GLOBAL_RATE_LIMIT", "100/minute") -DEFAULT_RATE_LIMIT = os.getenv("DEFAULT_RATE_LIMIT", "10/minute") +settings = get_settings() diff --git a/app/utils/lifespan.py b/app/utils/lifespan.py index ad6a0fb..bd2c32b 100644 --- a/app/utils/lifespan.py +++ b/app/utils/lifespan.py @@ -5,13 +5,13 @@ from slowapi.util import get_ipaddr from app.helpers import ClientsManager -from app.utils.config import CONFIG, GLOBAL_RATE_LIMIT +from app.utils.config import settings -clients = ClientsManager(config=CONFIG) +clients = ClientsManager(settings=settings) limiter = Limiter( key_func=get_ipaddr, - storage_uri=f"redis://{CONFIG.databases.cache.args.get("username", "")}:{CONFIG.databases.cache.args.get("password", "")}@{CONFIG.databases.cache.args["host"]}:{CONFIG.databases.cache.args["port"]}", - default_limits=[GLOBAL_RATE_LIMIT], + storage_uri=f"redis://{settings.config.databases.cache.args.get("username", "")}:{settings.config.databases.cache.args.get("password", "")}@{settings.config.databases.cache.args["host"]}:{settings.config.databases.cache.args["port"]}", + default_limits=[settings.global_rate_limit], ) diff --git a/app/utils/logging.py b/app/utils/logging.py new file mode 100644 index 0000000..7bf0467 --- /dev/null +++ b/app/utils/logging.py @@ -0,0 +1,6 @@ +import logging +from app.utils.config import settings + +logging.basicConfig(format="%(levelname)s:%(asctime)s:%(name)s: %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(settings.log_level) diff --git a/app/utils/security.py b/app/utils/security.py index 5d4480a..7089fea 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -6,7 +6,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.schemas.security import User -from app.utils.config import CONFIG +from app.utils.config import settings from app.utils.exceptions import InvalidAPIKeyException, InvalidAuthenticationSchemeException from app.utils.lifespan import clients from app.utils.variables import ROLE_LEVEL_0, ROLE_LEVEL_2 @@ -30,7 +30,7 @@ def encode_string(input: str) -> str: return hash -if CONFIG.auth: +if settings.config.auth: def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> str: """ diff --git a/pyproject.toml b/pyproject.toml index 4205026..2637df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,15 +20,15 @@ app = [ "redis==5.0.7", "uvicorn==0.30.1", "fastapi==0.111.0", + "pydantic-settings==2.6.1", "pyyaml==6.0.1", - "python-magic==0.4.27", "grist-api==0.1.0", + "six==1.16.0", "pdfminer.six==20240706", "beautifulsoup4==4.12.3", "duckduckgo-search==6.2.13", "numpy==1.26.4", "slowapi==0.1.9", - "six==1.16.0", ] dev = [ "ruff==0.6.5",